gpt4 book ai didi

python - Tensorflow:在时间流中的索引处查找张量

转载 作者:行者123 更新时间:2023-12-01 01:20:00 27 4
gpt4 key购买 nike

我正在训练一个 RNN,我需要使用索引来查找示例时间流的另一部分中的值

v = tf.constant([
[[.1, .2], [.3, .4]], # timestream 1 values
[[.6, .5], [.7, .8]] # timestream 2 values
])
ixs = tf.constant([
[1, 0], # indices into timestream 1 values
[0, 1] # indices into timestream 2 values
])

我正在寻找一个可以进行查找并用张量值和产量替换索引的操作:

[
[[.3, .4], [.1, .2]],
[[.6, .5], [.7, .8]]
]

tf.gather 和 tf.gather_nd 听起来可能是正确的路径,但我并不真正理解从它们那里得到的结果。

v_at_ix = tf.gather(v, ixs, axis=-1)
sess.run(v_at_ix)
array([[[[0.2, 0.1],
[0.1, 0.2]],
[[0.4, 0.3],
[0.3, 0.4]]],
[[[0.5, 0.6],
[0.6, 0.5]],
[[0.8, 0.7],
[0.7, 0.8]]]], dtype=float32)

v_at_ix = tf.gather_nd(v, ixs)
sess.run(v_at_ix)
array([[0.6, 0.5],
[0.3, 0.4]], dtype=float32)

有谁知道正确的方法吗?

最佳答案

tf.gather只能获取基于指定轴的切片,其索引是并列的。在v_at_ix = tf.gather(v, ixs, axis=-1)中:

[1, 0] 中的

1 表示 [[[.2],[.4]],[[.5],[.8 ]]]v 中。

[1, 0] 中的

0 表示 [[[.1],[.3]],[[.6],[.7 ]]]v 中。

[0, 1] 中的

0 表示 [[[.1],[.3]],[[.6],[.7 ]]]v 中。

[0, 1] 中的

1 表示 [[[.2],[.4]],[[.5],[.8 ]]]v 中。

tf.gather_nd能够获取指定索引处的切片,并且其索引是渐进的。在v_at_ix = tf.gather_nd(v, ixs)中:

[1, 0] 中的

1 代表 [[.6, .5], [.7, .8]] v

[1, 0] 中的

0 代表 [[.6, .5] 中的 [.6, .5] ], [.7, .8]]

[0, 1] 中的

0 代表 [[.1, .2], [.3, .4]] v

[0, 1] 中的

1 表示 [[.1, .2] 中的 [.3, .4] ], [.3, .4]]

所以当我们使用tf时,我们需要的是[[[0,1],[0,0]],[[1,0],[1,1]]] .gather_nd。它可以由[[0,0],[1,1]][[1,0],[0,1]]组成。前者是重复的行号,后者是ixs。所以我们可以做到

ixs_row = tf.tile(tf.expand_dims(tf.range(v.shape[0]),-1),multiples=[1,v.shape[1]])
ixs = tf.concat([tf.expand_dims(ixs_row,-1),tf.expand_dims(ixs,-1)],axis=-1)
v_at_ix = tf.gather_nd(v,ixs)

关于python - Tensorflow:在时间流中的索引处查找张量,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53919354/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com