gpt4 book ai didi

python - 如何对形状为 (batch_size, 200, 256) 的张量进行索引以获得 (batch_size, 1, 256) 给定的长度 = batch_size 的索引张量列表?

转载 作者:太空宇宙 更新时间:2023-11-03 20:17:40 24 4
gpt4 key购买 nike

我有形状为 (batch_size, 200, 256) 的 LSTM 层的输出,其中 200 是标记序列的长度,256 是 LSTM 输出维度。我还有另一个形状为 (batch_size) 的张量,它是我想要从批处理中的每个样本序列中切出的标记的索引列表。

如果标记索引不为-1,我将切出标记向量表示(长度= 256)。如果 token 索引为-1,我将给出零向量(长度= 256)。

预期输出结果的形状为 (batch_size, 1, 256)。我该怎么做?

谢谢

这是我迄今为止尝试过的

bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
matrix_to_slice = x[0]
index_tensor = x[1]


out_tensor = tf.where(index_tensor == -1,
tf.zeros(tf.shape(tf.gather(matrix_to_slice,
index_tensor, axis=1))),
tf.gather(matrix_to_slice, index_tensor, axis=1))



return out_tensor


representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0])
# stack_idx0 shape is (batch_size)
# I got output with shape (batch_size, batch_size, 256) with this code

最佳答案

a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
# [[[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]],

# [[12, 13, 14, 15],
# [16, 17, 18, 19],
# [20, 21, 22, 23]]]

b=tf.constant([-1,2])

aa=tf.pad(a,[[0,0],[1,0],[0,0]])

bb=b+1

index=tf.stack([tf.range(tf.size(b)),bb],axis=-1)
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0, 0, 0, 0]],
#[[20, 21, 22, 23]]]

当索引为-1时,我们需要像张量一样的零。所以我们可以首先沿着第二个轴填充原始张量。然后将索引增加 1。之后,使用 tf.gather_nd 将返回答案。

关于python - 如何对形状为 (batch_size, 200, 256) 的张量进行索引以获得 (batch_size, 1, 256) 给定的长度 = batch_size 的索引张量列表?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58355115/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com