gpt4 book ai didi

python - 使用数据生成器时,Keras 自定义指标 self.validation_data 为 none

转载 作者:行者123 更新时间:2023-12-04 04:09:48 25 4
gpt4 key购买 nike

我一直在尝试训练模型并在每个时期结束时计算精度和召回率。

自定义指标

class Metrics(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.precision = []
self.recall = []

def on_epoch_end(self, epoch, logs={}):
print(type(self.validation_data))
print(self.validation_data)
predict = np.round(np.asarray(self.model.predict(self.validation_data[0])))
targ = self.validation_data[1]

precision_score = sklm.precision_score(targ, predict)
recall = sklm.recall_score(targ, predict)
self.precision.append(precision_score)
self.recall.append(recall)

def avg_precision_score(self):
return np.mean(self.precision_score)

def avg_recall_score(self):
return np.mean(self.recall)

训练时我使用数据生成器。
   training_set = train_datagen.flow_from_directory('train/',
target_size=(dim_x,dim_y),
batch_size=8, # 16 32
class_mode='categorical')

test_set = test_datagen.flow_from_directory('test/',
target_size=(dim_x,dim_y),
batch_size=8, # 16 32
class_mode='categorical')
metrics = Metrics()
history = classifier.fit_generator(
training_set,
steps_per_epoch=2,#50,
epochs=1, # 25
validation_data=test_set,
validation_steps=10,
callbacks=[metrics]
)

但这提供了 None 类型的 self.validation 。
我究竟做错了什么 ?

最佳答案

找到了这个问题的解决方案。
在问题评论中提及
https://github.com/keras-team/keras/issues/10472

class Metrics(Callback):

def __init__(self, val_data, batch_size = 20):
super().__init__()
self.validation_data = val_data
self.batch_size = batch_size

初始化验证数据解决了使用数据生成器时的问题。

关于python - 使用数据生成器时,Keras 自定义指标 self.validation_data 为 none,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61939790/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com