2017-01-17 5 views
0

Самой простой вещи может быть для меня просто разместить код Numpy, что я пытаюсь выполнить непосредственно в Теано, если это возможно:Применение поэлементны условные функции на Theano TensorVariable

tensor = shared(np.random.randn(7, 16, 16)).eval() 

tensor2 = tensor[0,:,:].eval() 
tensor2[tensor2 < 1] = 0.0 
tensor2[tensor2 > 0] = 1.0 

new_tensor = [tensor2] 
for i in range(1, tensor.shape[0]): 
    new_tensor.append(np.multiply(tensor2, tensor[i,:,:].eval())) 

output = np.array(new_tensor).reshape(7,16,16) 

Если это не сразу видно, то, что я пытаюсь сделать, это использовать значения из одной матрицы тензора, состоящего из 7 разных матриц, и применить это к другим матрицам в тензоре.

Действительно, проблема, которую я решаю, делает условные утверждения в целевой функции для полностью свернутой сети в Keras. В принципе, потеря некоторых значений характеристик объектов будет рассчитываться (и впоследствии взвешена) иначе, чем другие, в зависимости от некоторых значений в одной из карт функций.

ответ

1

Вы можете легко выполнить условные обозначения с помощью инструкции switch.

Вот бы эквивалентный код:

import theano 
from theano import tensor as T 
import numpy as np 


def _check_new(var): 
    shape = var.shape[0] 
    t_1, t_2 = T.split(var, [1, shape-1], 2, axis=0) 
    ones = T.ones_like(t_1) 
    cond = T.gt(t_1, ones) 
    mask = T.repeat(cond, t_2.shape[0], axis=0) 
    out = T.switch(mask, t_2, T.zeros_like(t_2)) 
    output = T.join(0, cond, out) 
    return output 

def _check_old(var): 
    tensor = var.eval() 

    tensor2 = tensor[0,:,:] 
    tensor2[tensor2 < 1] = 0.0 
    tensor2[tensor2 > 0] = 1.0 
    new_tensor = [tensor2] 

    for i in range(1, tensor.shape[0]): 
     new_tensor.append(np.multiply(tensor2, tensor[i,:,:])) 

    output = theano.shared(np.array(new_tensor).reshape(7,16,16)) 
    return output 


tensor = theano.shared(np.random.randn(7, 16, 16)) 
out1 = _check_new(tensor).eval() 
out2 = _check_old(tensor).eval() 
print out1 
print '----------------' 
print ((out1-out2) ** 2).mean() 

Примечание: так как ваша маскировка на первый фильтр, мне нужно использовать split и join операций.