2016-05-10 8 views
5

Если я пытаюсь импортировать сохраненную определение TensorFlow графа сКак получить TensorFlow в «import_graph_def», чтобы вернуться тензорами

import tensorflow as tf 
from tensorflow.python.platform import gfile 

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs', 
               'output/network_activation', 
               'data/correct_outputs'], 
           name='') 

возвращаемые значения не Tensor s, как и ожидалось, но что-то другое: вместо того, чтобы, например, , получения x в

Tensor("data/inputs:0", shape=(?, 784), dtype=float32) 

я

name: "data/inputs_1" 
op: "Placeholder" 
attr { 
    key: "dtype" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "shape" 
    value { 
    shape { 
    } 
    } 
} 

То есть вместо получения ожидаемого тензора x я получаю, x.op. Это меня смущает, потому что documentation кажется, что я должен получить Tensor (хотя там есть группа или, что затрудняет ее понимание).

Как мне получить tf.import_graph_def для возврата определенного Tensor s, который затем я могу использовать (например, при загрузке загруженной модели или в ходе анализов)?

+0

Вторая строка кода должна быть 'от tensorflow.python.platform import gfile'. – tobe

ответ

3

Названия 'data/inputs', 'output/network_activation' и 'data/correct_outputs' являются фактически названиями операций. Чтобы получить tf.import_graph_def() вернуть tf.Tensor объекты, вы должны добавить индекс выхода к имени операции, которая обычно ':0' для одного вывода операций:

x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs:0', 
               'output/network_activation:0', 
               'data/correct_outputs:0'], 
           name='') 

 Смежные вопросы

  • Нет связанных вопросов^_^