gpt4 book ai didi

machine-learning - 如何使用 Keras 将 3D 矩阵简化为 2D 矩阵?

转载 作者:行者123 更新时间:2023-11-30 08:31:01 25 4
gpt4 key购买 nike

我已经构建了一个 Keras ConvLSTM 神经网络,并且我想基于一系列 10 次步骤来预测提前一帧:

from keras.models import Sequential
from keras.layers.convolutional import Conv3D
from keras.layers.convolutional_recurrent import ConvLSTM2D
from keras.layers.normalization import BatchNormalization
import numpy as np
import pylab as plt
from keras import layers

# We create a layer which take as input movies of shape
# (n_frames, width, height, channels) and returns a movie
# of identical shape.

model = Sequential()
model.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
input_shape=(None, 64, 64, 1),
padding='same', return_sequences=True))
model.add(BatchNormalization())

model.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
model.add(BatchNormalization())

model.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
model.add(BatchNormalization())

model.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
padding='same', return_sequences=True))
model.add(BatchNormalization())

model.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
activation='sigmoid',
padding='same', data_format='channels_last'))
model.compile(loss='binary_crossentropy', optimizer='adadelta')

训练:

data_train_x = data_4[0:20, 0:10, :, :, :]
data_train_y = data_4[0:20, 10:11, :, :, :]

model.fit(data_train_x, data_train_y, batch_size=10, epochs=1,
validation_split=0.05)

我测试了模型:

test_x = np.reshape(data_test_x[2,:,:,:,:], [1,10,64,64,1])
next_frame = model.predict(test_x,batch_size=1, verbose=1, steps=None)

但问题是“next_frame”形状是:(1, 10, 64, 64, 1),但我希望它的形状是(1, 1, 64, 64, 1)

这是“model.summary()”的结果:

_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv_lst_m2d_1 (ConvLSTM2D) (None, None, 64, 64, 40) 59200
_________________________________________________________________
batch_normalization_1 (Batch (None, None, 64, 64, 40) 160
_________________________________________________________________
conv_lst_m2d_2 (ConvLSTM2D) (None, None, 64, 64, 40) 115360
_________________________________________________________________
batch_normalization_2 (Batch (None, None, 64, 64, 40) 160
_________________________________________________________________
conv_lst_m2d_3 (ConvLSTM2D) (None, None, 64, 64, 40) 115360
_________________________________________________________________
batch_normalization_3 (Batch (None, None, 64, 64, 40) 160
_________________________________________________________________
conv_lst_m2d_4 (ConvLSTM2D) (None, None, 64, 64, 40) 115360
_________________________________________________________________
batch_normalization_4 (Batch (None, None, 64, 64, 40) 160
_________________________________________________________________
conv3d_1 (Conv3D) (None, None, 64, 64, 1) 1081
=================================================================
Total params: 407,001
Trainable params: 406,681
Non-trainable params: 320

所以我不知道要添加什么层,所以我将输出减少到 1 帧而不是 10 帧?

最佳答案

这是基于最后一层中的 3D 卷积的预期结果。例如,如果在 3 维张量的 Conv2D 中有 1 个滤波器,且 padding = 'same',这意味着它将产生相同高度和宽度的 2D 输出(例如滤波器也隐式地沿深度轴捕获)。

对于 4 维张量的 3D 也是如此,它沿着 channel 维度深度轴隐式捕获,从而产生与输入相同的 3-D 张量(序列索引、高度、宽度)。

听起来您想要做的是在 Conv3D 层之后添加某种池化步骤,以便它在序列维度上变平,例如 AveragePooling3D池化元组为(10, 1, 1)对第一个非批量维度进行平均(或根据您的特定网络需求进行修改)。

或者,假设您想通过仅采用最终序列元素来沿序列维度专门“池化”(例如,而不是跨序列进行平均或最大池化)。然后你就可以得到最后的 ConvLSTM2D层有return_sequences=False ,然后在最后一步中进行 2D 卷积,但这意味着您的最终卷积不会从预测帧序列的聚合中受益。可能是特定于应用程序的,无论这是否是一个好主意。

为了确认第一种方法,我添加了:

model.add(layers.AveragePooling3D(pool_size=(10, 1, 1), padding='same'))

就在Conv3D之后层,然后制作玩具数据:

x = np.random.rand(1, 10, 64, 64, 1)

然后:

In [22]: z = model.predict(x)

In [23]: z.shape
Out[23]: (1, 1, 64, 64, 1)

您需要确保第一个非批量维度中的池大小设置为最大可能的序列长度,以确保最终输出形状始终得到 (1, 1, ...)。

关于machine-learning - 如何使用 Keras 将 3D 矩阵简化为 2D 矩阵?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49802778/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com