gpt4 book ai didi

python - 在 tensorflow 中获取 ValueError 说我的形状不兼容

转载 作者:行者123 更新时间:2023-12-01 23:59:36 24 4
gpt4 key购买 nike

错误:

return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
C:\Users\selvaa\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\keras\backend.py:4619 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
C:\Users\selvaa\miniconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1128 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))

ValueError: Shapes (None, 1) and (None, 151) are incompatible

我的模型:

x = np.array(x)
y = np.array(y)

x = x/255.0

model = Sequential()
model.add(Conv2D(3, (3,3), input_shape=(128,128,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Flatten())
model.add(Dense(302, activation='relu'))
model.add(Dense(151, activation='softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x, y, batch_size=32, epochs=5, verbose=1, validation_split=0.1)

我正在尝试训练一个模型来识别不同的口袋妖怪,我的数据集(正确标记和所有)都有两张图片,每张 151 只口袋妖怪。不确定我做错了什么。

这是我打印 x.shape 和 y.shape 时发生的情况:

(301, 128, 128, 3) (301,)

最佳答案

使用损失tf.keras.losses.SparseCategoricalCrossEntropy ,如下面的代码示例所示。

损失函数 tf.keras.losses.SparseCategoricalCrossEntropy 接受形状为 (n_samples,) 的引用标签和形状为 (n_samples, n_classes ),这将适用于您的数据。您不能使用 categorical_crossentropy,因为它要求您的标签进行单热编码(请参阅答案底部)。

x = np.array(x)
y = np.array(y)

x = x / 255.0

model = Sequential()
model.add(Conv2D(3, (3,3), input_shape=(128,128,3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(302, activation='relu'))
model.add(Dense(151, activation='softmax'))

model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
optimizer='adam',
metrics=['accuracy'])

model.fit(x, y, batch_size=32, epochs=5, verbose=1, validation_split=0.1)

另一种解决方案是在训练之前对标签进行单热编码,例如使用函数 tf.one_hot .如果您使用这种方法,则可以使用 categorical_crossentropy

关于python - 在 tensorflow 中获取 ValueError 说我的形状不兼容,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/62075840/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com