2016-04-24 5 views
1

Я пытаюсь обучить Ленте, определенному здесь Solving in Python with LeNet , чтобы обучить данные распознавания цифр, установленные в kaggle. Сначала я использую приведенный здесь учебник Create lmdb для передачи данных в формат lmdb. Затем я следую инструкциям в ссылке 1 («Решение на Python с LeNet») для создания прототипов подготовки, тестирования и решения. Однако, когда я извлекаю решатель из solver.prototxt, я обнаружил, что каждый элемент в данных изображения равен нулю. Что-то не так с моим кодом?Python with Caffe: пользовательские данные - все нули при чтении с решателя

import pandas as pd 
import lmdb 
import caffe 
import numpy as np 
import numpy as np 
from caffe import layers as L, params as P 
from pylab import * 
import os, sys 
from caffe.proto import caffe_pb2 
%matplotlib inline 

train_original = pd.read_csv(path/to/my/train.csv) 
test = pd.read_csv(path/to/my/test.csv) 
train_obs, dim = train_data.shape 
val_obs, dim = val_data.shape 
train_data_array = np.array(train_data, dtype = float32) 
train_label_array = np.array(train_label, dtype = float32) 
val_data_array = np.array(val_data, dtype = float32) 
val_label_array = np.array(val_label, dtype = float32) 

train_lmdb_size = train_data_array.nbytes * 10 
val_lmdb_size = val_data_array.nbytes * 10 
env = lmdb.open('train_lmdb', map_size=train_lmdb_size) 
with env.begin(write=True) as txn: 
    for i in range(train_num): 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.channels = 1 
     datum.height = 28 
     datum.width = 28 
     datum.data = train_data_array[i].reshape(28, 28).tobytes() # or .tostring() if numpy < 1.9 
     datum.label = int(train_label_array[i]) 
     str_id = '{:08}'.format(i) 
     # The encode is only essential in Python 3 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

env = lmdb.open('test_lmdb', map_size=train_lmdb_size) 
with env.begin(write=True) as txn: 
    for i in range(val_num): 
     datum = caffe.proto.caffe_pb2.Datum() 
     datum.channels = 1 
     datum.height = 28 
     datum.width = 28 
     datum.data = val_data_array[i].reshape(28, 28).tobytes() # or .tostring() if numpy < 1.9 
     datum.label = int(val_label_array[i]) 
     str_id = '{:08}'.format(i) 
     # The encode is only essential in Python 3 
     txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

train_path = 'CNN_training.prototxt' 
test_path = 'CNN_testing.prototxt' 
train_lmdb_path = 'train_lmdb' 
test_lmdb_path = 'test_lmdb' 
solver_path = 'CNN_solver.prototxt' 

def lenet(lmdb, batch_size): 
    # our version of LeNet: a series of linear and simple nonlinear transformations 
    n = caffe.NetSpec() 

    n.data, n.label = L.Data(batch_size=batch_size, backend=P.Data.LMDB, source=lmdb, 
          transform_param=dict(scale=1./255), ntop=2) 

    n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=20, weight_filler=dict(type='xavier')) 
    n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX) 
    n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=50, weight_filler=dict(type='xavier')) 
    n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX) 
    n.fc1 = L.InnerProduct(n.pool2, num_output=500, weight_filler=dict(type='xavier')) 
    n.relu1 = L.ReLU(n.fc1, in_place=True) 
    n.score = L.InnerProduct(n.relu1, num_output=10, weight_filler=dict(type='xavier')) 
    n.loss = L.SoftmaxWithLoss(n.score, n.label) 

    return n.to_proto() 

with open(train_path, 'w') as f: 
    f.write(str(lenet(train_lmdb_path, 64))) 

with open(test_path, 'w') as f: 
    f.write(str(lenet(test_lmdb_path, 100))) 

s = caffe_pb2.SolverParameter() 
s.random_seed = 0xCAFFE 
s.train_net = train_path 
s.test_net.append(test_path) 
s.test_interval = 500 
s.test_iter.append(100) 
s.max_iter = 10000 
s.type = 'Adam' 
s.base_lr = 0.01 
s.momentum = 0.75 
s.weight_decay = 5e-1 
s.lr_policy = 'inv' 
s.gamma = 0.0001 
s.power = 0.75 
s.display = 1000 
s.snapshot = 5000 
s.snapshot_prefix = 'lin_lnet' 
s.solver_mode = caffe_pb2.SolverParameter.CPU 
with open(solver_path,'w') as f: 
    f.write(str(s)) 

solver = None 
solver = caffe.get_solver(solver_path) 
# result in solver.net['data'].data[0] are zeros 
print solver.net['data'].data[0] 
array([[[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.], 
     [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
      0., 0.]]], dtype=float32) 

ответ

1

Попробуйте сделать net.forward(). Вы должны уметь видеть свои данные, если все остальное верно.

Простой и безопасный способ записи в LMDB использует caffe.io.array_to_datum, как показано на рисунке here.

+0

спасибо, я вижу – user3162707

+0

@ user3162707, пожалуйста, подумайте о «принятии» этого ответа, щелкнув значок «v» рядом с ним. – Shai

 Смежные вопросы

  • Нет связанных вопросов^_^