0

Я использую среднеквадратическую ошибку для вычисления функции потерь многорежимного регрессора. Я использовал рекуррентную модель нейронной сети с архитектурой от одной до многих. Мой выходной вектор имеет размер 6 (1 * 6), а значения монотонны (не убывают).Добавить ограничение на квадратную функцию стоимости ошибки в многорежимной регрессии

пример: y_i = [1,3,6,13,30,57,201]

Я хотел бы, чтобы заставить модель, чтобы изучить эту зависимость. Следовательно, добавление ограничения к функции стоимости. Я получаю ошибку, равную 300 в наборе проверки. Я верю, что после редактирования среднеквадратичной функции потери ошибок я смогу получить лучшую производительность.

Я использую keras для реализации. Вот основная модель.

batchSize = 256 
epochs = 20 

samplesData = trainX 
samplesLabels = trainY 

print("Compiling neural network model...") 

Model = Sequential() 
Model.add(LSTM(input_shape = (98,),input_dim=98, output_dim=128, return_sequences=True)) 
Model.add(Dropout(0.2)) 
#Model.add(LSTM(128, return_sequences=True)) 
#Model.add(Dropout(0.2)) 
Model.add(TimeDistributedDense(7)) 
#rmsprop = rmsprop(lr=0.0, decay=0.0) 
Model.compile(loss='mean_squared_error', optimizer='rmsprop') 
Model.summary() 
print("Training model...") 
# learning schedule callback 
#lrate = LearningRateScheduler(step_decay) 
#callbacks_list = [lrate] 
history = Model.fit(samplesData, samplesLabels, batch_size=batchSize, nb_epoch= epochs, verbose=1, 
          validation_split=0.2, show_accuracy=True) 
print("model training has been completed.") 

Любые другие советы относительно скорости обучения, распада и т. Д. Оцениваются.

ответ

0

Сохраняйте среднеквадратичную ошибку как метрику. Вместо этого используйте Smooth L1 loss. Вот моя реализация.

#Define Smooth L1 Loss 
def l1_smooth_loss(y_true, y_pred): 
    abs_loss = tf.abs(y_true - y_pred) 
    sq_loss = 0.5 * (y_true - y_pred)**2 
    l1_loss = tf.where(tf.less(abs_loss, 1.0), sq_loss, abs_loss - 0.5) 
    return tf.reduce_sum(l1_loss, -1)  
#And 

Model.compile(loss='l1_smooth_loss', optimizer='rmsprop')