2017-02-22 13 views
1

Я использую CNN для классификации текста и использовал вкладчик contriborflow.Строка DataType для attr 'TI' не входит в список допустимых значений: uint8, int32, int64

Однако, когда я пытаюсь выполнить следующий код:

classifier = learn.Estimator(model_fn=cnn_model) 

classifier.fit(x_train, y_train, steps=10000) 
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)] 

score = metrics.accuracy_score(y_test, y_predicted) 

print('Accuracy: {0:f}'.format(score)) 

Я бегу в следующее сообщение об ошибке:

ERROR:DataType string for attr 'TI' not in list of allowed values: uint8, int32, int64 on line 'classifier.fit'

+0

Я немного отформатировал ваш код, пожалуйста, проверьте, что контент по-прежнему верен. И, дикое угадывание: может быть, 'y_train' должен представлять классы как целые числа, но на самом деле содержит поплавки? – phg

+0

y_train содержит 0 и 1. – Raj

+0

И x_train содержит массив чисел – Raj

ответ

0

Вам нужно преобразовать входы y_train к данному типу. print(type(y_train)) Скорее всего, это float вместо целого.

+0

Они все ints (1 и 0) – Raj