gpt4 book ai didi

python - 加速cython代码

转载 作者:太空宇宙 更新时间:2023-11-04 02:00:10 25 4
gpt4 key购买 nike

我写了一个 python 代码来管理大量数据,因此需要很多时间。所以,我发现了 Cython,并开始更改我的代码。

基本上,我所做的就是更改函数的声明(cdef 类型名称(带有变量类型的参数)),声明 cdef 变量及其类型,并声明 cdef 类。我正在写所有 .pyx使用 Eclipse,我正在使用命令 python setup.py build_ext --inplace 进行编译并使用 eclipse 运行它。

我的问题是比较 python 和 cython 的速度,没有任何区别。

我运行命令 cython -a <file>生成一个html文件,里面有很多黄线。

我不知道我做错了什么,我应该包括其他东西,我不知道如何删除这些黄线。

我只是粘贴了一些代码行,那是我想加快速度的部分,因为代码很长。


主.pyx

'''there are a lot of ndarray objects stored in a file and in this step I get each of them until there are no more items '''
cdef ReadWavePoints (WavePointManagement wavePointManagement, ColumnManagement columnManagement):
cdef int runReadWavePoints

wavePointManagement.OpenWavePointFileLoad(wavePointsFile)
runReadWavePoints = 1

while runReadWavePoints == 1:
try:
wavePointManagement.LoadWavePointFile()
wavePointManagement.RoundCoordinates()
wavePointManagement.SortWavePointList()
GroupColumnsVoxels(wavePointManagement.GetWavePointList(), columnManagement)
except:
wavePointManagement.CloseWavePointFile()
columnManagement.CloseWriteColumnFile()
break

'''I check which points are in the same XYZ (voxel) and in the same XY (column)'''

cdef GroupColumnsVoxels (object wavePointList, ColumnManagement columnManagement):
cdef int indexWavePointRef, indexWavePoint
cdef int saved
cdef double voxelValue
cdef int sizeWavePointList

sizeWavePointList = len(wavePointList)

indexWavePointRef = 0

while indexWavePointRef < sizeWavePointList - 1:
saved = 0
voxelValue = (wavePointList[indexWavePointRef]).GetValue()
for indexWavePoint in xrange(indexWavePointRef + 1, len(wavePointList)):
if (wavePointList[indexWavePointRef]).GetX() == (wavePointList[indexWavePoint]).GetX() and (wavePointList[indexWavePointRef]).GetY() == (wavePointList[indexWavePoint]).GetY():
if (wavePointList[indexWavePointRef]).GetZ() == (wavePointList[indexWavePoint]).GetZ():
if voxelValue < (wavePointList[indexWavePoint]).GetValue():
voxelValue = (wavePointList[indexWavePoint]).GetValue()
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
if indexWavePointRef == sizeWavePointList - 1:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), (wavePointList[indexWavePointRef]).GetValue())
break
else:
saved = 1
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())
indexWavePointRef = indexWavePoint
break
if saved == 0:
CheckVoxel((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ(), voxelValue)
indexWavePointRef = indexWavePoint
columnObject = columnInstance.Column((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY())
columnManagement.AddColumn(columnObject)
MaximumHeightColumn((wavePointList[indexWavePointRef]).GetX(), (wavePointList[indexWavePointRef]).GetY(), (wavePointList[indexWavePointRef]).GetZ())



'''I check if the data stored in a voxel is lower than the new one; if its the case, I store it'''

cdef CheckVoxel (double X, double Y, double Z, double newValue):
cdef object bandVoxel, structvalCheckVoxel, out_str
cdef tuple valueCheckVoxel

bandVoxel = datasetVoxels.GetRasterBand(int(math.floor(Z/0.3))+1)
structvalCheckVoxel = bandVoxel.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueCheckVoxel = struct.unpack('f', structvalCheckVoxel)

if newValue > valueCheckVoxel[0]:
out_str = struct.pack('f', newValue)
bandVoxel.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_str)

'''I check if this point has the highest Z and I store this information'''
cdef MaximumHeightColumn(double X, double Y, double newZ):
cdef object bandMetricMaximumHeightColumn, structvalMaximumHeightColumn, out_strMaximumHeightColumn
cdef tuple valueMaximumHeightColumn

bandMetricMaximumHeightColumn = datasetMetrics.GetRasterBand(10)
structvalMaximumHeightColumn = bandMetricMaximumHeightColumn.ReadRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, buf_type=gdal.GDT_Float32)
valueMaximumHeightColumn = struct.unpack('f', structvalMaximumHeightColumn)

