作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
这是我的模型:
filters = 256
kernel_size = 3
strides = 1
factor = 4 # the factor of upscaling
inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))
conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)
act = ReLU()(res)
res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([conv1, res])
for i in range(15): # 16-1
res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
act = ReLU()(res1)
res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)
res_rec = Add()([res_rec, res2])
conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)
a = Add()([conv1, conv2])
up = UpSampling2D(size=4)(a)
outputLayer = Conv2D(filters=3,
kernel_size=1,
strides=1,
padding='same')(up)
model = Model(inputs=inputLayer, outputs=outputLayer)
model.summary()
显示:
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, 350, 350, 3) 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 350, 350, 256 7168 input_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 350, 350, 256 590080 conv2d_1[0][0]
__________________________________________________________________________________________________
re_lu_1 (ReLU) (None, 350, 350, 256 0 conv2d_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 350, 350, 256 590080 re_lu_1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 350, 350, 256 590080 add_1[0][0]
__________________________________________________________________________________________________
re_lu_2 (ReLU) (None, 350, 350, 256 0 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 350, 350, 256 590080 re_lu_2[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 350, 350, 256 0 add_1[0][0]
conv2d_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 350, 350, 256 590080 add_2[0][0]
__________________________________________________________________________________________________
re_lu_3 (ReLU) (None, 350, 350, 256 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 350, 350, 256 590080 re_lu_3[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 350, 350, 256 0 add_2[0][0]
conv2d_7[0][0]
...... this goes on for a long time .....
__________________________________________
add_15 (Add) (None, 350, 350, 256 0 add_14[0][0]
conv2d_31[0][0]
__________________________________________________________________________________________________
conv2d_32 (Conv2D) (None, 350, 350, 256 590080 add_15[0][0]
__________________________________________________________________________________________________
re_lu_16 (ReLU) (None, 350, 350, 256 0 conv2d_32[0][0]
__________________________________________________________________________________________________
conv2d_33 (Conv2D) (None, 350, 350, 256 590080 re_lu_16[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 350, 350, 256 0 add_15[0][0]
conv2d_33[0][0]
__________________________________________________________________________________________________
conv2d_34 (Conv2D) (None, 350, 350, 256 590080 add_16[0][0]
__________________________________________________________________________________________________
add_17 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_34[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 1400, 1400, 2 0 add_17[0][0]
__________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 1400, 1400, 3 771 up_sampling2d_1[0][0]
==================================================================================================
Total params: 19,480,579
Trainable params: 19,480,579
Non-trainable params: 0
__________________________________________________________________________________________________
None
重要的部分就在最后,靠近输出:
__________________________________________________________________________________________________
add_17 (Add) (None, 350, 350, 256 0 conv2d_1[0][0]
conv2d_34[0][0]
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 1400, 1400, 2 0 add_17[0][0]
__________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 1400, 1400, 3 771 up_sampling2d_1[0][0]
==================================================================================================
现在,看看我运行网络时遇到的错误:
Traceback (most recent call last):
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>
setUpImages()
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages
setUpData(trainData, testData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData
setUpModel(X_train, Y_train, validateTestData, trainingTestData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel
train(model, X_train, Y_train, validateTestData, trainingTestData)
File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train
batch_size=32)
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit
batch_size=batch_size)
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data
exception_prefix='target')
File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data
str(data_shape))
ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)
为什么我的最后一个卷积期望有一个 (1400, 1400, 1)
张量,却得到一个 (1400, 1400, 3)
张量,而摘要说 UpSampling2D
应该返回一个 (1400, 1400, 2)
张量?
澄清一下上下文:这应该是一个采用 350x350x3 图像并输出 1400x1400x3 图像的网络。
最佳答案
很明显,错误消息与 conv2d_35
实体并不特别相关,而是与我的损失函数链接的网络的最后一个实体。
由于我选择了 sparse_categorical_crossentropy
作为损失函数,因此它期望一个单一维度向量。
将损失设置为 mean_squared_error
修复了该问题。
关于python-3.x - Keras UpSampling2D 不一致的行为,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52452947/
我是一名优秀的程序员,十分优秀!