gpt4 book ai didi

python - 如何在未知大小的张量上使用 tf.slice

转载 作者:太空宇宙 更新时间:2023-11-04 09:45:46 25 4
gpt4 key购买 nike

我正在处理一个矩阵 sim_mat,它是一个大小为 (?,?,?) 的张量。

我想像这样在上面使用 tf.slice:

a=tf.slice(sim_matrix,[0,0,0],tf.stack([tf.shape(sim_matrix[0],tf.shape(sim_matrix)[1],3]))
print(a)

但这给了我大小为 (?,?,?) 而不是 (?,?,3) 的张量 a。此外,当我将 a 用于更多功能时,我收到一条错误消息:

输入大小(输入的深度)必须可以通过形状推断访问,但没有看到值。

有解决办法吗?

谢谢。

最佳答案

您可以编写以下等效但可读性更好的行,并带有花哨的索引:

a = sim_matrix[ :, :, 0 : 3 ]

但是,这仍然会为您提供形状张量 ( ?, ?, ? )。

如果您确实需要固定生成的张量的三维,请使用 tf.gather() .您必须使用 tf.transpose() 排列尺寸然而,之前和之后,因为tf.gather()仅适用于第一维度。 (也有 tf.gather_nd(),但这在这里没有帮助,因为您不想在前两个 dims 上指定索引,您想要切片。)所以像这样(测试代码):

import tensorflow as tf

sim_matrix = tf.placeholder( shape=(None,None,None), dtype = tf.float32 )
sim_matrixT = tf.transpose( sim_matrix, [ 2, 0, 1 ] )
aT = tf.gather( sim_matrixT, [ 0, 1, 2 ] ) # get first three
a = tf.transpose( aT, [ 1, 2, 0 ] )

print( a.get_shape() )

输出:

(?, ?, 3)

您还说您在其他功能上遇到错误。如果这些函数需要其输入张量的固定形状,情况仍然可能如此。

关于python - 如何在未知大小的张量上使用 tf.slice,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49858041/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com