gpt4 book ai didi

python - 使用 numba 索引 numpy 数组时出现类型错误

转载 作者:太空宇宙 更新时间:2023-11-03 16:10:55 24 4
gpt4 key购买 nike

我需要根据另一个包含类成员资格信息的数组(标签)对一维 numpy 数组(如下:data)中的元素求和>)。我在下面的代码中使用 numba 来加快速度。但是,如果我没有在 ret[int(find(labels, g))] += y 行中显式使用 int() 进行转换,我会收到一条错误消息:

类型错误:不支持的数组索引类型?int64

是否有比显式转换更好的解决方法?

import numpy as np
from numba import jit

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)

@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct

@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret

最佳答案

问题是 find 可以返回一个 intNone 如果它没有找到任何东西,因此我认为 >?int64 错误。为了避免强制转换,当 find 退出而没有找到所需的值时,您需要提供一个 int 返回值,然后在调用者中处理它。

关于python - 使用 numba 索引 numpy 数组时出现类型错误,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/39316939/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com