gpt4 book ai didi

java - 如何使用 DeepLearning4J 训练 RBM 和重建输入?

转载 作者:塔克拉玛干 更新时间:2023-11-02 08:36:39 24 4
gpt4 key购买 nike

我正在尝试使用 DeepLearning4J 0.7 训练受限玻尔兹曼机 (RBM),但没有成功。我发现的所有示例要么没有做任何有用的事情,要么不再适用于 DeepLearning4J 0.7。

我需要使用 Contrastive Divergence 训练单个 RBM,然后计算重构误差。

这是我目前所拥有的:

import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.api.Layer;
import static org.deeplearning4j.nn.conf.layers.RBM.VisibleUnit;
import static org.deeplearning4j.nn.conf.layers.RBM.HiddenUnit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;

public class experiment3 {
private static final Logger log = LoggerFactory.getLogger(experiment3.class);

public static void main(String[] args) throws Exception {
DataSetIterator mnistTrain = new MnistDataSetIterator(100, 60000, true);

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.regularization(false)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.list()
.layer(0, new RBM.Builder()
.nIn(784).nOut(500)
.weightInit(WeightInit.XAVIER)
.lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY)
.updater(Updater.NESTEROVS)
.learningRate(0.1)
.momentum(0.9)
.k(1)
.build())
.pretrain(true).backprop(false)
.build();

MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(600));

for(int i = 0; i < 50; i++) {
model.fit(mnistTrain);
}
}
}

它在每个时期编译并打印一些分数,但是分数在应该减少的时候增加了,我还没有找到任何重建的方法。

我尝试使用重建函数并计算距离:

        while(mnistTrain.hasNext()){
DataSet next = mnistTrain.next();
INDArray in = next.getFeatureMatrix();
INDArray out = model.reconstruct(in, 1); // tried with 0 but arrayindexoutofbounds

log.info("distance(1):" + in.distance1(out));
}

但每个元素的距离始终为 0.0,即使模型没有经过单个时期的训练,这是不可能的。

这是训练 RBM 的正确方法吗?如何用单个 RBM 重建输入?

最佳答案

我意识到这个问题很老了,但最近的 Activity 在我的流程中揭示了它。我只想说我最近在 DL4j 中使用 RBM,包括单层和多层。它们可能没有得到官方支持,但它们确实有效。我还偶然发现了 reconstruct。要测试模型,您应该使用 output,就像在 FF 网络中一样。在你的情况下,我假设它是:

INDArray in = next.getFeatureMatrix();
INDArray out = model.output(in);

一些补充:

我正在使用 0.9.1

关于java - 如何使用 DeepLearning4J 训练 RBM 和重建输入?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41143092/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com