2017-01-13 2 views
1

Я столкнулся с рядом проблем, связанных с динамическими осями. Я пытаюсь реализовать сверточный rnn, подобный функции LSTM(), но обрабатывает последовательный вход изображения и выводит изображение.Динамические топоры с настраиваемым RNN

Я могу построить сеть и передать фиктивные данные через него, чтобы произвести вывод, но когда я пытаюсь вычислить ошибку с input_variable ярлыком я последовательно увидеть следующее сообщение об ошибке:

RuntimeError: Node '__v2libuid__Input471__v2libname__img_label' (InputValue operation): DataFor: FrameRange's dynamic axis is inconsistent with matrix: {numTimeSteps:1, numParallelSequences:2, sequences:[{seqId:0, s:0, begin:0, end:1}, {seqId:1, s:1, begin:0, end:1}]} vs. {numTimeSteps:2, numParallelSequences:1, sequences:[{seqId:0, s:0, begin:0, end:2}]}` 

Если я правильно понимайте это сообщение об ошибке, он утверждает, что значение, которое я передал в качестве метки, имеет несогласованные оси в соответствии с ожидаемыми с помощью 2 временных шагов и 1 параллельной последовательностью, когда требуется 1 временная и 2 последовательности. Это имеет смысл для меня, но я не уверен, как данные, которые я передаю, не соответствуют этому. Вот (примерно) переменные декларации и Eval заявления:

… 
img_input = input_variable(shape=img_shape, dtype=np.float32, name="img_input") 
convlstm = Recurrence(conv_lstm_cell, initial_state=initial_state)(img_input) 
out = select_last(convlstm) 
img_label = input_variable(shape=img_shape, dynamic_axes=out.dynamic_axes, dtype=np.float32, name="img_label”) 
error = squared_error(out, img_label) 
… 

dummy_input = np.ones(shape=(2, 3, 3, 32, 32)) # (batch, seq_len, channels, height, width) 
dummy_label = np.ones(shape=(2, 3, 32, 32))  # (batch, channels, height, width) 
out = error.eval({img_input:dummy_input, img_label:dummy_label}) 

я считаю часть проблемы: с dynamic_axes устанавливается при создании img_label input_variable, я также попытался установить его [Axis.default_batch_axis() ] и не устанавливая его вообще, и квадратная ошибка жалуется на несогласованные оси между out и img_label или я вижу ту же ошибку, что и выше.

ответ

0

Единственная проблема, которую я вижу с выше установки является то, что ваша метка манекена должна иметь четкую динамическую ось поэтому она должна быть объявлена ​​как

dummy_label = np.ones(shape=(2, 1, 3, 32, 32)) 

Если предположить, что convlstm работы похож на LSTM, то следующие работы: без проблем для меня, и он оценивает потерю для двух пар ввода/вывода.

x = C.input_variable((3,32,32)) 
cx = convlstm(x) 
lx = C.sequence.last(cx) 
y = C.input_variable(lx.shape, dynamic_axes=lx.dynamic_axes) 
loss = C.squared_error(y, lx) 
x0 = np.arange(2*3*3*32*32,dtype=np.float32).reshape(2,3,3,32,32) 
y0 = np.arange(2*1*3*32*32,dtype=np.float32).reshape(2,1,3,32,32) 
loss.eval({x:x0, y:y0})