gpt4 book ai didi

python - 在 Python 中排序的最快方法(没有 cython)

转载 作者:行者123 更新时间:2023-11-28 18:12:46 28 4
gpt4 key购买 nike

我遇到一个问题,我必须使用自定义函数对一个非常大的数组(形状 - 7900000X4X4)进行排序。我用了sorted 但是排序花了1个多小时。我的代码是这样的。

def compare(x,y):
print('DD '+str(x[0]))
if(np.array_equal(x[1],y[1])==True):
return -1
a = x[1].flatten()
b = y[1].flatten()
idx = np.where( (a>b) != (a<b) )[0][0]
if a[idx]<0 and b[idx]>=0:
return 0
elif b[idx]<0 and a[idx]>=0:
return 1
elif a[idx]<0 and b[idx]<0:
if a[idx]>b[idx]:
return 0
elif a[idx]<b[idx]:
return 1
elif a[idx]<b[idx]:
return 1
else:
return 0
def cmp_to_key(mycmp):
class K:
def __init__(self, obj, *args):
self.obj = obj
def __lt__(self, other):
return mycmp(self.obj, other.obj)
return K
tblocks = sorted(tblocks.items(),key=cmp_to_key(compare))

这有效,但我希望它在几秒钟内完成。我认为 python 中的任何直接实现都不能给我所需的性能,所以我尝试了 cython。我的 Cython 代码是这样的,非常简单。

cdef int[:,:] arrr
cdef int size

cdef bool compare(int a,int b):
global arrr,size
cdef int[:] x = arrr[a]
cdef int[:] y = arrr[b]
cdef int i,j
i = 0
j = 0
while(i<size):
if((j==size-1)or(y[j]<x[i])):
return 0
elif(x[i]<y[j]):
return 1
i+=1
j+=1
return (j!=size-1)

def sorted(np.ndarray boxes,int total_blocks,int s):
global arrr,size
cdef int i
cdef vector[int] index = xrange(total_blocks)
arrr = boxes
size = s
sort(index.begin(),index.end(),compare)
return index

这段用 cython 编写的代码用了 33 秒! Cython 是解决方案,但我正在寻找一些可以直接在 python 上运行的替代解决方案。例如麻麻。我尝试了 Numba,但没有得到令人满意的结果。请帮忙!

最佳答案

没有工作示例很难给出答案。我假设,你的 Cython 代码中的 arrr 是一个二维数组,我假设大小是 size=arrr.shape[0]

Numba 实现

import numpy as np
import numba as nb
from numba.targets import quicksort


def custom_sorting(compare_fkt):
index_arange=np.arange(size)

quicksort_func=quicksort.make_jit_quicksort(lt=compare_fkt,is_argsort=False)
jit_sort_func=nb.njit(quicksort_func.run_quicksort)
index=jit_sort_func(index_arange)

return index

def compare(a,b):
x = arrr[a]
y = arrr[b]
i = 0
j = 0
while(i<size):
if((j==size-1)or(y[j]<x[i])):
return False
elif(x[i]<y[j]):
return True
i+=1
j+=1
return (j!=size-1)


arrr=np.random.randint(-9,10,(7900000,8))
size=arrr.shape[0]

index=custom_sorting(compare)

这为生成的测试数据提供了3.85s。但是排序算法的速度在很大程度上取决于数据....

简单示例

import numpy as np
import numba as nb
from numba.targets import quicksort

#simple reverse sort
def compare(a,b):
return a > b

#create some test data
arrr=np.array(np.random.rand(7900000)*10000,dtype=np.int32)
#we can pass the comparison function
quicksort_func=quicksort.make_jit_quicksort(lt=compare,is_argsort=True)
#compile the sorting function
jit_sort_func=nb.njit(quicksort_func.run_quicksort)
#get the result
ind_sorted=jit_sort_func(arrr)

这个实现比 np.argsort 慢了大约 35%,但这在编译代码中使用 np.argsort 时也很常见。

关于python - 在 Python 中排序的最快方法(没有 cython),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/50206440/

28 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com