2016-11-01 3 views
2

У меня есть подготовленный замороженный график, который я пытаюсь запустить на устройстве ARM. В основном, я использую contrib/pi_examples/label_image, но с моей сетью, а не с Inception. Моя сеть была обучена с отсевом, который сейчас вызывает у меня неприятности:Устранить операции отсева из графика TensorFlow

Invalid argument: No OpKernel was registered to support Op 'Switch' with these attrs. Registered kernels: 
    device='CPU'; T in [DT_FLOAT] 
    device='CPU'; T in [DT_INT32] 
    device='GPU'; T in [DT_STRING] 
    device='GPU'; T in [DT_BOOL] 
    device='GPU'; T in [DT_INT32] 
    device='GPU'; T in [DT_FLOAT] 

[[Node: l_fc1_dropout/cond/Switch = Switch[T=DT_BOOL](is_training_pl, is_training_pl)]] 

Одно решения, которое я вижу в том, чтобы построить такую ​​TF статической библиотеки, которая включает в себя соответствующую операцию. С другой стороны, лучше было бы исключить операции отсева из сети, чтобы упростить и ускорить работу. Есть ли способ сделать это?

Спасибо.

+0

Вы можете редактировать 'graph.pbtxt' в текстовом редакторе и избавиться от выпадения (т.е. заменить Dropout ор с идентичностью ор) –

ответ

3
#!/usr/bin/env python2 

import argparse 

import tensorflow as tf 
from google.protobuf import text_format 
from tensorflow.core.framework import graph_pb2 
from tensorflow.core.framework import node_def_pb2 

def print_graph(input_graph): 
    for node in input_graph.node: 
     print "{0} : {1} ({2})".format(node.name, node.op, node.input) 

def strip(input_graph, drop_scope, input_before, output_after, pl_name): 
    input_nodes = input_graph.node 
    nodes_after_strip = [] 
    for node in input_nodes: 
     print "{0} : {1} ({2})".format(node.name, node.op, node.input) 

     if node.name.startswith(drop_scope + '/'): 
      continue 

     if node.name == pl_name: 
      continue 

     new_node = node_def_pb2.NodeDef() 
     new_node.CopyFrom(node) 
     if new_node.name == output_after: 
      new_input = [] 
      for node_name in new_node.input: 
       if node_name == drop_scope + '/cond/Merge': 
        new_input.append(input_before) 
       else: 
        new_input.append(node_name) 
      del new_node.input[:] 
      new_node.input.extend(new_input) 
     nodes_after_strip.append(new_node) 

    output_graph = graph_pb2.GraphDef() 
    output_graph.node.extend(nodes_after_strip) 
    return output_graph 

def main(): 

    parser = argparse.ArgumentParser() 
    parser.add_argument('--input-graph', action='store', dest='input_graph') 
    parser.add_argument('--input-binary', action='store_true', default=True, dest='input_binary') 
    parser.add_argument('--output-graph', action='store', dest='output_graph') 
    parser.add_argument('--output-binary', action='store_true', dest='output_binary', default=True) 

    args = parser.parse_args() 

    input_graph = args.input_graph 
    input_binary = args.input_binary 
    output_graph = args.output_graph 
    output_binary = args.output_binary 

    if not tf.gfile.Exists(input_graph): 
     print("Input graph file '" + input_graph + "' does not exist!") 
     return 

    input_graph_def = tf.GraphDef() 
    mode = "rb" if input_binary else "r" 
    with tf.gfile.FastGFile(input_graph, mode) as f: 
     if input_binary: 
      input_graph_def.ParseFromString(f.read()) 
     else: 
      text_format.Merge(f.read().decode("utf-8"), input_graph_def) 

    print "Before:" 
    print_graph(input_graph_def) 
    output_graph_def = strip(input_graph_def, u'l_fc1_dropout', u'l_fc1/Relu', u'prediction/MatMul', u'is_training_pl') 
    print "After:" 
    print_graph(output_graph_def) 

    if output_binary: 
     with tf.gfile.GFile(output_graph, "wb") as f: 
      f.write(output_graph_def.SerializeToString()) 
    else: 
     with tf.gfile.GFile(output_graph, "w") as f: 
      f.write(text_format.MessageToString(output_graph_def)) 
    print("%d ops in the final graph." % len(output_graph_def.node)) 


if __name__ == "__main__": 
    main() 
+0

Сценарий, кажется, удаляет слои, но если я удалю промежуточные слои отсечки, следующий слой ожидает выходного тензора выпадения. В моем случае, когда я пытаюсь прочитать слои, оставшиеся на графике, я получаю сообщение об ошибке: ValueError: graph_def недопустим в узле u'fc7/Conv2D ': Входной тензор' dropout/mul_1: 0 'не найден в graph_def .. , Как я могу изменить имя тензора входного слоя u'fc7/Conv2D в моем protobuf? –

+0

Скрипты обеспечивают, что функциональность также ... отлично работает, спасибо. –

3

Как об этом, как более общее решение:

for node in temp_graph_def.node: 
    for idx, i in enumerate(node.input): 
     input_clean = node_name_from_input(i) 
     if input_clean.endswith('/cond/Merge') and input_clean.split('/')[-3].startswith('dropout'): 
      identity = node_from_map(input_node_map, i).input[0] 
      assert identity.split('/')[-1] == 'Identity' 
      parent = node_from_map(input_node_map, node_from_map(input_node_map, identity).input[0]) 
      pred_id = parent.input[1] 
      assert pred_id.split('/')[-1] == 'pred_id'    
      good = parent.input[0] 
      node.input[idx] = good