2016-08-21 8 views
3

Предполагая, что я соответствую следующей нейронной сети для бинарной проблемы классификации:Как усилить нейронную сеть Keras, используя AdaBoost?

model = Sequential() 
model.add(Dense(21, input_dim=19, init='uniform', activation='relu')) 
model.add(Dense(80, init='uniform', activation='relu')) 
model.add(Dense(80, init='uniform', activation='relu')) 
model.add(Dense(1, init='uniform', activation='sigmoid')) 
# Compile model 
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) 
# Fit the model 
model.fit(x2, training_target, nb_epoch=10, batch_size=32, verbose=0,validation_split=0.1, shuffle=True,callbacks=[hist]) 

Как бы увеличить нейронную сеть с помощью AdaBoost? У keras есть какие-то команды для этого?

ответ

1

Keras сам не реализует adaboost. Тем не менее, модели Keras совместимы с scikit-learn, поэтому вы, вероятно, можете использовать AdaBoostClassifier оттуда: link. Используйте model как base_estimator после того, как вы его скомпилируете, и fit экземпляр AdaBoostClassifier вместо model.

Таким образом, вы не сможете использовать аргументы, которые вы передаете fit, например количество эпох или batch_size, поэтому будут использоваться значения по умолчанию. Если значения по умолчанию недостаточно хороши, вам может понадобиться создать собственный класс, который реализует интерфейс scikit-learn поверх вашей модели и передает правильные аргументы fit.

+0

Привет, спасибо за ваш ответ. Когда я вставляю: 'bdt = AdaBoostClassifier (base_estimator = model)' 'bdt.fit (x2, training_target)' где модель - это моя скомпилированная сеть keras, она дает мне ошибку: * TypeError: Can not clone object ' '(type ): это не похоже на оценку scikit-learn, поскольку она не реализует методы get_params. * – ishido

+0

По-видимому, сами по себе Классификаторы keras не совместимы с scikit-learn. См. Эту статью для получения информации о том, как заставить их работать вместе: https://keras.io/scikit-learn-api/ – Ishamael

0

Видимо, нейронные сети, не совместимы с sklearn AdaBoost см https://github.com/scikit-learn/scikit-learn/issues/1752

+0

Добро пожаловать в Stack Overflow! Это граничная ссылка [link-only answer] (// meta.stackexchange.com/q/8231). Вы должны расширить свой ответ, указав здесь как можно больше информации, и используйте ссылку только для справки. –