2017-02-08 11 views
1

Предположим, что у нас есть два тензора (A и B) с тем же числом измерений. Мы можем умножить их на tensordot. Например:Есть ли что-то среднее между тендердотом и умножением по элементам в Theano?

T.tensordot(A, B, axes = [[0,3], [0,3]]) 

В этом случае мы «пара» оси первого тензора с некоторыми осями тензора второго, а затем мы просуммировать эти «спаренные» Оси:

C[j, k, a, b ] = sum_{i, l} A[i, j, k, l] * A[i, a, b, l] 

В в приведенном выше примере первая и последняя оси первого тензора соединяются с первой и последней осью второго тензора соответственно.

В качестве альтернативы, можно умножить два тензора поэлементны:

C[i, j, k, l] = A[i, j, k, l] * B[i, j, k, l] 

В этом случае мы «пара» все оси первого тензора с все соответствующими осями тензора второго (сначала с первым , второй со вторым и т. д.).

Теперь я хочу сделать что-то, что находится между двумя описанными выше операциями. Более подробно:

  1. Я хочу пару некоторой оси первого тензора с некоторой осью тензора второго (как ж это сделать в tensordot). Таким образом, я не хочу соединять все оси A со всей осью B.
  2. Я не хочу суммировать по всем парным осям (как, например, в парном умножении, нет суммирования по парным оси).

Вот что я хочу написал в «математической» форме:

C[a, b, c, i] = sum_d A[a, b, c, d] * B[i, b, c, d] 

Что такое лучший способ сделать это в Теано?

ответ

0

Способ подхода к описанной проблеме заключается в использовании попарного умножения *. Парное умножение «пары» всех осей первого тензора с соответствующими осями второго тензора (сначала с первым, вторым вторым, ..., последним с последним). Поэтому нам нужно решить две проблемы: (1) перетасовать оси двух тензоров, так что собственные оси спарены друг с другом, (2) добавить «фиктивные» оси для предотвращения спаривания там, где это не требуется. Наконец, мы суммируем все, что хотим.

Конкретная проблема упоминается в вопросе

C[a, b, c, i] = sum_d A[a, b, c, d] * B[i, b, c, d] 

решается следующим образом:

T.sum(A.dimshuffle(0, 1, 2, 3, 'x') * B.dimshuffle('x', 1, 2, 3, 0), axis=4)