gpt4 book ai didi

python - 如何使 `fit_generator` 与 `tf.keras.Model` 一起工作

转载 作者:行者123 更新时间:2023-12-04 17:37:02 25 4
gpt4 key购买 nike

我正在实现一个 tf.keras.Model(不是 Sequential 模型!),它应该使用 fit_generator 进行训练。但是,fit_generator 会引发错误,可能是因为输入形状在编译时不可用。

这是一个最小的例子:

import tensorflow as tf
import numpy as np


class MyModel(tf.keras.Model):

def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

def call(self, inputs, training=None, mask=None):
return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

def __len__(self):
# Number of batches per epoch
return 1

def __getitem__(self, _):
# Generate one batch of data
x = np.array([[1., 2., 3.]])
y = np.array([[0., 1., 0.5]])

return x, y


if __name__ == '__main__':
m = MyModel()
g = MyGenerator()

m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
m.fit_generator(g)

最后一行加注

AttributeError: 'MyModel' object has no attribute 'total_loss'

那么在自定义 Keras 模型中使用 fit_generator 的正确方法是什么?

最佳答案

在 Tensorflow 2.x 中,eager execution 默认启用。 Model.fit_generator 已弃用,将在未来版本中删除。所以你必须使用支持生成器的 Model.fit

请引用TF 2.4兼容代码如下所示

import tensorflow as tf
print(tf.__version__)
import numpy as np


class MyModel(tf.keras.Model):

def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(3, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(3, activation=tf.nn.softmax)

def call(self, inputs, training=None, mask=None):
return self.dense2(self.dense1(inputs))


class MyGenerator(tf.keras.utils.Sequence):

def __len__(self):
# Number of batches per epoch
return 1

def __getitem__(self, _):
# Generate one batch of data
x = np.array([[1., 2., 3.]])
y = np.array([[0., 1., 0.5]])

return x, y


if __name__ == '__main__':
m = MyModel()
g = MyGenerator()

m.compile(tf.keras.optimizers.SGD(), loss=tf.keras.losses.mean_squared_error)
m.fit(g)

输出:

2.4.0
1/1 [==============================] - 0s 224ms/step - loss: 0.4725

关于python - 如何使 `fit_generator` 与 `tf.keras.Model` 一起工作,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56321468/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com