gpt4 book ai didi

python - 使用 Keras 扩充 CSV 文件数据集

转载 作者:行者123 更新时间:2023-11-30 09:59:20 24 4
gpt4 key购买 nike

我正在开发一个已经在 Kaggle 实现的项目这与图像分类有关。我总共有 6 个类别要预测,分别是愤怒、快乐、悲伤等。我已经实现了 CNN 模型,目前仅使用 4 个类别(图像数量最多的类别),但我的模型过度拟合,我的验证准确率最高可达 53%,因此我尝试了多种方法,但似乎并没有提高我的准确率。现在我看到人们提到了一种叫做数据增强的东西,我想尝试一下,因为它似乎有可能提高准确性。但是我遇到了一个我无法弄清楚的错误。

数据集分布:

6_classes

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from matplotlib.pyplot import imread, imshow, subplots, show


def plot(data_generator):
"""
Plots 4 images generated by an object of the ImageDataGenerator class.
"""
data_generator.fit(df_training)
image_iterator = data_generator.flow(df_training)

# Plot the images given by the iterator
fig, rows = subplots(nrows=1, ncols=4, figsize=(18,18))
for row in rows:
row.imshow(image_iterator.next()[0].astype('int'))
row.axis('off')
show()

x_train = df_training.drop("emotion",axis=1)
image = x_train[1:2].values.reshape(48, 48)
x_train = x_train.values.reshape(x_train.shape[0], 48, 48,1)
x_train = x_train.astype("float32")
image = image.astype("float32")
image = x_train[1:2].reshape(48, 48)

# Creating a dataset which contains just one image.
images = image.reshape((1, image.shape[0], image.shape[1]))

imshow(images[0])
show()
print(x_train.shape)
data_generator = ImageDataGenerator(rotation_range=90)
plot(data_generator)

错误:

ValueError: Input to .fit() should have rank 4. Got array with shape: (28709, 2305)

我已经将数据重新整形为 4d 数组,但由于某种原因,错误中它显示为 2d 数据。这是 print(x_train.shape) => (28709, 48, 48, 1)

的形状

x_train 是数据集所在的位置,x_train[1:2] 访问一张图像。

P.s 您是否会推荐任何其他方法来根据此数据集提高我的准确性。有关我的数据集的更多问题,如果您不理解此部分代码中的某些内容,请告诉我。

最佳答案

您在 df_training 上使用 data_generator,而不是在 x_train 上。

关于如何避免过度拟合的更多想法:Tensorflow 有一个官方教程,其中有一些很好的建议: https://www.tensorflow.org/tutorials/keras/overfit_and_underfit

关于python - 使用 Keras 扩充 CSV 文件数据集,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/59664575/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com