gpt4 book ai didi

python - PyTorch - 在 torch.sort 之后取回原始张量顺序的更好方法

转载 作者:行者123 更新时间:2023-12-02 11:17:08 36 4
gpt4 key购买 nike

我想在 torch.sort 之后取回原始张量顺序操作和对排序张量的一些其他修改,以便张量不再排序。最好用一个例子来解释这一点:

x = torch.tensor([30., 40., 20.])
ordered, indices = torch.sort(x)
# ordered is [20., 30., 40.]
# indices is [2, 0, 1]
ordered = torch.tanh(ordered) # it doesn't matter what operation is
final = original_order(ordered, indices)
# final must be equal to torch.tanh(x)

我以这种方式实现了该功能:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
for i in range(ordered.size(0)):
z[indices[i]] = ordered[i]
return z

有一个更好的方法吗?特别是,可以避免循环并更有效地计算操作吗?

就我而言,我有一个大小为 torch.Size([B, N]) 的张量然后我对 B 中的每一个进行排序单独行,一次调用 torch.sort .所以,我必须打电话 original_order B次与另一个循环。

任何更多 pytorch-ic 的想法?

编辑 1 - 摆脱内循环

我通过以这种方式简单地用索引索引 z 解决了部分问题:
def original_order(ordered, indices):
z = torch.empty_like(ordered)
z[indices] = ordered
return z

现在,我只需要了解如何避免 B 上的外循环。尺寸。

编辑 2 - 摆脱外循环
def original_order(ordered, indices, batch_size):
# produce a vector to shift indices by lenght of the vector
# times the batch position
add = torch.linspace(0, batch_size-1, batch_size) * indices.size(1)


indices = indices + add.long().view(-1,1)

# reduce tensor to single dimension.
# Now the indices take in consideration the new length
long_ordered = ordered.view(-1)
long_indices = indices.view(-1)

# we are in the previous case with one dimensional vector
z = torch.zeros_like(long_ordered).float()
z[long_indices] = long_ordered

# reshape to get back to the correct dimension
return z.view(batch_size, -1)

最佳答案

def original_order(ordered, indices):
return ordered.gather(1, indices.argsort(1))
例子
original = torch.tensor([
[20, 22, 24, 21],
[12, 14, 10, 11],
[34, 31, 30, 32]])
sorted, index = original.sort()
unsorted = sorted.gather(1, index.argsort(1))
assert(torch.all(original == unsorted))
为什么有效
为简单起见,想象 t = [30, 10, 20] ,省略张量符号。 t.sort()给我们排序的张量 s = [10, 20, 30] ,以及排序索引 i = [1, 2, 0]免费。 i实际上是 t.argsort() 的输出. i告诉我们怎么走 ts . “要将 t 排序到 s 中,从 t 中取元素 1,然后是 2,然后是 0”。 Argsorting i给我们另一个排序索引 j = [2, 0, 1] ,它告诉我们如何从 i 开始到自然数的规范序列 [0, 1, 2] ,实际上颠倒了排序。另一种看待它的方式是 j告诉我们怎么走 st . “要将 s 排序到 t 中,从 s 中取元素 2,然后是 0,然后是 1”。对排序索引进行 Argsorting 为我们提供了它的“反向索引”,反之亦然。
现在我们有了反向索引,我们将它转​​储到 torch.gather()使用正确的 dim ,这对张量进行了排序。
来源
torch.gather
torch.argsort
我在研究这个问题时找不到这个确切的解决方案,所以我认为这是一个原始答案。

关于python - PyTorch - 在 torch.sort 之后取回原始张量顺序的更好方法,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52127723/

36 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com