Я столкнулся с рядом проблем, связанных с динамическими осями. Я пытаюсь реализовать сверточный rnn, подобный функции LSTM(), но обрабатывает последовательный вход изображения и выводит изображение.Динамические топоры с настраиваемым RNN
Я могу построить сеть и передать фиктивные данные через него, чтобы произвести вывод, но когда я пытаюсь вычислить ошибку с input_variable ярлыком я последовательно увидеть следующее сообщение об ошибке:
RuntimeError: Node '__v2libuid__Input471__v2libname__img_label' (InputValue operation): DataFor: FrameRange's dynamic axis is inconsistent with matrix: {numTimeSteps:1, numParallelSequences:2, sequences:[{seqId:0, s:0, begin:0, end:1}, {seqId:1, s:1, begin:0, end:1}]} vs. {numTimeSteps:2, numParallelSequences:1, sequences:[{seqId:0, s:0, begin:0, end:2}]}`
Если я правильно понимайте это сообщение об ошибке, он утверждает, что значение, которое я передал в качестве метки, имеет несогласованные оси в соответствии с ожидаемыми с помощью 2 временных шагов и 1 параллельной последовательностью, когда требуется 1 временная и 2 последовательности. Это имеет смысл для меня, но я не уверен, как данные, которые я передаю, не соответствуют этому. Вот (примерно) переменные декларации и Eval заявления:
…
img_input = input_variable(shape=img_shape, dtype=np.float32, name="img_input")
convlstm = Recurrence(conv_lstm_cell, initial_state=initial_state)(img_input)
out = select_last(convlstm)
img_label = input_variable(shape=img_shape, dynamic_axes=out.dynamic_axes, dtype=np.float32, name="img_label”)
error = squared_error(out, img_label)
…
dummy_input = np.ones(shape=(2, 3, 3, 32, 32)) # (batch, seq_len, channels, height, width)
dummy_label = np.ones(shape=(2, 3, 32, 32)) # (batch, channels, height, width)
out = error.eval({img_input:dummy_input, img_label:dummy_label})
я считаю часть проблемы: с dynamic_axes устанавливается при создании img_label input_variable, я также попытался установить его [Axis.default_batch_axis() ] и не устанавливая его вообще, и квадратная ошибка жалуется на несогласованные оси между out и img_label или я вижу ту же ошибку, что и выше.