gpt4 book ai didi

pytorch - 如何将 Numba 用于 Pytorch 张量?

转载 作者:行者123 更新时间:2023-12-04 00:16:45 39 4
gpt4 key购买 nike

我是 Numba 的新手,我需要使用 Numba 来加速一些 Pytorch 功能。但我发现即使是一个非常简单的功能也不起作用:(

import torch
import numba

@numba.njit()
def vec_add_odd_pos(a, b):
res = 0.
for pos in range(len(a)):
if pos % 2 == 0:
res += a[pos] + b[pos]
return res

x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)
但是出现以下错误
def vec_add_odd_pos(a, b):
资源 = 0。
^
此错误可能是由以下参数引起的:
  • 参数 0:无法确定 的 Numba 类型
  • 参数 1:无法确定 的 Numba 类型

  • 谁能帮我?包含更多示例的链接也将不胜感激。谢谢。

    最佳答案

    正如其他人所提到的,numba 目前不支持火炬张量,只支持 numpy 张量。然而有TorchScript ,它有一个类似的目标。然后可以将您的函数重写为:

    import torch

    @torch.jit.script
    def vec_add_odd_pos(a, b):
    res = 0.
    for pos in range(len(a)):
    if pos % 2 == 0:
    res += a[pos] + b[pos]
    return res

    x = torch.tensor([3, 4, 5.])
    y = torch.tensor([-2, 0, 1.])
    z = vec_add_odd_pos(x, y)
    请注意:虽然您说您的代码片段只是一个简单的示例,但 for 循环确实很慢并且运行 TorchScript 可能对您没有太大帮助,您应该不惜一切代价避免它们,并且只有在不存在其他解决方案时才使用它们。话虽如此,以下是如何以更高效的方式实现您的功能:
    def vec_add_odd_pos(a, b):
    evenids = torch.arange(len(a)) % 2 == 0
    return (a[evenids] + b[evenids]).sum()

    关于pytorch - 如何将 Numba 用于 Pytorch 张量?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/63169760/

    39 4 0
    Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
    广告合作:1813099741@qq.com 6ren.com