gpt4 book ai didi

python - TensorFlow 模型拟合和 train_on_batch 之间的区别

转载 作者:行者123 更新时间:2023-12-05 08:18:12 26 4
gpt4 key购买 nike

我正在构建一个普通的 DQN 模型来玩 OpenAI gym Cartpole 游戏。

但是,在我将状态作为输入并将目标 Q 值作为标签的训练步骤中,如果我使用 model.fit(x=states, y=target_q),它工作正常,代理最终可以很好地玩游戏,但如果我使用 model.train_on_batch(x=states, y=target_q),损失不会减少,模型不会玩游戏比随机策略更好的任何地方。

我想知道 fittrain_on_batch 有什么区别?据我了解,fit 调用 train_on_batch 时批处理大小为 32,这应该没有什么区别,因为将批处理大小指定为等于我输入的实际数据大小没有区别。

如果需要更多上下文信息来回答这个问题,完整的代码在这里:https://github.com/ultronify/cartpole-tf

最佳答案

model.fit 将训练 1 个或多个 epoch。这意味着它将训练多个批处理。 model.train_on_batch,顾名思义,只训练一个batch。

举一个具体的例子,假设你正在用 10 张图片训练一个模型。假设您的批量大小为 2。model.fit 将对所有 10 张图像进行训练,因此它将更新梯度 5 次。 (您可以指定多个时期,因此它会遍历您的数据集。)model.train_on_batch 将执行一次梯度更新,因为您只批量提供模型。如果批量大小为 2,您将给 model.train_on_batch 两张图片。

如果我们假设 model.fit 在幕后调用了 model.train_on_batch(虽然我不这么认为),那么 model.train_on_batch 会被多次调用,很可能在一个循环中。这是解释的伪代码。

def fit(x, y, batch_size, epochs=1):
for epoch in range(epochs):
for batch_x, batch_y in batch(x, y, batch_size):
model.train_on_batch(batch_x, batch_y)

关于python - TensorFlow 模型拟合和 train_on_batch 之间的区别,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62629997/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com