2016-02-08 4 views
2

Я пытаюсь построить тренировку NN, аналогичную той, которая находится в учебнике this.Tensorflow Обучение с использованием входной очереди застревает

Мой код выглядит следующим образом:

def train(): 
    init_op = tf.initialize_all_variables() 
    sess = tf.Session() 
    sess.run(init_op) 

    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(sess=sess, coord=coord) 

    step = 0 

    try: 
     while not coord.should_stop(): 
      step += 1 
      print 'Training step %i' % step 
      training = train_op() 
      sess.run(training) 

    except tf.errors.OutOfRangeError: 
     print 'Done training - epoch limit reached.' 
    finally: 
     coord.request_stop() 

    coord.join(threads) 
    sess.close() 

с

MIN_NUM_EXAMPLES_IN_QUEUE = 10 
NUM_PRODUCING_THREADS = 1 
NUM_CONSUMING_THREADS = 1 

def train_op(): 
    images, true_labels = inputs() 
    predictions = NET(images) 
    true_labels = tf.cast(true_labels, tf.float32) 
    loss = tf.nn.softmax_cross_entropy_with_logits(predictions, true_labels) 
    return OPTIMIZER.minimize(loss) 


def inputs(): 
    filenames = [os.path.join(FLAGS.train_dir, filename) 
     for filename in os.listdir(FLAGS.train_dir) 
     if os.path.isfile(os.path.join(FLAGS.train_dir, filename))] 
    filename_queue = tf.train.string_input_producer(filenames, 
     num_epochs=FLAGS.training_epochs, shuffle=True) 

    example_list = [_read_and_preprocess_image(filename_queue) 
     for _ in xrange(NUM_CONSUMING_THREADS)] 

    image_batch, label_batch = tf.train.shuffle_batch_join(
     example_list, 
     batch_size=FLAGS.batch_size, 
     capacity=MIN_NUM_EXAMPLES_IN_QUEUE + (NUM_CONSUMING_THREADS + 2) * FLAGS.batch_size, 
     min_after_dequeue=MIN_NUM_EXAMPLES_IN_QUEUE) 

    return image_batch, label_batch 

Учебник говорит

Они требуют, чтобы вы называете tf.train.start_queue_runners перед выполнением какой-либо подготовки или шагов вывода, или его будет вечно вечно.

. Я звоню tf.train.start_queue_runners, но выполнение train() по-прежнему застревает при первом вхождении sess.run(training).

Есть ли у кого-то идеи, что я делаю неправильно?

ответ

4

Вы пересматриваете свою сеть каждый раз, когда пытаетесь запустить цикл обучения.

Помните, что TensorFlow определяет график выполнения, а затем выполняет его. Вы хотите называть свой train_op() вне цикла запуска, и вам нужно определить этот график ДО того, как вы звоните initialize_all_variables и tf.train.start_queue_runners