gpt4 book ai didi

python - 如何有效地乘以具有重复行的 torch 张量而不将所有行存储在内存中或迭代?

转载 作者:行者123 更新时间:2023-12-04 07:40:37 25 4
gpt4 key购买 nike

给定一个火炬张量:

# example tensor size 2 x 4
a = torch.Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
另一个是每 n 行重复一次:
# example tensor size 4 x 3 where every 2 rows repeated
b = torch.Tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
如何执行矩阵乘法:
>>> torch.mm(a, b)
tensor([[ 28., 38., 48.],
[ 68., 94., 120.]])
不将整个重复行张量复制到内存中或迭代?
即只存储前两行:
# example tensor size 2 x 3 where only the first two rows from b are actually stored in memory
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
因为这些行将重复。
有一个功能
torch.expand()
但这在重复多行时确实有效,而且,正如这个问题:
Repeating a pytorch tensor without copying memory
指示并且我自己的测试证实在调用时通常最终将整个张量复制到内存中
.to(device)
也可以迭代地执行此操作,但速度相对较慢。
是否有某种方法可以有效地执行此操作而不将整个重复行张量存储在内存中?
编辑说明:
抱歉,最初没有澄清:一个被用作第一个张量的第一维以保持示例简单,但我实际上正在寻找任何两个张量 a 和 b 的一般情况的解决方案,以便它们的尺寸兼容矩阵乘法和 b 的行每 n 行重复一次。我已经更新了示例以反射(reflect)这一点。

最佳答案

假设a的第一维在您的示例中为 1,您可以执行以下操作:

a = torch.Tensor([[1, 2, 3, 4]])
b_abbreviated = torch.Tensor([[1, 2, 3], [4, 5, 6]])
torch.mm(a.reshape(-1, 2), b_abbreviated).sum(axis=0, keepdim=True)
在这里,不是重复行,而是乘以 a块,然后将它们按列相加以获得相同的结果。

如果 a的第一个维度不一定是 1,您可以尝试以下操作:
torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1).sum(
dim=0, keepdim=True).reshape(a.shape[0], -1)
在这里,您执行以下操作:
  • torch.mm(a.reshape(-1,2),b_abbreviated ,您再次拆分 a 的每一行成大小为 2 的块并将它们一个叠放在另一个上面,然后将每一行叠放在另一行上。
  • torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]) ,这些堆栈然后按行分隔,以便拆分的每个结果组件对应于单行的块。
  • torch.cat(torch.split(torch.mm(a.reshape(-1,2),b_abbreviated), a.shape[0]), dim=1)然后将这些堆栈按列连接。
  • .sum(dim=0, keepdim=True) , 结果对应于 a 中各个行的单独块加起来。
  • .reshape(a.shape[0], -1) , 行 a按列连接的再次按行堆叠。

  • 与直接矩阵乘法相比,它似乎相当慢,这并不奇怪,但与显式迭代相比,我还没有检查过。可能有更好的方法可以做到这一点,如果我想到任何方法,将进行编辑。

    关于python - 如何有效地乘以具有重复行的 torch 张量而不将所有行存储在内存中或迭代?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/67496315/

    25 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com