gpt4 book ai didi

deep-learning - 如何在 Keras 中按列拆分张量以实现 STFCN

转载 作者:行者123 更新时间:2023-12-01 01:54:41 25 4
gpt4 key购买 nike

我想在 Keras 中实现时空全卷积网络(STFCN)。我需要输入 3D 卷积输出的每个深度列,例如形状为 (64, 16, 16) 的张量,作为单独 LSTM 的输入。

为了清楚起见,我有一个 (64 x 16 x 16)维度张量 (channels, height, width) .我需要将张量(显式或隐式)拆分为 16 * 16 = 256 个形状的张量 (64 x 1 x 1) .

这是 STFCN 论文中用于说明时空模块的图表。我上面描述的是“空间特征”和“时空模块”之间的箭头。

The connection between FCn and Spatio-Temporal Module is the relevant part of the diagram.

这个想法如何在 Keras 中得到最好的实现?

最佳答案

您可以使用 tf.split来自使用 Keras 的 Tensorflow Lambda

使用 Lambda 分割形状为 (64,16,16) 的张量进入 (64,1,1,256)然后子集您需要的任何索引。

import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.models import Model
from keras.layers import Input, Lambda

# input data
data = np.ones((3,64,16,16))

# define lambda function to split
def lambda_fun(x) :
x = K.expand_dims(x, 4)
split1 = tf.split(x, 16, 2)
x = K.concatenate(split1, 4)
split2 = tf.split(x, 16, 3)
x = K.concatenate(split2, 4)
return x

## check thet splitting works fine
input = Input(shape= (64,16,16))
ll = Lambda(lambda_fun)(input)
model = Model(inputs=input, outputs=ll)
res = model.predict(data)
print(np.shape(res)) #(3, 64, 1, 1, 256)

关于deep-learning - 如何在 Keras 中按列拆分张量以实现 STFCN,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/41484413/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com