gpt4 book ai didi

python - 如何让 Keras 仅对验证数据计算某个指标?

转载 作者:行者123 更新时间:2023-12-04 11:27:26 25 4
gpt4 key购买 nike

我正在使用 tf.keras使用 TensorFlow 1.14.0。我已经实现了一个计算量非常大的自定义指标,如果我只是将它添加到作为 model.compile(..., metrics=[...]) 提供的指标列表中,它会减慢训练过程的速度。 .

如何让 Keras 在训练迭代期间跳过度量的计算,但在每个时期结束时根据验证数据(并打印)计算它?

最佳答案

为此,您可以在指标计算中创建一个 tf.Variable 来确定计算是否继续进行,然后在使用回调运行测试时更新它。例如

class MyCustomMetric(tf.keras.metrics.Metrics):

def __init__(self, **kwargs):
# Initialise as normal and add flag variable for when to run computation
super(MyCustomMetric, self).__init__(**kwargs)
self.metric_variable = self.add_weight(name='metric_varaible', initializer='zeros')
self.update_metric = tf.Variable(False)

def update_state(self, y_true, y_pred, sample_weight=None):
# Use conditional to determine if computation is done
if self.update_metric:
# run computation
self.metric_variable.assign_add(computation_result)

def result(self):
return self.metric_variable

def reset_states(self):
self.metric_variable.assign(0.)

class ToggleMetrics(tf.keras.callbacks.Callback):
'''On test begin (i.e. when evaluate() is called or
validation data is run during fit()) toggle metric flag '''
def on_test_begin(self, logs):
for metric in self.model.metrics:
if 'MyCustomMetric' in metric.name:
metric.on.assign(True)
def on_test_end(self, logs):
for metric in self.model.metrics:
if 'MyCustomMetric' in metric.name:
metric.on.assign(False)

关于python - 如何让 Keras 仅对验证数据计算某个指标?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56826495/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com