2016-02-12 4 views
0

В Theano есть возможность использовать функцию повторения T.repeat(A,B) и предоставить пару векторов, так что каждый элемент A[i] повторяется B[i] раз.Theano Scan and Repeat

К сожалению, эта операция не имеет определенного градиента (она выбрасывает исключение не реализована), что является проблемой, поскольку я пытаюсь использовать ее с помощью пробников на основе градиента Pymc3.

Я думаю, что могу обратиться к этому с помощью функции scan и вызывать повторную рекурсию для каждого элемента из двух векторов, однако мой код не работает, возможно, потому, что я неправильно вызываю scan. Может ли кто-нибудь помочь мне понять, почему приведенный ниже код не работает?

A = T.dvector('A') 
B = T.ivector('B') 
A.tag.test_value = np.array(np.random.rand(2), dtype = "float32") 
B.tag.test_value = np.array(np.random.rand(2), dtype = "int32") 
th.config.compute_test_value = 'warn' 

results, updates = th.scan(fn = lambda prior_result, A, B: A.repeat(B), 
          sequences = [A, B], 
          outputs_info = T.constant([1,4,4,4])) 

b = th.function(inputs=[A,B], outputs=results.flatten()) 
b([1],[4]) 

Я ожидаю, что это вернет [1,1,1,1], но вместо этого вернет ошибку ниже.

395  except AttributeError: 
    396   return _wrapit(a, 'repeat', repeats, axis) 
--> 397  return repeat(repeats, axis) 
    398 
    399 

ValueError: operands could not be broadcast together with shape (1,) (4,) 

Я поднял issue на GitHub Pymc3, чтобы увидеть, если это то, что должно быть исправлено более постоянно, но я полагаю, его хорошую возможность узнать больше о Теана для меня в любом случае, и если я могу решить проблему, может быть, я могу внести свой вклад в проект.

ответ

0

Я вижу здесь две вещи:

  1. Bad упорядочение в лямбда-выражения: оно должно быть A, B, prior_result (теперь B рассматривается как outputs_info)
  2. форма A.repeat (В) отличающийся от формы previous_result (на данном этапе компиляции)

Быстрое исправление: просто удалите output_info из аргументов сканирования (и before_result из лямбда), и вы получите [1,1,1,1].