2016-08-08 5 views
4

Я пытаюсь написать пользовательскую функцию градиента для 'my_op', которая для примера содержит только вызов tf.identity() (в идеале это может быть любой граф).Создать пользовательскую функцию градиента на основе Python для операции? (без реализации C++)

import tensorflow as tf 
from tensorflow.python.framework import function 


def my_op_grad(x): 
    return [tf.sigmoid(x)] 


@function.Defun(a=tf.float32, python_grad_func=my_op_grad) 
def my_op(a): 
    return tf.identity(a) 


a = tf.Variable(tf.constant([5., 4., 3., 2., 1.], dtype=tf.float32)) 

sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 

grad = tf.gradients(my_op(a), [a])[0] 

result = sess.run(grad) 

print(result) 

sess.close() 

К сожалению, я получаю следующее сообщение об ошибке:

Traceback (most recent call last): 
    File "custom_op.py", line 19, in <module> 
    grad = tf.gradients(my_op(a), [a])[0] 
    File "/Users/njk/tfm/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 528, in __call__ 
    return call_function(self._definition, *args, **kwargs) 
    File "/Users/njk/tfm/lib/python3.5/site-packages/tensorflow/python/framework/function.py", line 267, in call_function 
    compute_shapes=False) 
    File "/Users/njk/tfm/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2285, in create_op 
    raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 
TypeError: Input #0 is not a tensor: <tensorflow.python.ops.variables.Variable object at 0x1080d2710> 

Я знаю, что можно создать пользовательскую C++ операции, но в моем случае мне просто нужно написать собственный градиент для функции, могут быть легко записаны на Python, используя стандартные операции TensorFlow, поэтому я хотел бы избежать написания ненужного кода на C++.

Кроме того, я использую восходящую версию TensorFlow от GitHub.

+0

Вы попробовали @ ops.RegisterGradient ("my_op")? Вы можете следовать примеру части python и пропустить часть C++: https://www.tensorflow.org/versions/r0.10/how_tos/adding_an_op/index.html#implement-the-gradient-in-python –

+0

Я думаю, что вход в ops.RegisterGradient() - это имя зарегистрированной операции TensorFlow, это не просто имя функции Python, содержащей операции TensorFlow. Итак, как-то мне нужно сначала зарегистрировать операцию, не так ли? – njk

+0

Я думаю, что вы правы, и код близок, но не работает из-за ошибки здесь: https://github.com/tensorflow/tensorflow/issues/3710 Обратите внимание, что python_grad_func нуждается в том же интерфейсе, что и ops.RegisterGradient https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function.py#L349 –

ответ

3

Обратите внимание, что для python_grad_func нужен тот же интерфейс, что и ops.RegisterGradient (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/framework/function.py#L349).

Вот модифицированный пример кода:

def my_op_grad(op, grad): ### instead of my_op_grad(x)             
    return tf.sigmoid(op.inputs[0])            

@function.Defun(a=tf.float32, python_grad_func=my_op_grad)      
def my_op(a):                  
    return tf.identity(a)               

def main(unused_argv):               

    a = tf.Variable(tf.constant([-5., 4., -3., 2., 1.], dtype=tf.float32))   
    sess = tf.Session()                
    sess.run(tf.initialize_all_variables())           

    a = tf.identity(a) #workaround for bug github.com/tensorflow/tensorflow/issues/3710 

    grad = tf.gradients(my_op(a), [a])[0]           
    result = sess.run(grad)               

    print(result)                 

    sess.close()  

Выход:

[ 0.00669286 0.98201376 0.04742587 0.88079709 0.7310586 ] 
+0

Запустив код, я получил NotFoundError: тип Op не зарегистрирован «my_op_2f8a34ee» –

+0

Выполнение этого дает мне: 'ValueError: неизвестные аргументы ключевого слова: dict_keys (['a'])' at my_op (a). Что мне делать? – nikpod

+0

Обнаружено обходное решение [здесь] (https://github.com/tensorflow/tensorflow/issues/6804). Пример: Определить 'y = my_op (a)' before 'sess.run (tf.initialize_all_variables())' Решает мою проблему и @EverettYou issue – nikpod

2

Следующая кажется отлично работает. У вас есть какая-то причина, предпочитающая python_grad_func?

@tf.function.Defun(tf.float32, tf.float32) 
def bprop(x, dy): 
    return tf.sigmoid(x) 

@tf.function.Defun(tf.float32, grad_func=bprop) 
def fprop(x): 
    return x # identity 

a = tf.Variable(tf.constant([-5., 4., -3., 2., 1.], dtype=tf.float32)) 
grad = tf.gradients(fprop(a), [a])           

with tf.Session() as sess:                
    sess.run(tf.initialize_all_variables()) 
    result = sess.run(grad)               

print(result)                 
+0

Я получаю 'tensorflow.python.framework.errors_impl.NotFoundError: тип Op не зарегистрирован 'fprop_da39a3ee''. – Albert

+0

Упс. Прости. Обновите образец кода. Дайте мне знать, если это сработает для вас. Часто сложная вещь заключается в том, что определение функции должно быть в том же Графе, что и его использование. – zfc