gpt4 book ai didi

python - 访问 tf.keras.callbacks.Callback 中已弃用的属性 "validation_data"

转载 作者:行者123 更新时间:2023-12-04 11:50:38 27 4
gpt4 key购买 nike

我决定从 keras 切换到 tf.keras(建议使用 here)。因此我安装了 tf.__version__=2.0.0tf.keras.__version__=2.2.4-tf .在我的旧版本代码(使用一些较旧的 Tensorflow 版本 tf.__version__=1.x.x )中,我使用回调来计算每个时期结束时整个验证数据的自定义指标。这样做的想法来自 here .但是,似乎不推荐使用“validation_data”属性,因此以下代码不再起作用。

class ValMetrics(Callback):

def on_train_begin(self, logs={}):

self.val_all_mse = []

def on_epoch_end(self, epoch, logs):

val_predict = np.asarray(self.model.predict(self.validation_data[0]))
val_targ = self.validation_data[1]

val_epoch_mse = mse_score(val_targ, val_predict)

self.val_epoch_mse.append(val_epoch_mse)

# Add custom metrics to the logs, so that we can use them with
# EarlyStop and csvLogger callbacks
logs["val_epoch_mse"] = val_epoch_mse

print(f"\nEpoch: {epoch + 1}")
print("-----------------")
print("val_mse: {:+.6f}".format(val_epoch_mse))

return

我目前的解决方法如下。我只是给了validation_data 作为 ValMetrics 的参数。类(class) :
class ValMetrics(Callback):

def __init__(self, validation_data):
super(Callback, self).__init__()
self.X_val, self.y_val = validation_data

我仍然有一些问题:“validation_data”属性真的被弃用了还是可以在其他地方找到?有没有比上述解决方法更好的方法来访问每个时期结束时的验证数据?

非常感谢!

最佳答案

你说得对,validation_data已根据 Tensorflow Callbacks Documentation 弃用.

您所面临的问题已在 Github 中提出。相关问题是Issue1 , Issue2Issue3 .

上述 Github 问题均未解决,您通过的解决方法 Validation_Data作为自定义回调的参数是一个很好的参数,根据此 Github Comment ,因为很多人发现它很有用。

为了 Stackoverflow Community 的利益,在下面指定解决方法的代码,即使它存在于 Github 中。

class Metrics(Callback):

def __init__(self, val_data, batch_size = 20):
super().__init__()
self.validation_data = val_data
self.batch_size = batch_size

def on_train_begin(self, logs={}):
print(self.validation_data)
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []

def on_epoch_end(self, epoch, logs={}):
batches = len(self.validation_data)
total = batches * self.batch_size

val_pred = np.zeros((total,1))
val_true = np.zeros((total))

for batch in range(batches):
xVal, yVal = next(self.validation_data)
val_pred[batch * self.batch_size : (batch+1) * self.batch_size] = np.asarray(self.model.predict(xVal)).round()
val_true[batch * self.batch_size : (batch+1) * self.batch_size] = yVal

val_pred = np.squeeze(val_pred)
_val_f1 = f1_score(val_true, val_pred)
_val_precision = precision_score(val_true, val_pred)
_val_recall = recall_score(val_true, val_pred)

self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)

return

我将继续关注上面提到的 Github 问题,并将相应地更新答案。

希望这可以帮助。快乐学习!

关于python - 访问 tf.keras.callbacks.Callback 中已弃用的属性 "validation_data",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60080646/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com