if newZ > round(valueMaximumHeightColumn[0], 1):
out_strMaximumHeightColumn = struct.pack('f', newZ)
bandMetricMaximumHeightColumn.WriteRaster(int(math.floor((X-Xmin)/0.25)), int(math.floor((Ymax-Y)/0.25)), 1, 1, out_strMaximumHeightColumn)

波点管理.pyx

'''this class serializes, rounds and sorts the points of each ndarray'''
import cPickle as pickle
import numpy as np
cimport numpy as np
import math

cdef class WavePointManagement(object):
'''
This class manages all the points extracted from the waveform
'''
cdef object fileObject, wavePointList
__slots__ = ('wavePointList', 'fileObject')

def __cinit__(self):
'''
Constructor
'''

self.fileObject = None
self.wavePointList = np.array([])

cdef object GetWavePointList(self):
return self.wavePointList

cdef void OpenWavePointFileLoad (self, object fileName):
self.fileObject = file(fileName, 'rb')

cdef void LoadWavePointFile (self):
self.wavePointList = None
self.wavePointList = pickle.load(self.fileObject)

cdef void SortWavePointList (self):
self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

cdef void RoundCoordinates (self):
cdef int indexPointObject, sizeWavePointList

for pointObject in self.GetWavePointList():
pointObject.SetX(round(math.floor(pointObject.GetX()/0.25)*0.25, 2))
pointObject.SetY(round(math.ceil(pointObject.GetY()/0.25)*0.25, 2))
pointObject.SetZ(round(math.floor(pointObject.GetZ()/0.3)*0.3, 1))

cdef void CloseWavePointFile(self):
self.fileObject.close()

设置.py

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

import numpy

ext = Extension("main", ["main.pyx"], include_dirs = [numpy.get_include()])

setup (ext_modules=[ext],
cmdclass = {'build_ext' : build_ext}
)

test_cython.py

'''this is the file I run with eclipse after compiling'''
from main import main

main()

我怎样才能加快这段代码的速度?

最佳答案

您的代码在使用 numpy 数组和列表之间来回跳转。因此,cython 生成的代码几乎没有区别。

下面的代码生成一个python列表,key函数也是一个纯python函数。

self.wavePointList = sorted(self.wavePointList, key=lambda k: (k.x, k.y, k.z))

您需要使用 ndarray.sort(如果您不想就地排序,则可以使用 numpy.sort)。为此,您还需要更改对象在数组中的存储方式。也就是说,您需要使用 structured array .参见 numpy.sort有关如何对结构化数组进行排序的示例 - 特别是页面上的最后两个示例。

将数据存储在 numpy 数组中后,您需要告诉 cython 数据是如何存储在数组中的。这包括提供类型信息和数组的维度。 This page提供了有关如何高效使用 numpy 数组的更多信息。

创建和排序结构化数组的示例:

import numpy as np
cimport numpy as np

DTYPE = [('name', 'S10'), ('height', np.float64), ('age', np.int32)]

cdef packed struct Person:
char name[10]
np.float64_t height
np.int32_t age

ctypedef Person DTYPE_t

def create_array():
values = [('Arthur', 1.8, 41), ('Lancelot', 1.9, 38),
('Galahad', 1.7, 38)]
return np.array(values, dtype=DTYPE)

cpdef sort_by_age_then_height(np.ndarray[DTYPE_t, ndim=1] arr):
arr.sort(order=['age', 'height'])

最后,您需要将代码从使用 Python 方法转换为使用标准 C 库方法,以进一步加快速度。下面是一个使用 RoundCoordinates 的例子。 ``cpdef` 意味着该函数也通过包装函数暴露给 python。

cimport cython
cimport numpy as np
from libc.math cimport floor, ceil, round

import numpy as np

DTYPE = [('x', np.float64), ('y', np.float64), ('z', np.float64)]

cdef packed struct Point3D:
np.float64_t x, y, z

ctypedef Point3D DTYPE_t

# Caution should be used when turning the bounds check off as it can lead to undefined
# behaviour if you use an invalid index.
@cython.boundscheck(False)
cpdef RoundCoordinates_cy(np.ndarray[DTYPE_t] pointlist):
cdef int i
cdef DTYPE_t point
for i in range(len(pointlist)): # this line is optimised into a c loop
point = pointlist[i] # creates a copy of the point
point.x = round(floor(point.x/0.25)*2.5) / 10
point.y = round(ceil(point.y/0.25)*2.5) / 10
point.z = round(floor(point.z/0.3)*3) / 10
pointlist[i] = point # overwrites the old point data with the new data

最后,在重写整个代码库之前,您应该分析代码以查看程序花​​费大部分时间的函数并在优化其他函数之前优化这些函数。

关于python - 加速cython代码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/28276078/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com