2014-01-27 2 views
1

Я работаю над машиной векторной поддержки, используя sci-kit learn in Python.scikit-learn: SVM дает мне нулевую ошибку, но не может предсказать

Я обучил модель, использовал GridSearch и перекрестная проверка, чтобы найти оптимальные параметры, и оценил лучшую модель на 15% -ном наборе.

Матрица путаницы в конце говорит, что у меня 0 неправильных классификаций.
Позже модель дала мне неправильные прогнозы, когда я даю ей рукописную цифру (я не включил код для этого, чтобы сохранить этот вопрос почтительно коротким).

Поскольку SVM имеет нулевую ошибку и далее, позже он не может предсказать правильно, я неправильно построил этот SVM.

Мой вопрос заключается в следующем:

Могу ли я право заподозрить я Перекрестная проверка наряду с GridSearch как-то неправильно? Или я дал параметры GridSearch, которые каким-то образом смехотворны и дают мне ложные результаты?

Спасибо за ваше время и силы за это прочтение.


ШАГ 1: разделить набор данных на 85%/15% с помощью функции train_test_split

X_train, X_test, y_train, y_test = 
cross_validation.train_test_split(X, y, test_size=0.15, 
random_state=0) 

ШАГ 2: применить функцию GridSearchCV для обучающего набора для настройки классификаторов

C_range = 10.0 ** np.arange(-2, 9) 
gamma_range = 10.0 ** np.arange(-5, 4) 
param_grid = dict(gamma=gamma_range, C=C_range) 
cv = StratifiedKFold(y=y, n_folds=3) 

grid = GridSearchCV(SVC(), param_grid=param_grid, cv=cv) 
grid.fit(X, y) 

print("The best classifier is: ", grid.best_estimator_) 

выход здесь:

('The best classifier is: ', SVC(C=10.0, cache_size=200, 
class_weight=None, coef0=0.0, degree=3, 
gamma=0.0001, kernel='rbf', max_iter=-1, probability=False, 
random_state=None, shrinking=True, tol=0.001, verbose=False)) 

ШАГ 3: И, наконец, оценить настроенную классификатор на оставшиеся 15% удержания отказа набора.

clf = svm.SVC(C=10.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, 
    gamma=0.001, kernel='rbf', max_iter=-1, probability=False, 
    random_state=None, shrinking=True, tol=0.001, verbose=False) 

clf.fit(X_train, y_train) 

clf.score(X_test, y_test) 
y_pred = clf.predict(X_test) 

Выход здесь:

precision recall f1-score support 

     -1.0  1.00  1.00  1.00   6 
     1.0  1.00  1.00  1.00  30 

avg/total  1.00  1.00  1.00  36 

Confusion Matrix: 
[[ 6 0] 
[ 0 30]] 
+1

Вы пробовали оценить его с большим количеством проб? Если тестовый образец, о котором вы упомянули, не от тренировки и удержания, я думаю, что это вероятный результат, а не ошибка вашего кода. Ошибка нуля при удержании не гарантирует нулевую ошибку в реальном наборе тестов. С другой стороны, ошибка сдерживания может быть чрезмерно оптимистичной оценкой ошибки теста, так как GridSearchCV видел все X и y, включая выдержки. –

+0

Я использовал cv = StratifiedKFold (y = y, n_folds = 3), который является трехкратной перекрестной проверкой, выполненной над 85% данных ... поэтому я считаю, что 15% данных, которые служат тестовым набором, никогда не видели сетчатый поиск. Между тем, образец, используемый для прогнозирования (который не работал), никогда ранее не видел SVM и живет в другом файле. К сожалению, у меня только один из них. –

+0

Исправьте меня, если я ошибаюсь, это '' X, y'' 100% данных и '' X_train, y_train'' 85% данных? –

ответ

3

Вы слишком мало данных в тестовом наборе (только 6 образцов для одного из классов), чтобы быть уверенным в прогностической точности вашего модель. Я бы рекомендовал маркировать не менее 150 выборок на классы и хранить 50 образцов в тестируемом тестировании для вычисления показателей оценки.

Редактировать: также посмотрите на новый образец, который он не может предсказать: являются ли значения функций в том же диапазоне (например, [0, 255] вместо [0, 1] или [-1, 1] для цифры от учебных и тестовых наборов)? делает ли новая цифра «похожими на другие цифры из вашего тестового набора, когда вы рисуете их, например, с помощью matplotlib?

+0

Спасибо за обе эти простые , практические и ценные идеи! В моем исходном наборе данных имеется 235 наблюдений. Поскольку это, по-видимому, слишком мало, могу ли я обойти его с помощью проверки K-fold Cross, или вы можете рекомендовать хороший ресурс, который описывает, как максимально эффективно использовать это ограничение? Благодаря! –

+0

Я принял ответ и продолжил выпуск здесь: http://stackoverflow.com/questions/21415934/using-sci-kit-learn-how-do-i-learn-a-svm-over-a- малого набора данных –