2017-01-30 2 views
0

Я пытаюсь использовать учебник CIFAR10 для создания собственного сценария обучения. Мой набор данных хранится в файле MAT, который я конвертирую в массив Numpy, используя h5py. В учебнике, они читают данные с помощью:Как сделать учебник tenorflow cifar10 из массива numpy?

reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 

в то время как в моем случае, я использую:

images_placeholder = tf.placeholder(tf.float32, shape=shape) 
labels_placeholder = tf.placeholder(tf.int32, shape=batch_size) 

Проблема заключается в том, когда я пытаюсь запустить обучение с использованием MonitoredTrainingSession, поскольку они используют в CIFAR10 пример:

def train(): 
with tf.Graph().as_default(): 
    global_step = tf.contrib.framework.get_or_create_global_step() 

    with inputs.read_imdb(FLAGS.input_path) as imdb: 
     sets = np.asarray(imdb['images']['set'], dtype=np.int32) 
     data_set = inputs.DataSet(imdb, np.where(sets == 1)[0]) 
    images, labels = inputs.placeholder_inputs(data_set, batch_size=128) 

    logits = model.vgg16(images) 
    loss = model.loss(logits, labels) 
    train_op = model.train(loss, global_step, data_set.num_examples) 

    class _LoggerHook(tf.train.SessionRunHook): 
     def begin(self): 
      self._step = -1 

     def before_run(self, run_context): 
      self._step += 1 
      self._start_time = time.time() 
      return tf.train.SessionRunArgs(loss) 

     def after_run(self, run_context, run_values): 
      duration = time.time() - self._start_time 
      loss_value = run_values.results 
      if self._step % 10 == 0: 
       num_examples_per_step = FLAGS.batch_size 
       examples_per_sec = num_examples_per_step/duration 
       sec_per_batch = float(duration) 

       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
           'sec/batch)') 
       print(format_str % (datetime.now(), self._step, loss_value, 
            examples_per_sec, sec_per_batch)) 

    with tf.train.MonitoredTrainingSession(
      checkpoint_dir=FLAGS.train_dir, 
      hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), 
        tf.train.NanTensorHook(loss), 
        _LoggerHook()], 
      config=tf.ConfigProto(
       log_device_placement=FLAGS.log_device_placement)) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(train_op) 

, где inputs.DataSet основан в примере MNIST. Некоторые вспомогательные функции:

def read_imdb(path): 
    imdb = h5py.File(path) 
    check_imdb(imdb) 
    return imdb 

def placeholder_inputs(data_set, batch_size): 
    shape = (batch_size,) + data_set.images.shape[1:][::-1] 
    images_placeholder = tf.placeholder(tf.floatz32, shape=shape) 
    labels_placeholder = tf.placeholder(tf.int32, shape=batch_size) 
    return images_placeholder, labels_placeholder 

Когда я пытаюсь запустить, он, очевидно, возвращает ошибку You must feed a value for placeholder tensor 'Placeholder', потому что я не создал канал. Дело в том, что у меня есть функция, которая создает фид, но я не знаю, куда его передать.

def fill_feed_dict(data_set, images, labels): 
    images_feed, labels_feed = data_set.next_batch(images.get_shape()[0].value) 
    feed_dict = {images: images_feed, labels: labels_feed} 
    return feed_dict 

Может ли кто-нибудь помочь?

Спасибо, Daniel

ответ

0

Вам просто нужно пройти dict созданный fill_feed_dict каждый раз, когда вы вызываете run метод:

mon_sess.run(train_op, feed_dict=fill_feed_dict(data_set, images, labels)) 

 Смежные вопросы

  • Нет связанных вопросов^_^