gpt4 book ai didi

python - 替换 "tf.gather_nd"

转载 作者:太空宇宙 更新时间:2023-11-03 20:47:54 25 4
gpt4 key购买 nike

我正在做一个项目,但他们的tensorflow版本不支持tf.gather_nd。我问是否可以使用 tf.gather、tf.slice 或 tf.strided_slice 重写 tf.gather_nd 的函数?

tf.gather_nd 用于将张量中的切片收集到具有由索引指定的形状的张量中。详情可见https://www.tensorflow.org/api_docs/python/tf/gather_nd

谢谢

最佳答案

这个函数应该做同样的工作:

import tensorflow as tf
import numpy as np

def my_gather_nd(params, indices):
idx_shape = tf.shape(indices)
params_shape = tf.shape(params)
idx_dims = idx_shape[-1]
gather_shape = params_shape[idx_dims:]
params_flat = tf.reshape(params, tf.concat([[-1], gather_shape], axis=0))
axis_step = tf.cumprod(params_shape[:idx_dims], exclusive=True, reverse=True)
indices_flat = tf.reduce_sum(indices * axis_step, axis=-1)
result_flat = tf.gather(params_flat, indices_flat)
return tf.reshape(result_flat, tf.concat([idx_shape[:-1], gather_shape], axis=0))

# Test
np.random.seed(0)
with tf.Graph().as_default(), tf.Session() as sess:
params = tf.constant(np.random.rand(10, 20, 30).astype(np.float32))
indices = tf.constant(np.stack([np.random.randint(10, size=(5, 8)),
np.random.randint(20, size=(5, 8))], axis=-1))
result1, result2 = sess.run((tf.gather_nd(params, indices),
my_gather_nd(params, indices)))
print(np.allclose(result1, result2))
# True

关于python - 替换 "tf.gather_nd",我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56452714/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com