gpt4 book ai didi

matrix-multiplication - PyTorch 张量沿任意轴的乘积 à la NumPy 的 `tensordot`

转载 作者:行者123 更新时间:2023-12-05 01:17:10 25 4
gpt4 key购买 nike

NumPy 提供了非常有用的 tensordot功能。它允许您计算两个 ndarrays 沿任何轴(其大小匹配)的乘积。我很难在 PyTorch 中找到类似的东西。 mm 仅适用于二维数组,matmul 具有一些不受欢迎的广播属性。

我错过了什么吗?我真的打算使用 mm reshape 阵列以模仿我想要的产品吗?

最佳答案

原来的答案是完全正确的,但作为更新,Pytorch now supports tensordot本地人。与 numpy 相同的调用签名,但将 axes 更改为 dims

import torch
import numpy as np

a = np.arange(36.).reshape(3,4,3)
b = np.arange(24.).reshape(4,3,2)
c = np.tensordot(a, b, axes=([1,0],[0,1]))
print(c)
# [[ 2640. 2838.] [ 2772. 2982.] [ 2904. 3126.]]

a = torch.from_numpy(a)
b = torch.from_numpy(b)
c = torch.tensordot(a, b, dims=([1,0],[0,1]))
print(c)
# tensor([[ 2640., 2838.], [ 2772., 2982.], [ 2904., 3126.]], dtype=torch.float64)

关于matrix-multiplication - PyTorch 张量沿任意轴的乘积 à la NumPy 的 `tensordot`,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/51266507/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com