2015-12-08 1 views
8

Я пытаюсь использовать функциональные возможности выбывания в tensorflow:Ошибки при помощи отсева в tensorflow

sess=tf.InteractiveSession() 
initial = tf.truncated_normal([1,4], stddev=0.1) 
x = tf.Variable(initial) 
keep_prob = tf.placeholder("float") 
dx = tf.nn.dropout(x, keep_prob) 
sess.run(tf.initialize_all_variables()) 
sess.run(dx, feed_dict={keep_prob: 0.5}) 
sess.close() 

Этот пример очень похож на то, как это делается в the tutorial; Однако, я в конечном итоге со следующей ошибкой:

RuntimeError: min: Conversion function <function constant at 0x7efcc6e1ec80> for type <type 'object'> returned incompatible dtype: requested = float32_ref, actual = float32 

У меня есть некоторые проблемы, чтобы понять DTYPE float32_ref, который, кажется, на фоне этой проблемы. Я также пытался указать dtype=tf.float32, но это ничего не исправить.

Я также попробовал этот пример, который отлично работает с float32:

sess=tf.Session() 
x=tf.Variable(np.array([1.0,2.0,3.0,4.0])) 
sess.run(x.initializer) 
x=tf.cast(x,tf.float32) 
prob=tf.Variable(np.array([0.5])) 
sess.run(prob.initializer) 
prob=tf.cast(prob,tf.float32) 
dx=tf.nn.dropout(x,prob) 
sess.run(dx) 
sess.close() 

Однако, если я бросаю float64 вместо float32 я получаю ту же ошибку:

RuntimeError: min: Conversion function <function constant at 0x7efcc6e1ec80> for type <type 'object'> returned incompatible dtype: requested = float64_ref, actual = float64 

Edit:

Похоже, что эта проблема возникает только при использовании выпадения непосредственно на переменной s, работает для заполнителей и изделий переменных и заполнителей, Пример:

sess=tf.InteractiveSession() 
x = tf.placeholder(tf.float64) 

sess=tf.InteractiveSession() 
initial = tf.truncated_normal([1,5], stddev=0.1,dtype=tf.float64) 
y = tf.Variable(initial) 

keep_prob = tf.placeholder(tf.float64) 
dx = tf.nn.dropout(x*y, keep_prob) 
sess.run(tf.initialize_all_variables()) 
sess.run(dx, feed_dict={x : np.array([1.0, 2.0, 3.0, 4.0, 5.0]),keep_prob: 0.5}) 
sess.close() 

ответ

7

Это ошибка в реализации tf.nn.dropout, что было зафиксировано в последние фиксации, и будет включена в следующий релиз TensorFlow. На данный момент, чтобы избежать проблемы, либо build TensorFlow from source, либо измените свою программу следующим образом:

#dx = tf.nn.dropout(x, keep_prob) 
dx = tf.nn.dropout(tf.identity(x), keep_prob) 

 Смежные вопросы

  • Нет связанных вопросов^_^