gpt4 book ai didi

python - Tensorflow:来自索引的 "staggered"序列掩码?

转载 作者:太空宇宙 更新时间:2023-11-03 23:59:54 25 4
gpt4 key购买 nike

输入如下:

[1, 3, 2]

期望的输出(在适当的张量中):

[1 0 0 
0 1 0
0 1 0
0 1 0
0 0 1
0 0 1]

即,与 tf.sequence_mask 非常相似(它会给出如下内容:

[1 1 1
0 1 1
0 1 0]

),但每个后续元素都“交错”在前一个序列掩码完成后开始。

非常感谢帮助。

最佳答案

这可以通过采用大小等于输入中元素数量的方形单位矩阵然后通过应用 tf.tile() inputs[i] 单位矩阵中每一行 i 的次数:

import tensorflow as tf

inputs = tf.constant([1, 3, 2])

unit = tf.eye(num_rows=inputs.get_shape().as_list()[0])
unstacked = tf.unstack(unit)
tiled = [tf.tile(u[None, ...], multiples=[inputs[i], 1])
for i, u in enumerate(unstacked)]
res = tf.concat(tiled, axis=0)

with tf.Session() as sess:
print(sess.run(res))
# [[1. 0. 0.]
# [0. 1. 0.]
# [0. 1. 0.]
# [0. 1. 0.]
# [0. 0. 1.]
# [0. 0. 1.]]

关于python - Tensorflow:来自索引的 "staggered"序列掩码?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/56047834/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com