У меня есть плотный отсевающий ANN с выходным слоем softmax. Вот метод обучения:Проблема несовместимости типа данных в Theano
def train(network, input_var, epochs, train_input, val_input, batchsize,
update_fn, loss_fn, verbose=True, deterministic=False, **kwargs):
"""
:param network: the output layer of a `lasagne`-backed ANN
:type input_var: TheanoVariable
:param train_input: (x, y)
:type train_input: (np.ndarray, np.ndarray)
:param val_input: (x, y)
:type val_input: (np.ndarray, np.ndarray)
"""
# create target var
# note: I use my own method instead of `theano.shared`, because for
# whatever reason Theano says I can't use a shared variable here
# and that I should pass it via the `givens` parameter, whatever
# that is.
target_var = self.numpy_to_theano_variable(train_input[1])
# training functions
prediction = lasagne.layers.get_output(network,
deterministic=deterministic)
loss = loss_fn(prediction, target_var).mean()
params = lasagne.layers.get_all_params(network, trainable=True)
updates = update_fn(loss, params, **kwargs)
train_fn = theano.function([input_var, target_var], loss, updates=updates)
# validation functions
val_pred = lasagne.layers.get_output(network, deterministic=True)
val_loss = loss_fn(val_pred, target_var).mean()
val_acc = T.mean(T.eq(T.argmax(val_pred, axis=1), target_var),
dtype=theano.config.floatX)
val_fn = theano.function([input_var, target_var], [val_loss, val_acc])
def run_epoch(epoch):
train_batches = yield_batches(train_input, batchsize)
val_batches = yield_batches(val_input, batchsize)
train_err = np.mean([train_fn(x, y) for x, y in train_batches])
val_err, val_acc = np.mean(
[val_fn(x, y) for x, y in val_batches], axis=0)
if verbose:
print("Epoch {} of {}: training error = {}, "
"validation error = {}, validation accuracy = {}"
"".format(epoch+1, epochs, train_err, val_err, val_acc))
return train_err, val_err, val_acc
return [run_epoch(e) for e in xrange(epochs)]
Метод numpy_to_theano_variable
определен в базовом классе:
def create_theano_variable(ndim, dtype, name=None):
"""
:type ndim: int
:type dtype: str
:type name: str
"""
if ndim == 1:
theano_var = T.vector(name, dtype=dtype)
elif ndim == 2:
theano_var = T.matrix(name, dtype=dtype)
elif ndim == 3:
theano_var = T.tensor3(name, dtype=dtype)
elif ndim == 4:
theano_var = T.tensor4(name, dtype=dtype)
else:
raise ValueError
return theano_var
def numpy_to_theano_variable(array, name=None):
"""
:type array: np.ndarray
:param array:
:rtype: T.TensorVariable
"""
return create_theano_variable(ndim=array.ndim,
dtype=str(array.dtype).split(".")[-1],
name=name)
В начале train
target_var
инициализируется как TheanoVariable
с тем же числом измерений и типа, массив numpy, используемый для его подачи. По причине за пределами моего понимания, если тип данных не int32
или int64
я получаю эту ошибку:
Traceback (most recent call last):
File "./train_net.py", line 131, in <module>
main(sys.argv[1:])
File "./train_net.py", line 123, in main
learning_rate=learning_rate, momentum=momentum, verbose=True)
File "/Users/ilia/OneDrive/GitHub/...", line 338, in train
loss = loss_fn(prediction, target_var).mean()
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/lasagne/objectives.py", line 129, in categorical_crossentropy
return theano.tensor.nnet.categorical_crossentropy(predictions, targets)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/tensor/nnet/nnet.py", line 2077, in categorical_crossentropy
return crossentropy_categorical_1hot(coding_dist, true_dist)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/gof/op.py", line 613, in __call__
node = self.make_node(*inputs, **kwargs)
File "/Users/ilia/.venvs/test/lib/python2.7/site-packages/theano/tensor/nnet/nnet.py", line 1440, in make_node
tensor.lvector))
TypeError: integer vector required for argument: true_one_of_n(got type: TensorType(<dtype>, vector) instead of: TensorType(int64, vector))
где <dtype>
представляет тип target_var
выведенный из Numpy массива (я проверил, что с int8
, int16
, uint8
, uint16
, uint32
, uint64
). В чем причина, она принимает только int32
и int64
?