gpt4 book ai didi

python - tf.keras.models.Sequential 模型无法适应输入类型 tf.Tensor

转载 作者:行者123 更新时间:2023-11-28 18:00:12 28 4
gpt4 key购买 nike

我写了一个简单的 tf.keras.models.Sequential 模型。当我尝试用数据和标签作为 tf.Tensor 来拟合它时,它给了我一些错误。但是我可以用具有完全相同的底层数据的 numpy 数组来适应它。这是为什么?

我正在使用只有 CPU 的 tensorflow 1.13。我检查了 fit tf.keras.models.Sequential 的函数,但它说 tf.Tensor 和 numpy 数组都可以用作数据和标签,只要它们的类型匹配。

import tensorflow as tf
tf.enable_eager_execution()

# very simple keras Sequential model
model = tf.keras.Sequential([
tf.keras.layers.Dense(3, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax')])

model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])

# use tf.Tensor as data and label
data = tf.constant([[0,0,1],[0,1,0],[1,0,0]])
label = tf.constant([[0,0,1],[0,1,0],[1,0,0]])
# This throws the following error
# InvalidArgumentError: Index out of range using input dim 2; input has only 2 dims [Op:StridedSlice] name: strided_slice/
model.fit(data, label, epochs=10)

# use numpy array with the same underlying data and label
data = data.numpy()
label = label.numpy()
# This works
model.fit(data, label, epochs=10)


第一次拟合不起作用并抛出以下错误。但是第二个有效。这很有趣,因为它们具有完全相同的基础数据

最佳答案

好吧,看起来你可能正在使用 tensorflow 2.0,因为调用 .numpy() 我相信它在 1.13 上不存在(也许你已经意识到但你可以检查版本使用 tf.__version__)

如果您打算使用 1.13,则需要进行 2 处更改以允许对 fit 的调用无错误地执行。

  1. 您必须将输入张量转换为 float32 类型
  2. 你必须传递一个steps_per_epoch参数

例如,这段代码不会抛出任何错误:

model = tf.keras.Sequential([
tf.keras.layers.Dense(3, activation='relu'),
tf.keras.layers.Dense(3, activation='softmax')])

model.compile(optimizer=tf.train.AdamOptimizer(0.001),
loss='categorical_crossentropy',
metrics=['accuracy'])


data = tf.constant([[0,0,1],[0,1,0],[1,0,0]], dtype=tf.float32)
label = tf.constant([[0,0,1],[0,1,0],[1,0,0]], dtype=tf.float32)
model.fit(data, label, epochs=10, steps_per_epoch=2)

关于python - tf.keras.models.Sequential 模型无法适应输入类型 tf.Tensor,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56088294/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com