gpt4 book ai didi

python - 如何将 PyTorch 的 torch.inverse() 函数应用于批处理中的每个样本?

转载 作者:太空狗 更新时间:2023-10-30 00:58:57 24 4
gpt4 key购买 nike

这似乎是一个基本问题,但我无法解决。

在我的神经网络的正向传递中,我有一个形状为 8x3x3 的输出张量,其中 8 是我的批量大小。我们可以假设每个 3x3 张量是一个非奇异矩阵。我需要找到这些矩阵的逆矩阵。PyTorch inverse()函数仅适用于方阵。由于我现在有 8x3x3,我如何以可微分的方式将此函数应用于批处理中的每个矩阵?

如果我遍历样本并将逆向追加到 python 列表,然后将其转换为 PyTorch 张量,那么在反向传播期间是否会出现问题? (我问的是因为将 PyTorch 张量转换为 numpy 以执行某些操作然后返回到张量不会在此类操作的反向传播期间计算梯度)

当我尝试做类似的事情时,我也会收到以下错误。

a = torch.arange(0,8).view(-1,2,2)
b = [m.inverse() for m in a]
c = torch.FloatTensor(b)

TypeError: 'torch.FloatTensor' object does not support indexing

最佳答案

编辑:

从 Pytorch 1.0 版开始,torch.inverse 现在支持批量张量。参见 here .所以你可以简单地使用内置函数 torch.inverse

旧答案

有计划很快实现批量逆。有关讨论,请参见示例 issue 7500issue 9102 .但是,截至本文撰写之时,当前的稳定版本(0.4.1),尚无批量反演操作可用。

话虽如此,最近添加了对 torch.gesv 的批处理支持。这可以(ab)用于沿着以下几行定义您自己的批量逆运算:

def b_inv(b_mat):
eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
b_inv, _ = torch.gesv(eye, b_mat)
return b_inv

我发现在 GPU 上运行时,这比 for 循环提供了很好的加速。

关于python - 如何将 PyTorch 的 torch.inverse() 函数应用于批处理中的每个样本?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/46595157/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com