gpt4 book ai didi

python-3.x - Pytorch - 堆栈维度必须完全相同?

转载 作者:行者123 更新时间:2023-12-04 16:50:38 24 4
gpt4 key购买 nike

在 pytorch 中,给定张量 a形状(1X11)b形状(1X11) , torch.stack((a,b),0)会给我一个形状张量 (2X11)
然而,当 a形状(2X11)b形状(1X11) , torch.stack((a,b),0)将引发错误 cf. “两个张量大小必须完全相同”。

因为这两个张量是模型的输出( 梯度包括 ),我无法将它们转换为 numpy 以使用 np.stack()np.vstack() .

是否有任何可能的解决方案来最小化 GPU 内存使用量?

最佳答案

看来你想用 torch.cat() (沿现有维度连接张量)而不是 torch.stack() (沿新维度连接/堆叠张量):

import torch

a = torch.randn(1, 42, 1, 1)
b = torch.randn(1, 42, 1, 1)

ab = torch.stack((a, b), 0)
print(ab.shape)
# torch.Size([2, 1, 42, 1, 1])

ab = torch.cat((a, b), 0)
print(ab.shape)
# torch.Size([2, 42, 1, 1])
aab = torch.cat((a, ab), 0)
print(aab.shape)
# torch.Size([3, 42, 1, 1])

关于python-3.x - Pytorch - 堆栈维度必须完全相同?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50394505/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com