gpt4 book ai didi

python - 未实现错误 : Learning rate schedule must override get_config

转载 作者:行者123 更新时间:2023-12-03 14:38:06 60 4
gpt4 key购买 nike

我已经使用 tf.keras 创建了一个自定义计划,并且在保存模型时遇到了这个错误:

NotImplementedError: Learning rate schedule must override get_config



这个类看起来像这样:

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):

def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()

self.d_model = d_model
self.d_model = tf.cast(self.d_model, tf.float32)

self.warmup_steps = warmup_steps

def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps**-1.5)

return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

def get_config(self):
config = {
'd_model':self.d_model,
'warmup_steps':self.warmup_steps

}
base_config = super(CustomSchedule, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

最佳答案

当您使用自定义子类模型时,保存模型架构有点棘手。相反,使用 Model.save_weights() 仅保存权重更容易。

如果您将代码更改为此,您将不会看到该错误:

  def get_config(self):
config = {
'd_model': self.d_model,
'warmup_steps': self.warmup_steps,

}
return config

关于python - 未实现错误 : Learning rate schedule must override get_config,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61557024/

60 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com