gpt4 book ai didi

python - Keras CNN 维度问题

转载 作者:太空宇宙 更新时间:2023-11-03 14:20:03 24 4
gpt4 key购买 nike

我正在尝试使用 Keras 构建一个 CNN 来执行图像分割任务,基于 this文章。因为我的数据集很小,所以我想使用 Keras ImageDataGenerator 并将其提供给 fit_generator()。所以,我按照example在 Keras 网站上。但是,由于压缩图像和蒙版生成器不起作用,我遵循了此 answer并创建了我自己的生成器。

我的输入数据大小为(701,256,1),我的问题是二进制的(前景、背景)。对于每个图像,我都有一个相同形状的标签。

现在,我面临着维度问题。 answer中也提到了这一点,但我不确定如何解决。

错误:

 ValueError: Error when checking target: expected dense_3 to have 2 dimensions, but got array with shape (2, 704, 256, 1)

我将在这里粘贴完整的代码:

import numpy
import pygpu
import theano
import keras

from keras.models import Model, Sequential
from keras.layers import Input, Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, Reshape
from keras.layers import BatchNormalization
from keras.preprocessing.image import ImageDataGenerator

from keras.utils import np_utils
from keras import backend as K

def superGenerator(image_gen, label_gen):
while True:
x = image_gen.next()
y = label_gen.next()
yield x[0], y[0]


img_height = 704
img_width = 256

train_data_dir = 'Dataset/Train/Images'
train_label_dir = 'Dataset/Train/Labels'
validation_data_dir = 'Dataset/Validation/Images'
validation_label_dir = 'Dataset/Validation/Labels'
n_train_samples = 1000
n_validation_samples = 500
epochs = 50
batch_size = 2

input_shape = (img_height, img_width,1)
target_shape = (img_height, img_width)

model = Sequential()

model.add(Conv2D(80,(28,28), input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2)))

model.add(Conv2D(96,(18,18)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2)))

model.add(Conv2D(128,(13,13)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2),strides=(2,2)))


model.add(Conv2D(160,(8,8)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))

model.add(Flatten())

model.add(Dense(1024, activation='relu'))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.25))

model.add(Dense(2, activation='softmax'))

model.summary()

model.compile(loss='binary_crossentropy', optimizer='nadam', metrics=['accuracy'])

data_gen_args = dict(
rescale=1./255,
horizontal_flip=True,
vertical_flip=True
)

train_datagen = ImageDataGenerator(**data_gen_args)
train_label_datagen = ImageDataGenerator(**data_gen_args)
test_datagen = ImageDataGenerator(**data_gen_args)
test_label_datagen = ImageDataGenerator(**data_gen_args)

seed = 1

train_image_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=target_shape,
color_mode='grayscale',
batch_size=batch_size,
class_mode = 'binary',
seed=seed)
train_label_generator = train_label_datagen.flow_from_directory(
train_label_dir,
target_size=target_shape,
color_mode='grayscale',
batch_size=batch_size,
class_mode = 'binary',
seed=seed)

validation_image_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=target_shape,
color_mode='grayscale',
batch_size=batch_size,
class_mode = 'binary',
seed=seed)

validation_label_generator = test_label_datagen.flow_from_directory(
validation_label_dir,
target_size=target_shape,
color_mode='grayscale',
batch_size=batch_size,
class_mode = 'binary',
seed=seed)

train_generator = superGenerator(train_image_generator, train_label_generator,batch_size)
test_generator = superGenerator(validation_image_generator, validation_label_generator,batch_size)

model.fit_generator(
train_generator,
steps_per_epoch= n_train_samples // batch_size,
epochs=50,
validation_data=test_generator,
validation_steps=n_validation_samples // batch_size)

model.save_weights('first_try.h5')

我是 Keras(和 CNN)的新手,因此非常感谢任何帮助。

最佳答案

好的。我做了一些橡皮鸭调试并阅读了更多文章。当然,维度是一个问题。 This简单的答案对我有用。我的标签的形状与输入图像相同,因此模型的输出也应该具有该形状。我使用 Conv2DTranspose 来解决这个问题。

关于python - Keras CNN 维度问题,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47994524/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com