gpt4 book ai didi

python - 带有空张量的tensorflow scatter_nd?

转载 作者:行者123 更新时间:2023-12-01 07:01:31 24 4
gpt4 key购买 nike

我正在尝试将两个张量混合在一起。 scatter_nd 非常适合这种情况,我编写了以下函数来完成我的任务。它基本上只是将 2 个 scatter_nds 广告放在一起。

def tf_munge(t, i, r, j, axis=0):
#insert tensor t at indices i and tensor r at indices j on axis `axis`.
#requires: i.shape[0] == t.shape[axis] && j.shape[0] == r.shape[axis] && t.shape[k] == r.shape[k] ∀k != axis
i = tf.expand_dims(i, -1)
j = tf.expand_dims(j, -1)
rank_indices = tf.range(tf.rank(t))
roller = tf.roll(rank_indices, -axis, 0)
rolled_t = tf.transpose(t, roller)
rolled_r = tf.transpose(r, roller)
scatter_shape = tf.concat((tf.shape(i)[0:1] + tf.shape(j)[0:1], tf.shape(rolled_t)[1:]), axis=0)
scattered = tf.scatter_nd(i, rolled_t, scatter_shape) + tf.scatter_nd(j, rolled_r, scatter_shape)
return tf.transpose(scattered, tf.roll(rank_indices, axis, 0))

一般来说,它按预期工作。但是,只要 rt 沿某个轴均为空,它就会失败。我有两个代码“路径”,具体取决于 bool 值,其中我分割张量并根据该 bool 值是真还是假执行不同的操作。有时,该 bool 值对于 0 行为 false。在这种情况下,我最终会对一个空张量做一些事情。其中之一就是这种尝试分散。该错误实际上引用了输出形状(上述代码中的 scatter_shape ),声称:

ValueError: Indices and updates specified for empty output shape for 'ScatterNd_4' (op: 'ScatterNd')
with input shapes: [3,1], [3,0,2], [3] and with input tensors computed as partial shapes: input[2] = [5,0,2].

请注意,空轴与我散射的轴不同。这是一个工作示例:

foo = tf.ones((3,1,2))
bar = tf.ones((2,1,2))*2
i = tf.constant([1,3,4])
j = tf.constant([0,2])
tf_munge(foo,i,bar,j,axis=0)
#Output: <tf.Tensor 'transpose_13:0' shape=(5, 1, 2) dtype=float32>

这是一个失败的例子:

foo = tf.ones((3,0,2))
bar = tf.ones((2,0,2))*2
tf_munge(foo,i,bar,j,axis=0)
#Output: The error above

这里的预期输出显然是形状为(5,0,2)的空张量。

我考虑过对输入的形状使用条件,但是tf.cond executes both pathways 。当我有一个带有 scatter_nd 的空张量时,如何处理这种情况?

最佳答案

您可以使用 tf.gather 更简单地做到这一点以适用于所有情况的方式:

import tensorflow as tf

def tf_munge(t, i, r, j, axis=0):
tr = tf.concat([t, r], axis=axis)
idx = tf.argsort(tf.concat([i, j], axis=0))
return tf.gather(tr, idx, axis=axis)

with tf.Graph().as_default(), tf.Session() as sess:
foo = tf.ones((3, 1, 2))
bar = tf.ones((2, 1, 2)) * 2
i = tf.constant([1, 3, 4])
j = tf.constant([0, 2])
out = tf_munge(foo, i, bar, j, axis=0)
print(sess.run(out))
# [[[2. 2.]]
#
# [[1. 1.]]
#
# [[2. 2.]]
#
# [[1. 1.]]
#
# [[1. 1.]]]
foo2 = tf.ones((3, 0, 2))
bar2 = tf.ones((2, 0, 2)) * 2
out2 = tf_munge(foo2, i, bar2, j, axis=0)
print(sess.run(out2))
# []

关于python - 带有空张量的tensorflow scatter_nd?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58612735/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com