gpt4 book ai didi

tensorflow - Tensorflow 中的回调

转载 作者:行者123 更新时间:2023-12-05 07:28:32 30 4
gpt4 key购买 nike

在Keras中我们可以简单的添加回调,如下所示:

self.model.fit(X_train,y_train,callbacks=[Custom_callback])

回调在doc中定义,但我找不到任何使用它们的例子。谁能告诉我如何将自定义回调添加到 TensorFlow 中?

最佳答案

这是我喜欢使用而不是 tensorboard 的损失记录回调示例。请注意,它需要后处理,我实际上更喜欢这样,因此我可以计算 cross-validation 的平均验证损失。 :

class lossesLogger(tf.keras.callbacks.Callback):
def __init__(self, fileName):
self.fileName = fileName
self.json_log = open(
self.fileName +'.json',
mode='w+',
buffering=1
)
def on_epoch_end(self, epoch, logs=None):
self.json_log.write(
json.dumps(
'epoch {}: '.format(epoch) +
str(logs)
) +
'\n'
)
def on_train_end(self, logs=None):
self.json_log.close()

要使用它,请将它添加到您的回调列表中,就像在这个示例中一样,我将它放在三个列表的末尾:

callbacks = [ #allows analyzing output of replicates
        EarlyStopping(
patience=epochsWithoutValidationLossDecrease,
verbose=1
),
        ModelCheckpoint(
            os.path.join(
os.getcwd(),
'Latest_saved_model_{}.h5'.format(uniqueID)
),
            verbose=1,
            save_best_only=True,
            save_weights_only=False
        ),
     lossesLogger(
         os.path.join(
             os.getcwd(),
             ('val_log_per_epoch_' + str(uniqueID))
         )
     )
    ]

关于tensorflow - Tensorflow 中的回调,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53324641/

30 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com