2016-12-14 2 views
3

Я пытаюсь обучить ограниченную машину Больцмана (RBM) с DeepLearning4J 0.7, но безуспешно. Все примеры, которые я нашел, либо не делают ничего полезного, либо не работают с DeepLearning4J 0.7.Как обучить RBM и восстановить вход с помощью DeepLearning4J?

Мне нужно обучить единое RBM с контрастирующей дивергенцией, а затем вычислить ошибку восстановления.

Вот то, что я до сих пор:

import org.nd4j.linalg.factory.Nd4j; 
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher; 
import org.deeplearning4j.nn.conf.layers.RBM; 
import org.deeplearning4j.nn.api.Layer; 
import static org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit; 
import static org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit; 
import org.deeplearning4j.optimize.api.IterationListener; 
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; 
import org.deeplearning4j.eval.Evaluation; 
import org.deeplearning4j.nn.api.OptimizationAlgorithm; 
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 
import org.deeplearning4j.nn.conf.Updater; 
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 
import org.deeplearning4j.nn.weights.WeightInit; 
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; 
import org.nd4j.linalg.dataset.DataSet; 
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; 
import org.nd4j.linalg.lossfunctions.LossFunctions; 
import org.slf4j.Logger; 
import org.slf4j.LoggerFactory; 
import org.nd4j.linalg.api.ndarray.INDArray; 

public class experiment3 { 
    private static final Logger log = LoggerFactory.getLogger(experiment3.class); 

    public static void main(String[] args) throws Exception { 
     DataSetIterator mnistTrain = new MnistDataSetIterator(100, 60000, true); 

     MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() 
      .regularization(false) 
      .iterations(1) 
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 
      .list() 
      .layer(0, new RBM.Builder() 
        .nIn(784).nOut(500) 
        .weightInit(WeightInit.XAVIER) 
        .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) 
        .updater(Updater.NESTEROVS) 
        .learningRate(0.1) 
        .momentum(0.9) 
        .k(1) 
        .build()) 
      .pretrain(true).backprop(false) 
      .build(); 

     MultiLayerNetwork model = new MultiLayerNetwork(conf); 
     model.init(); 
     model.setListeners(new ScoreIterationListener(600)); 

     for(int i = 0; i < 50; i++) { 
      model.fit(mnistTrain); 
     } 
    } 
} 

Он собирает и печатать некоторые счет в каждой эпохе, но счет усиливает, когда она должна быть убывающей, и я не нашел способ сделать реконструкцию.

Я пытался использовать функцию реконструирует и вычислить расстояние:

 while(mnistTrain.hasNext()){ 
      DataSet next = mnistTrain.next(); 
      INDArray in = next.getFeatureMatrix(); 
      INDArray out = model.reconstruct(in, 1); // tried with 0 but arrayindexoutofbounds 

      log.info("distance(1):" + in.distance1(out)); 
     } 

но расстояние всегда 0,0 для каждого элемента, даже если модель не обученные для одной эпохи, что невозможно.

Это правильный способ обучения RBM? Как я могу восстановить ввод с помощью единого RBM?

+0

Пожалуйста, поднимите этот вопрос на канал Gitter DL4J, который очень активен: https://gitter.im/deeplearning4j/deeplearning4j – tremstat

+0

Я проверю Gitter, если у меня нет ответа здесь. Благодарю. –

+0

Любой успех? Все еще борется с теми же проблемами с последней версией (0.9.1). –

ответ

1

Я говорил с Адамом Гибсоном (автором) по телефону the project's Gitter channel по этому вопросу. Он говорит, что они фактически отказались от поддержки RBM во всех, кроме кодовой базы, поэтому любые ошибки RBM могут произойти и не будут исправлены.

Причина, по которой они отбрасываются, заключается в том, что в целом RBM был заменен VAE (Variational Auto-Encoder), поэтому они заставляют людей использовать это вместо этого.

 Смежные вопросы

  • Нет связанных вопросов^_^