gpt4 book ai didi

python - 前向传播缓慢 - 训练时间正常

转载 作者:行者123 更新时间:2023-11-30 22:16:45 28 4
gpt4 key购买 nike

我无法弄清楚为什么当我执行前向传播时我的代码非常慢。有问题的代码可以在这里找到:https://github.com/rekkit/lazy_programmer_ml_course/blob/develop/05_unsupervised_deep_learning/poetry_generator_rnn.py

我正在将我的代码的性能与此代码的性能进行比较:https://github.com/lazyprogrammer/machine_learning_examples/blob/master/rnn_class/srn_language_tf.py

不同之处在于我运行的时候

self.session.run(self.predict(x_batch), feed_dict={...})

或者当我运行时

self.returnPrediction(x_batch)

运行大约需要0.14秒。现在这听起来可能不像是一场灾难,但每个句子需要 0.14 秒(我正在创建一个 RNN 来预测句子中的下一个单词)。由于有 1436 个句子,因此我们每个 epoch 的时长约为 3 分 20 秒。如果我想训练 10 个 epoch,那就是半个小时。比其他代码花费的时间要多得多。

有人知道问题出在哪里吗?我能看到的唯一区别是我已经模块化了代码。

感谢您提前提供的帮助。

最佳答案

我已经弄清楚了。每次我调用预测方法时,我都会重建图表。相反,在 fit 方法中我定义了一个变量:

preds = self.predict(self.tfX)

然后每次我需要预测时,而不是使用:

predictions = self.session.run(self.predict(x_batch), feed_dict={...})

我使用:

predictions = self.session.run(self.preds, feed_dict={...})

关于python - 前向传播缓慢 - 训练时间正常,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49865478/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com