gpt4 book ai didi

python - 在自定义回调中访问验证数据

转载 作者:太空狗 更新时间:2023-10-29 17:20:48 26 4
gpt4 key购买 nike

我正在安装一个 train_generator 并通过自定义回调来计算我的 validation_generator 上的自定义指标。如何在自定义回调中访问参数 validation_stepsvalidation_dataself.params里没有,self.model里也找不到。这就是我想做的。欢迎任何不同的方法。

model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
callbacks=[CustomMetrics()])


class CustomMetrics(keras.callbacks.Callback):

def on_epoch_end(self, batch, logs={}):
for i in validation_steps:
# features, labels = next(validation_data)
# compute custom metric: f(features, labels)
return

喀拉斯:2.1.1

更新

我设法将我的验证数据传递给自定义回调的构造函数。然而,这会导致恼人的“内核似乎已经死亡。它将自动重新启动。”信息。我怀疑这是否是正确的方法。有什么建议吗?

class CustomMetrics(keras.callbacks.Callback):

def __init__(self, validation_generator, validation_steps):
self.validation_generator = validation_generator
self.validation_steps = validation_steps


def on_epoch_end(self, batch, logs={}):

self.scores = {
'recall_score': [],
'precision_score': [],
'f1_score': []
}

for batch_index in range(self.validation_steps):
features, y_true = next(self.validation_generator)
y_pred = np.asarray(self.model.predict(features))
y_pred = y_pred.round().astype(int)
self.scores['recall_score'].append(recall_score(y_true[:,0], y_pred[:,0]))
self.scores['precision_score'].append(precision_score(y_true[:,0], y_pred[:,0]))
self.scores['f1_score'].append(f1_score(y_true[:,0], y_pred[:,0]))
return

metrics = CustomMetrics(validation_generator, validation_steps)

model.fit_generator(generator=train_generator,
steps_per_epoch=steps_per_epoch,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_steps,
shuffle=True,
callbacks=[metrics],
verbose=1)

最佳答案

您可以直接遍历 self.validation_data 以在每个纪元结束时聚合所有验证数据。如果您想计算整个验证数据集的精度、召回率和 F1:

# Validation metrics callback: validation precision, recall and F1
# Some of the code was adapted from https://medium.com/@thongonary/how-to-compute-f1-score-for-each-epoch-in-keras-a1acd17715a2
class Metrics(callbacks.Callback):

def on_train_begin(self, logs={}):
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []

def on_epoch_end(self, epoch, logs):
# 5.4.1 For each validation batch
for batch_index in range(0, len(self.validation_data)):
# 5.4.1.1 Get the batch target values
temp_targ = self.validation_data[batch_index][1]
# 5.4.1.2 Get the batch prediction values
temp_predict = (np.asarray(self.model.predict(
self.validation_data[batch_index][0]))).round()
# 5.4.1.3 Append them to the corresponding output objects
if(batch_index == 0):
val_targ = temp_targ
val_predict = temp_predict
else:
val_targ = np.vstack((val_targ, temp_targ))
val_predict = np.vstack((val_predict, temp_predict))

val_f1 = round(f1_score(val_targ, val_predict), 4)
val_recall = round(recall_score(val_targ, val_predict), 4)
val_precis = round(precision_score(val_targ, val_predict), 4)

self.val_f1s.append(val_f1)
self.val_recalls.append(val_recall)
self.val_precisions.append(val_precis)

# Add custom metrics to the logs, so that we can use them with
# EarlyStop and csvLogger callbacks
logs["val_f1"] = val_f1
logs["val_recall"] = val_recall
logs["val_precis"] = val_precis

print("— val_f1: {} — val_precis: {} — val_recall {}".format(
val_f1, val_precis, val_recall))
return

valid_metrics = Metrics()

然后您可以将 valid_metrics 添加到回调参数:

your_model.fit_generator(..., callbacks = [valid_metrics])

如果您希望其他回调使用这些措施,请务必将其放在回调的开头。

关于python - 在自定义回调中访问验证数据,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47676248/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com