gpt4 book ai didi

python - 使用 numba jit 提高 python 脚本的性能

转载 作者:行者123 更新时间:2023-12-04 10:14:59 24 4
gpt4 key购买 nike

我正在运行一个示例 python 模拟来预测加权和常规骰子。我想使用 numba 来帮助加速我的脚本,但我收到一个错误:

<timed exec>:6: NumbaWarning: 
Compilation is falling back to object mode WITH looplifting enabled because Function "roll" failed type inference due to: Untyped global name 'sum': cannot determine Numba type of <class 'builtin_function_or_method'>

File "<timed exec>", line 9:
<source missing, REPL/exec in use?>

这是我的原始代码:我可以使用另一种类型的 numba 表达式吗?现在我正在使用 2500 卷的输入进行测试;希望将其缩短到 4 秒(目前为 8.5 秒)。

%%time
from numba import jit
import random
import matplotlib.pyplot as plt
import numpy

@jit
def roll(sides, bias_list):
assert len(bias_list) == sides, "Enter correct number of dice sides"
number = random.uniform(0, sum(bias_list))
current = 0
for i, bias in enumerate(bias_list):
current += bias
if number <= current:
return i + 1

no_of_rolls = 2500
weighted_die = {}
normal_die = {}
#weighted die

for i in range(no_of_rolls):
weighted_die[i+1]=roll(6,(0.15, 0.15, 0.15, 0.15, 0.15, 0.25))
#regular die
for i in range(no_of_rolls):
normal_die[i+1]=roll(6,(0.167, 0.167, 0.167, 0.167, 0.167, 0.165))

plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()

最佳答案

使用随机选择

重构代码

import random
import matplotlib.pyplot as plt

no_of_rolls = 2500

# weights
normal_weights = (0.167, 0.167, 0.167, 0.167, 0.167, 0.165)
bias_weights = (0.15, 0.15, 0.15, 0.15, 0.15, 0.25)

# Replaced roll function with random.choices
# Reference: https://www.w3schools.com/python/ref_random_choices.asp
bias_rolls = random.choices(range(1, 7), weights = bias_weights, k = no_of_rolls)
normal_rolls = random.choices(range(1, 7), weights = normal_weights, k = no_of_rolls)

# Create dictionaries with same structure as posted code
weighted_die = dict(zip(range(no_of_rolls), bias_rolls))
normal_die = dict(zip(range(no_of_rolls), normal_rolls))

# Use posted plotting calls
plt.bar(*zip(*weighted_die.items()))
plt.show()
plt.bar(*zip(*normal_die.items()))
plt.show()

性能

*Not including plotting.*
Original code: ~6 ms
Revised code: ~2 ms
(3x improvement, but not sure why the post mentions 8 seconds to run)

关于python - 使用 numba jit 提高 python 脚本的性能,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/61109780/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com