gpt4 book ai didi

python - PyTorch 张量 - 使用给定的结束索引列表进行矢量化切片

转载 作者:行者123 更新时间:2023-12-04 15:03:45 32 4
gpt4 key购买 nike

假设我有一个一维 PyTorch 张量 end_index长度为L。

我想构造一个 2D PyTorch 张量 T有 L 行,其中 T[i,j] = 2什么时候j < end_index[i]T[i,j] = 1否则。

以下作品:

T = torch.ones([4,3], dtype=torch.long)
for element in end_index:
T[:, :element] = 2

有没有矢量化的方法来做到这一点?

最佳答案

您可以使用 broadcast semantics 构建这样的张量

# sample inputs
L, C = 4, 3
end_index = torch.tensor([0, 2, 2, 1])

# Construct tensor of shape [L, C] such that for all (i, j)
# T[i, j] = 2 if j < end_index[i] else 1
j_range = torch.arange(C, device=end_index.device)
T = (j_range[None, :] < end_index[:, None]).long() + 1

结果

T = 
tensor([[1, 1, 1],
[2, 2, 1],
[2, 2, 1],
[2, 1, 1]])

关于python - PyTorch 张量 - 使用给定的结束索引列表进行矢量化切片,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/66520261/

32 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com