gpt4 book ai didi

Tensorflow Keras 形状不匹配

转载 作者:行者123 更新时间:2023-12-04 09:55:11 33 4
gpt4 key购买 nike

在尝试实现许多教程用来向您介绍神经网络的标准 MNIST 数字识别器时,我遇到了错误

ValueError: Shape mismatch: The shape of labels (received (1,)) should equal the shape of logits except for the last dimension (received (28, 10)).

我想用 from_tensor_slices处理数据,因为我想将代码应用于数据来自 CSV 文件的另一个问题。无论如何,这是在行 model.fit(...) 中产生错误的代码
import tensorflow as tf

train_dataset, test_dataset = tf.keras.datasets.mnist.load_data()
train_images, train_labels = train_dataset
train_images = train_images/255.0
train_dataset_tensor = tf.data.Dataset.from_tensor_slices((train_images, train_labels))

num_of_validation_data = 10000
validation_data = train_dataset_tensor.take(num_of_validation_data)
train_data = train_dataset_tensor.skip(num_of_validation_data)

model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, activation='sigmoid'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

model.fit(train_data, batch_size=50, epochs=5)

performance = model.evaluate(validation_data)

我不明白形状在哪里 (28, 10)的 logits 来自,我以为我正在展平图像,基本上是从 2D 图像中制作一个 1D 矢量?我怎样才能防止错误?

最佳答案

您可以使用以下代码

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

train_ds = tf.data.Dataset.from_tensor_slices(
(x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(100, activation='sigmoid'),
tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)

model.fit(train_ds)

关于Tensorflow Keras 形状不匹配,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61935211/

33 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com