gpt4 book ai didi

python - TensorFlow 中类似 NumPy 的切片

转载 作者:太空宇宙 更新时间:2023-11-04 01:57:07 25 4
gpt4 key购买 nike

我有一个张量对象,我想切分它的一部分。

tf_a1 = tf.Variable([    [9.968594,  8.655439,  0.,        0.       ],
[0., 8.3356, 0., 8.8974 ],
[0., 0., 6.103182, 7.330564 ],
[6.609862, 0., 3.0614321, 0. ],
[9.497023, 0., 3.8914037, 0. ],
[0., 8.457685, 8.602337, 0. ],
[0., 0., 5.826657, 8.283971 ],
[0., 0., 0., 0. ]])

另外,我有这个数组:

tf_a2 = tf.constant([[1, 2, 5],
[1, 4, 6],
[0, 7, 7],
[2, 3, 6],
[2, 4, 7]])

我想像切片一样做这个 numpy:

 tf_a1[tf_a2]

numpy 代码的预期输出如下:

[[[0.        8.3356    0.        8.8974   ]
[0. 0. 6.103182 7.330564 ]
[0. 8.457685 8.602337 0. ]]

[[0. 8.3356 0. 8.8974 ]
[9.497023 0. 3.8914037 0. ]
[0. 0. 5.826657 8.283971 ]]

[[9.968594 8.655439 0. 0. ]
[0. 0. 0. 0. ]
[0. 0. 0. 0. ]]

[[0. 0. 6.103182 7.330564 ]
[6.609862 0. 3.0614321 0. ]
[0. 0. 5.826657 8.283971 ]]

[[0. 0. 6.103182 7.330564 ]
[9.497023 0. 3.8914037 0. ]
[0. 0. 0. 0. ]]]

我认为我可以使用以下方法在 tensorflow 中执行类似的操作:

tf.gather_nd(tf_a1, tf_a2)

但它引发了这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: index innermost dimension length must be <= params rank; saw: 3 vs. 2 [Op:GatherNd]

感谢任何帮助:)

最佳答案

我想你可以使用 tf.gather :

tf.gather(tf_a1, tf_a2, axis=0)                                                                                        
# <tf.Tensor 'GatherV2_10:0' shape=(5, 3, 4) dtype=float32>

TensorFlow 2.0 上的可重现示例

tf.__version__
# '2.0.0-beta0'

tf.gather(tf_a1, tf_a2, axis=0)

<tf.Tensor: id=9, shape=(5, 3, 4), dtype=float32, numpy=
array([[[0. , 8.3356 , 0. , 8.8974 ],
[0. , 0. , 6.103182 , 7.330564 ],
[0. , 8.457685 , 8.602337 , 0. ]],

[[0. , 8.3356 , 0. , 8.8974 ],
[9.497023 , 0. , 3.8914037, 0. ],
[0. , 0. , 5.826657 , 8.283971 ]],

[[9.968594 , 8.655439 , 0. , 0. ],
[0. , 0. , 0. , 0. ],
[0. , 0. , 0. , 0. ]],

[[0. , 0. , 6.103182 , 7.330564 ],
[6.609862 , 0. , 3.0614321, 0. ],
[0. , 0. , 5.826657 , 8.283971 ]],

[[0. , 0. , 6.103182 , 7.330564 ],
[9.497023 , 0. , 3.8914037, 0. ],
[0. , 0. , 0. , 0. ]]], dtype=float32)>

关于python - TensorFlow 中类似 NumPy 的切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56568386/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com