При чтении реализации тензорного потока для модели глубокого обучения я пытаюсь понять следующий сегмент кода, включенный в процесс обучения.Мини-пакетный градиент достойной реализации в тензорном потоке
self.net.gradients_node = tf.gradients(loss, self.variables)
for epoch in range(epochs):
total_loss = 0
for step in range((epoch*training_iters), ((epoch+1)*training_iters)):
batch_x, batch_y = data_provider(self.batch_size)
# Run optimization op (backprop)
_, loss, lr, gradients = sess.run((self.optimizer, self.net.cost, self.learning_rate_node, self.net.gradients_node),
feed_dict={self.net.x: batch_x,
self.net.y: util.crop_to_shape(batch_y, pred_shape),
self.net.keep_prob: dropout})
if avg_gradients is None:
avg_gradients = [np.zeros_like(gradient) for gradient in gradients]
for i in range(len(gradients)):
avg_gradients[i] = (avg_gradients[i] * (1.0 - (1.0/(step+1)))) + (gradients[i]/(step+1))
norm_gradients = [np.linalg.norm(gradient) for gradient in avg_gradients]
self.norm_gradients_node.assign(norm_gradients).eval()
total_loss += loss
Я думаю, что это связанно с мини-периодическим градиентом приличными, но я не могу понять, как это работает, или у меня есть некоторые трудности, чтобы подключить его к алгоритму, показанный следующему
Привет, я благодарю вас за ответ. Я включаю всю часть итерации обучения в исходный пост. Другое, что меня смущает, - это avg_gradients, который первоначально был определен как ноль. Тогда в avg_gradients [i] = (avg_gradients [i] * (1.0 - (1.0/(step + 1)))) + (градиенты [i]/(step + 1)), так как avg_gradients [i] = 0, выглядит как первый член в левой части просто равен 0. и avg_gradients [i] = градиенты [i]/(шаг + 1), так ли? Я просто не могу понять, к чему стремится этот градиентный отступ. – user288609
Да, это правильно ** на первом шаге **, в это время 'step + 1' = 1, и поэтому' avg_gradients [i] = grandients [i] '. На каждом последовательном шаге условие «avg_gradients is None» не выполняется, и поэтому оно больше не равно нулю. – Ishamael
Вижу, спасибо. Но каков основной алгоритм (или логика) для этой реализации, если это не пакетный SGD. Я только что обновил оригинальный пост. – user288609