gpt4 book ai didi

machine-learning - 从 Keras 中的 3 维张量收集 2 维张量列表

转载 作者:行者123 更新时间:2023-11-30 08:40:48 25 4
gpt4 key购买 nike

我有一个名为 ma​​in_decoder 的 3-d 张量,形状为 (None,9,256)

我想提取 9 个形状为 (无,256)的张量

我尝试使用 Keras gather,以下是模式代码片段:

for i in range(0,9):
sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)

结果是 9 个形状为 (9,256)

的 lambda 层

如何修改它以便可以获得或收集形状为 (None,256) 的 9 个张量

谢谢。

最佳答案

您可以将 3D 张量分割为 9 个 2D 张量,并从 Lambda 层返回张量列表。

main_decoder = Input(shape=(9, 256))
sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)

print(sub_decoder_input)
[<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]

关于machine-learning - 从 Keras 中的 3 维张量收集 2 维张量列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/47347738/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com