gpt4 book ai didi

indexing - 张量的 torch 逻辑索引

转载 作者:行者123 更新时间:2023-12-04 12:05:48 24 4
gpt4 key购买 nike

我正在寻找一种优雅的方法来选择满足某些约束的火炬张量的子集。
例如,假设我有:

A = torch.rand(10,2)-1

S是一个 10x1 张量,
sel = torch.ge(S,5) -- this is a ByteTensor

我希望能够进行逻辑索引,如下所示:
A1 = A[sel]

但这不起作用。
所以有 index接受 LongTensor 的函数但我找不到一种简单的方法来转换 SLongTensor ,以下情况除外:
sel = torch.nonzero(sel)

它返回一个 K x 2 张量(K 是 S >= 5 的值的数量)。那么我必须将其转换为一维数组,这最终允许我对 A 进行索引:
A:index(1,torch.squeeze(sel:select(2,1)))

这很麻烦;在例如Matlab 我所要做的就是
A(S>=5,:)

任何人都可以提出更好的方法吗?

最佳答案

一种可能的选择是:

sel = S:ge(5):expandAs(A)   -- now you can use this mask with the [] operator
A1 = A[sel]:unfold(1, 2, 2) -- unfold to get back a 2D tensor

例子:
> A = torch.rand(3,2)-1
-0.0047 -0.7976
-0.2653 -0.4582
-0.9713 -0.9660
[torch.DoubleTensor of size 3x2]

> S = torch.Tensor{{6}, {1}, {5}}
6
1
5
[torch.DoubleTensor of size 3x1]

> sel = S:ge(5):expandAs(A)
1 1
0 0
1 1
[torch.ByteTensor of size 3x2]

> A[sel]
-0.0047
-0.7976
-0.9713
-0.9660
[torch.DoubleTensor of size 4]

> A[sel]:unfold(1, 2, 2)
-0.0047 -0.7976
-0.9713 -0.9660
[torch.DoubleTensor of size 2x2]

关于indexing - 张量的 torch 逻辑索引,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/36343199/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com