gpt4 book ai didi

python - 仅使用 NumPy 优化步进函数的代码

转载 作者:行者123 更新时间:2023-12-04 10:34:28 25 4
gpt4 key购买 nike

我正在尝试仅使用 NumPy 函数(或者可能是列表推导式)优化以下代码中的函数“pw”。

from time import time
import numpy as np

def pw(x, udata):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func


N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)

ti = time()
pw(x,udata)
tf = time()
print(tf - ti)

import cProfile
cProfile.run('pw(x,udata)')

cProfile.run 告诉我大部分开销来自 np.where(大约 1 毫秒),但如果可能的话,我想创建更快的代码。似乎按行和按列执行操作会有所不同,除非我弄错了,但我想我已经考虑到了。我知道有时列表理解会更快,但我想不出比我正在做的更快的方法。

Searchsorted 似乎产生了更好的性能,但 1 ms 仍然保留在我的计算机上:
(modified)
def pw(xx, uu):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds != uu.shape[0]]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))

pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func

answer使用分段似乎非常接近,但这是在标量 x0 和 x1 上。我将如何在阵列上做到这一点?它会更有效率吗?

可以理解,x 可能相当大,但我正在尝试对其进行压力测试。

我仍在学习,所以一些可以帮助我的提示或技巧会很棒。

编辑

第二个函数似乎有错误,因为第二个函数的结果数组与第一个不匹配(我相信它可以工作):
N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)

产量
15000

不相同的数据点。所以似乎我不知道如何使用'searchsorted'。

编辑 2

实际上我修复了它:
pw_func = vals[inds[inds != uu.shape[0]]]

改为
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]

所以至少结果数组匹配。但问题仍然是是否有更有效的方法来做到这一点。

编辑 3

感谢天丽指出错误。这个应该工作
pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]

也许一种更易读的呈现方式是
non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1 # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]

我想我迷失在所有这些括号中!我想这就是可读性的重要性。

最佳答案

一个非常抽象但有趣的问题!谢谢你招待我,我玩得很开心:)

附言我不确定你的 pw2我无法获得与 pw1 相同的输出.

供引用原文pw s:

def pw1(x, udata):
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func

def pw2(xx, uu):
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))

pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func

我的第一次尝试是利用来自 numpy 的大量广播操作。 :

def pw3(x, udata):
# the None slice is to create new axis
step_bool = x >= udata[None,:].T

# we exploit the fact that bools are integer value of 1s
# skipping the last value in "data"
step_vals = np.sum(step_bool[:-1], axis=0)

# for the step_bool that we skipped from previous step (last index)
# we set it to zerp so that we can negate the step_vals once we reached
# the last value in "data"
step_vals[step_bool[-1]] = 0

return step_vals

看了之后 searchsorted来自您的 pw2我有一种新方法可以以更高的性能利用它:

def pw4(x, udata):
inds = np.searchsorted(udata, x, side='right')

# fix-ups the last data if x is already out of range of data[-1]
if x[-1] > udata[-1]:
inds[inds == inds[-1]] = 0

return inds

情节与:

plt.plot(pw1(x,udata.reshape(udata.shape[0],1)), label='pw1')
plt.plot(pw2(x,udata), label='pw2')
plt.plot(pw3(x,udata), label='pw3')
plt.plot(pw4(x,udata), label='pw4')

data = [1,3,4,5,5,7] :

enter image description here

data = [1,3,4,5,5,7,11]
enter image description here
pw1 , pw3 , pw4都是一样的

print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw3(x,udata)))
>>> True
print(np.all(pw1(x,udata.reshape(udata.shape[0],1)) == pw4(x,udata)))
>>> True

性能: ( timeit 默认运行 3 次,平均 number=N 次)

print(timeit.Timer('pw1(x,udata.reshape(udata.shape[0],1))', "from __main__ import pw1, x, udata").repeat(number=1000))
>>> [3.1938983199979702, 1.6096494779994828, 1.962694135003403]
print(timeit.Timer('pw2(x,udata)', "from __main__ import pw2, x, udata").repeat(number=1000))
>>> [0.6884554479984217, 0.6075002400029916, 0.7799002879983163]
print(timeit.Timer('pw3(x,udata)', "from __main__ import pw3, x, udata").repeat(number=1000))
>>> [0.7369808239964186, 0.7557657590004965, 0.8088172269999632]
print(timeit.Timer('pw4(x,udata)', "from __main__ import pw4, x, udata").repeat(number=1000))
>>> [0.20514375300263055, 0.20203858999957447, 0.19906871100101853]

关于python - 仅使用 NumPy 优化步进函数的代码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/60253834/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com