gpt4 book ai didi

python - keras:如何使用 model.train_on_batch() 的学习率衰减

转载 作者:太空狗 更新时间:2023-10-29 23:56:38 26 4
gpt4 key购买 nike

在我当前的项目中,我使用 keras 的 train_on_batch() 函数进行训练,因为 fit() 函数不支持所需的生成器和鉴别器的交替训练对于 GAN。使用(例如)Adam 优化器,我必须在构造函数中使用 optimizer = Adam(decay=my_decay) 指定学习率衰减,并将其交给模型编译方法。如果我之后使用模型的 fit() 函数,这个工作很好,因为它负责在内部计算训练重复次数,但我不知道如何使用像<这样的构造自己设置这个值/p>

counter = 0
for epoch in range(EPOCHS):
for batch_idx in range(0, number_training_samples, BATCH_SIZE):
# get training batch:
x = ...
y = ...
# calculate learning rate:
current_learning_rate = calculate_learning_rate(counter)
# train model:
loss = model.train_on_batch(x, y) # how to use the current learning rate?

具有一些计算学习率的函数。如何手动设置当前学习率?

如果这篇文章中有错误,我很抱歉,这是我的第一个问题。

感谢您的帮助。

最佳答案

编辑

在 2.3.0 中,lr 被重命名为 learning_rate:link .在旧版本中,您应该改用 lr(感谢@Bananach)。

在 keras 后端的帮助下设置值:keras.backend.set_value(model.optimizer.learning_rate, learning_rate)(其中 learning_rate 是一个 float ,期望的学习率) 适用于 fit 方法并且应该适用于 train_on_batch:

from keras import backend as K


counter = 0
for epoch in range(EPOCHS):
for batch_idx in range(0, number_training_samples, BATCH_SIZE):
# get training batch:
x = ...
y = ...
# calculate learning rate:
current_learning_rate = calculate_learning_rate(counter)
# train model:
K.set_value(model.optimizer.learning_rate, current_learning_rate) # set new learning_rate
loss = model.train_on_batch(x, y)

希望对您有所帮助!

关于python - keras:如何使用 model.train_on_batch() 的学习率衰减,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53609972/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com