gpt4 book ai didi

python - model.fit() 不接受 tf.data.Dataset 的输入形状

转载 作者:行者123 更新时间:2023-12-03 08:40:46 24 4
gpt4 key购买 nike

我想通过应用 tf.data.Dataset 为我的模型提供数据。

检查了 TF 2.0 的文档后,我发现 .fit() 函数 ( https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit ) 接受:

x - A tf.data dataset. Should return a tuple of either (inputs, targets)or (inputs, targets, sample_weights).

因此,我编写了以下简短的概念验证代码:

from sklearn.datasets import make_blobs
import tensorflow as tf
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.metrics import Accuracy, AUC

X, Y = make_blobs(n_samples=500, n_features=2, cluster_std=3.0, random_state=1)

def define_model():
model = Sequential()
model.add(Dense(units=1, activation="sigmoid", input_shape=(2,)))
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=[AUC(), Accuracy()])
return model

model = define_model()

X_ds = tf.data.Dataset.from_tensor_slices(X)
Y_ds = tf.data.Dataset.from_tensor_slices(Y)
dataset = tf.data.Dataset.zip((X_ds, Y_ds))

for elem in dataset.take(1):
print(type(elem))
print(elem)

model.fit(x=dataset) #<-- does not work
#model.fit(x=X, y=Y) <-- does work without any problems....

正如第二条评论中提到的,不应用 tf.data.Dataset 的代码工作正常。

但是,当应用数据集对象时,我收到以下错误消息:

<class 'tuple'>
(<tf.Tensor: shape=(2,), dtype=float64, numpy=array([-10.42729974, -0.85439721])>, <tf.Tensor: shape=(), dtype=int64, numpy=1>)
... other output here...
ValueError: Error when checking input: expected dense_19_input to have
shape (2,) but got array with shape (1,)

根据我对文档的理解,我构建的数据集应该正是 fit 方法期望的元组对象。

我不明白这个错误消息。

我在这里做错了什么?

最佳答案

当您将数据集传递给 fit 时,预计它会直接生成批处理,而不是单个示例。您只需在训练之前对数据集进行批处理即可。

dataset = dataset.batch(batch_size)
model.fit(x=dataset)

关于python - model.fit() 不接受 tf.data.Dataset 的输入形状,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62877768/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com