2

Одним из способов повышения стабильности в глубоких Q-учебных задачах является поддержание набора целевых весов для сети, которые обновляются медленно и используются для вычисления целей Q-значения. В результате в разное время в учебной процедуре в переднем проходе используются два разных набора весов. Для нормальной DQN это не сложно реализовать, так как веса tensorflow переменных, которые могут быть установлены в feed_dict то есть:Как получить доступ к весам повторяющейся ячейки в Tensorflow?

sess = tf.Session() 
input = tf.placeholder(tf.float32, shape=[None, 5]) 
weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1) 
bias = tf.Variable(tf.constant(0.1, shape=[4]) 
output = tf.matmul(input, weights) + bias 
target = tf.placeholder(tf.float32, [None, 4]) 
loss = ... 

... 

#Here we explicitly set weights to be the slowly updated target weights 
sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias}) 

# Targets for the learning procedure are computed using this output. 

.... 

#Now we run the learning procedure, using the most up to date weights, 
#as well as the previously computed targets 
sess.run(loss, feed_dict={input: states, target: targets}) 

Я хотел бы использовать эту технику целевой сети в рецидивирующий версии DQN, но Я не знаю, как получить доступ и установить весы, используемые внутри повторяющейся ячейки. В частности, я использую tf.nn.rnn_cell.BasicLSTMCell, но я хотел бы знать, как это сделать для любого типа повторяющейся ячейки.

ответ

3

BasicLSTMCell не раскрывает свои переменные как часть своего публичного API. Я рекомендую вам либо посмотреть, какие имена имеют эти переменные на вашем графике, и подправить эти имена (эти имена вряд ли изменятся, так как они находятся в контрольных точках, и изменение этих имен приведет к нарушению совместимости контрольных точек).

В качестве альтернативы вы можете сделать копию BasicLSTMCell, которая выставляет переменные. Думаю, это самый чистый подход.

+1

Это сработало, спасибо Александру. Для тех, кто хочет получить более подробную информацию, переменные веса и смещения создаются при подаче рекуррентной ячейки в 'tf.nn.dynamicrnn()'. После запуска 'tf.initialize_all_variables()' в сеансе будет два новых обучаемых тензора, которые вы можете увидеть, если вы запустите 'tf.trainable_variables()'. В моем случае они были названы «RNN/BasicLSTMCell/Linear/Matrix: 0' и« RNN/BasicLSTMCell/Linear/Bias: 0'. –