gpt4 book ai didi

python - 在二维张量上滚动具有可变步长值的行

转载 作者:行者123 更新时间:2023-12-04 07:19:04 25 4
gpt4 key购买 nike

我有一个张量 a我想首先使用 mask 进行掩码然后丢弃剩余的帧。为了确保输出张量的形状正确,填充应该在最后填充剩余的值。我可以假设只有一个连续序列 True位于掩码的每一行中。
例如

a = torch.arange(1,17).reshape(4,4)
# tensor([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12],
# [13, 14, 15, 16]])

mask = torch.tensor([[False, True, True, False],
[False, True, True, True],
[ True, False, False, False],
[ True, True, True, True]])

# desired output (assuming padding value is 0):
# tensor([[ 2, 3, 0, 0],
# [ 6, 7, 8, 0],
# [ 9, 0, 0, 0],
# [13, 14, 15, 16]])
我可以通过应用 torch.masked_select 来实现所需的输出其次是 torch.nn.functional.pad 在循环中的每一行,但我正在努力想办法更有效地批量执行此操作。
我也开始考虑使用 torch.roll 并在适当的索引后归零,但此功能只能应用于整个维度,而不是每行的自定义滚动量。

最佳答案

通过申请 torch.sort 在面具本身上,您可以达到预期的效果。事实上,如果您对 bool 值进行排序,您可以设法移动 False堆栈末尾的值,并让 True开始时的值。
请注意,这可能会因排序算法而异,某些算法可能会有一些改组.... 如 @Seraf Fej指出:您可以使用stable=True torch.stable 上的选项以便保留等效项的顺序。
然后使用排序的索引来收集 a 上的值与 torch.gather .最后,您需要屏蔽生成的矩阵以使用适当的填充替换丢弃的值。

>>> a
tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])

>>> mask
tensor([[False, True, True, False],
[False, True, True, True],
[ True, False, False, False],
[ True, True, True, True]])
对掩码进行排序:
>>> values, indices = mask.sort(1, descending=True, stable=True)

>>> values
tensor([[ True, True, False, False],
[ True, True, True, False],
[ True, False, False, False],
[ True, True, True, True]])

>>> indices
tensor([[1, 2, 0, 3],
[1, 2, 3, 0],
[0, 1, 2, 3],
[0, 1, 2, 3]])
收集自 indices和面具 values :
>>> a.gather(1, indices)*values
tensor([[ 2, 3, 0, 0],
[ 6, 7, 8, 0],
[ 9, 0, 0, 0],
[13, 14, 15, 16]])

您可以使用 torch.where 轻松扩展到任何填充值:
>>> torch.where(values, a.gather(1, indices), -1)
tensor([[ 2, 3, -1, -1],
[ 6, 7, 8, -1],
[ 9, -1, -1, -1],
[13, 14, 15, 16]])
或者使用反向掩码 ~values ,由填充值加权:
>>> a.gather(1, indices)*values -1*~values
tensor([[ 2, 3, -1, -1],
[ 6, 7, 8, -1],
[ 9, -1, -1, -1],
[13, 14, 15, 16]])

关于python - 在二维张量上滚动具有可变步长值的行,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/68621175/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com