gpt4 book ai didi

python - pytorch中矩阵和向量之间的加/减

转载 作者:行者123 更新时间:2023-11-28 17:05:54 28 4
gpt4 key购买 nike

我想在pytorch中的矩阵和向量之间做+/-/*。我怎样才能有好的表现?我尝试使用扩展,但它真的很慢(我使用的是带有小向量的大矩阵)。

a = torch.rand(2,3)
print(a)
0.7420 0.2990 0.3896
0.0715 0.6719 0.0602
[torch.FloatTensor of size 2x3]
b = torch.rand(2)
print(b)
0.3773
0.6757
[torch.FloatTensor of size 2]
a.add(b)
Traceback (most recent call last):
File "C:\ProgramData\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3066, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-17-a1cb1b03d031>", line 1, in <module>
a.add(b)
RuntimeError: inconsistent tensor size, expected r_ [2 x 3], t [2 x 3] and src [2] to have the same number of elements, but got 6, 6 and 2 elements respectively at c:\miniconda2\conda-bld\pytorch-cpu_1519449358620\work\torch\lib\th\generic/THTensorMath.c:1021

预期结果:

 0.7420-0.3773  0.2990-0.3773  0.3896-0.3773
0.0715-0.6757 0.6719-0.6757 0.0602-0.6757

最佳答案

要使用广播,您需要将张量 b 的维数提升到二维,因为张量 a 是二维的。

In [43]: a
Out[43]:
tensor([[ 0.9455, 0.2088, 0.1070],
[ 0.0823, 0.6509, 0.1171]])

In [44]: b
Out[44]: tensor([ 0.4321, 0.8250])

# subtraction
In [46]: a - b[:, None]
Out[46]:
tensor([[ 0.5134, -0.2234, -0.3252],
[-0.7427, -0.1741, -0.7079]])

# alternative way to do subtraction
In [47]: a.sub(b[:, None])
Out[47]:
tensor([[ 0.5134, -0.2234, -0.3252],
[-0.7427, -0.1741, -0.7079]])

# yet another approach
In [48]: torch.sub(a, b[:, None])
Out[48]:
tensor([[ 0.5134, -0.2234, -0.3252],
[-0.7427, -0.1741, -0.7079]])

其他操作(+*)可以类推。


就性能而言,使用一种方法似乎没有优于其他方法的优势。只需使用三种方法中的任何一种即可。

In [49]: a = torch.rand(2000, 3000)
In [50]: b = torch.rand(2000)

In [51]: %timeit torch.sub(a, b[:, None])
2.4 ms ± 8.31 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [52]: %timeit a.sub(b[:, None])
2.4 ms ± 6.94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [53]: %timeit a - b[:, None]
2.4 ms ± 12 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

关于python - pytorch中矩阵和向量之间的加/减,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51097719/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com