gpt4 book ai didi

pytorch - 排列后如何进行张量点运算

转载 作者:行者123 更新时间:2023-12-04 08:15:43 24 4
gpt4 key购买 nike

我有 2 个张量,A 和 B:

A = torch.randn([32,128,64,12],dtype=torch.float64)
B = torch.randn([64,12,64,12],dtype=torch.float64)
C = torch.tensordot(A,B,([2,3],[0,1]))
D = C.permute(0,2,1,3) # shape:[32,64,128,12]
张量 D 来自操作“tensordot -> permute”。如何在 f() 之后实现新操作 f() 以进行 tensordot 操作,例如:
A_2 = f(A)
B_2 = f(B)
D = torch.tensordot(A_2,B_2)

最佳答案

您是否考虑过使用 torch.einsum 这是非常灵活的?

D = torch.einsum('ijab,abkl->ikjl', A, B)
tensordot 的问题是它输出 A的所有维度在 B 之前并且您正在寻找的(在排列时)是从 A 中“交错”维度和 B .

关于pytorch - 排列后如何进行张量点运算,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65716540/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com