2017-02-17 13 views
2

Мой коллега указал на очень крутую возможность использовать sample_weight вместо слоя маскирования, когда вам нужно маскировать ввод не-RNN в Keras.Keras: маскирующий нулевой вход для не-RNN

В моем случае у меня есть 62 столбца на входе, а 63-й - ответ. Более 97% ненулевых записей в 62 столбцах содержатся в первых 30 столбцах. Я пытаюсь просто заставить это работать, поэтому я хотел бы весить последние 32 столбца, чтобы быть 0 в обучении, по существу создавая «маску бедного человека».

Это 8-классная задача классификации с использованием MLP. Ответная переменная была преобразована с использованием функции to_categorical() в Keras.

Вот реализация:

model = Sequential() 
model.add(Dense(100, input_dim=X.shape[1], init='uniform', activation='relu')) 
model.add(Dense(8, init='uniform', activation='sigmoid')) 
hist = model.fit(X, y, 
       validation_data=(X_test, ytest), 
       nb_epoch=epochs_, 
       batch_size=batch_size_, 
       callbacks=callbacks_list, 
       sample_weight = np.array([X.shape[1]-32, 30])) 

Я получаю эту ошибку:

in standardize_weights 
assert y.shape[:sample_weight.ndim] == sample_weight.shape 

Как я могу исправить мой sample_weight в «маске» первые 32 столбцов из входа?

ответ

2

Вес образца не работает так:

sample_weight : optional array of the same length as x , containing weights to apply to the model's loss for each sample. In the case of temporal data, you can pass a 2D array with shape (samples, sequence_length) , to apply a different weight to every timestep of every sample. In this case you should make sure to specify sample_weight_mode="temporal" in compile() . source

Другими словами, этот параметр ставит различные веса на образцов тренировочных данных, а не на особенностях каждого образца. Это используется только на этапе обучения. Я думаю, вы должны использовать маскирование, если вы не хотите, чтобы слой использовал эти функции. Или просто удалите их из своего набора данных? Или, если это не слишком сложно, пусть сеть узнает сама по себе, какие полезные функции есть.

Помогает ли это?