gpt4 book ai didi

python - Matplotlib 的绘图非常缓慢

转载 作者:太空狗 更新时间:2023-10-30 03:04:50 24 4
gpt4 key购买 nike

谁能帮忙优化python中的plot函数?我使用 Matplotlib 绘制财务数据。这里是绘制 OHLC 数据的小函数。如果我添加指标或其他数据,时间会显着增加。

import numpy as np
import datetime
from matplotlib.collections import LineCollection
from pylab import *
import urllib2

def test_plot(OHLCV):

bar_width = 1.3
date_offset = 0.5
fig = figure(figsize=(50, 20), facecolor='w')
ax = fig.add_subplot(1, 1, 1)
labels = ax.get_xmajorticklabels()
setp(labels, rotation=0)

month = MonthLocator()
day = DayLocator()
timeFmt = DateFormatter('%Y-%m-%d')

colormap = OHLCV[:,1] < OHLCV[:,4]
color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
color[:] = 'red'
color[np.where(colormap)] = 'green'
dates = date2num( OHLCV[:,0])

lines_hl = LineCollection( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
lines_hl.set_color(color)
lines_hl.set_linewidth(bar_width)
lines_op = LineCollection( zip(zip((np.array(dates) - date_offset).tolist(), OHLCV[:,1]), zip((np.array(dates)).tolist(), parsed_table[:,1])))
lines_op.set_color(color)
lines_op.set_linewidth(bar_width)
lines_cl = LineCollection( zip(zip((np.array(dates) + date_offset).tolist(), OHLCV[:,4]), zip((np.array(dates)).tolist(), parsed_table[:,4])))
lines_cl.set_color(color)
lines_cl.set_linewidth(bar_width)
ax.add_collection(lines_hl, autolim=True)
ax.add_collection(lines_cl, autolim=True)
ax.add_collection(lines_op, autolim=True)

ax.xaxis.set_major_locator(month)
ax.xaxis.set_major_formatter(timeFmt)
ax.xaxis.set_minor_locator(day)

ax.autoscale_view()

ax.xaxis.grid(True, 'major')
ax.grid(True)

ax.set_title('EOD test plot')
ax.set_xlabel('Date')
ax.set_ylabel('Price , $')
fig.savefig('test.png', dpi = 50, bbox_inches='tight')
close()

if __name__=='__main__':

data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
parsed_table = []
#Format: Date, Open, High, Low, Close, Volume
dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

for row in data_table:

field = row.strip().split(',')[:-1]
data_tmp = [i(j) for i,j in zip(dtype, field)]
parsed_table.append(data_tmp)

parsed_table = np.array(parsed_table)

import time
bf = time.time()
count = 100
for i in xrange(count):
test_plot(parsed_table)
print('Plot time: %s' %(time.time() - bf) / count)

结果是这样的。每个地 block 的平均执行时间约为 2.6 秒。在 R 中绘制图表要快得多,但我没有测量性能并且我不想使用 Rpy,所以我相信我的代码效率低下。 enter image description here

最佳答案

此解决方案重用 Figure 实例并异步保存绘图。您可以将其更改为拥有与处理器一样多的图形,异步绘制那么多绘图,这样速度应该会更快。实际上,每个图需要大约 1 秒,低于我机器上的 2.6。

import numpy as np
import datetime
import urllib2
import time
import multiprocessing as mp
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from pylab import *
from matplotlib.collections import LineCollection

class AsyncPlotter():
def __init__(self, processes=mp.cpu_count()):
self.manager = mp.Manager()
self.nc = self.manager.Value('i', 0)
self.pids = []
self.processes = processes

def async_plotter(self, nc, fig, filename, processes):
while nc.value >= processes:
time.sleep(0.1)
nc.value += 1
print "Plotting " + filename
fig.savefig(filename)
plt.close(fig)
nc.value -= 1

def save(self, fig, filename):
p = mp.Process(target=self.async_plotter,
args=(self.nc, fig, filename, self.processes))
p.start()
self.pids.append(p)

def join(self):
for p in self.pids:
p.join()

class FinanceChart():
def __init__(self, async_plotter):
self.async_plotter = async_plotter
self.bar_width = 1.3
self.date_offset = 0.5
self.fig = plt.figure(figsize=(50, 20), facecolor='w')
self.ax = self.fig.add_subplot(1, 1, 1)
self.labels = self.ax.get_xmajorticklabels()
setp(self.labels, rotation=0)
line_hl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
line_op = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))
line_cl = LineCollection(([[(734881,1), (734882,5), (734883,9), (734889,5)]]))

self.lines_hl = self.ax.add_collection(line_hl, autolim=True)
self.lines_op = self.ax.add_collection(line_cl, autolim=True)
self.lines_cl = self.ax.add_collection(line_op, autolim=True)

self.ax.set_title('EOD test plot')
self.ax.set_xlabel('Date')
self.ax.set_ylabel('Price , $')

month = MonthLocator()
day = DayLocator()
timeFmt = DateFormatter('%Y-%m-%d')
self.ax.xaxis.set_major_locator(month)
self.ax.xaxis.set_major_formatter(timeFmt)
self.ax.xaxis.set_minor_locator(day)

def test_plot(self, OHLCV, i):
colormap = OHLCV[:,1] < OHLCV[:,4]
color = np.zeros(colormap.__len__(), dtype = np.dtype('|S5'))
color[:] = 'red'
color[np.where(colormap)] = 'green'
dates = date2num( OHLCV[:,0])
date_array = np.array(dates)
xmin = min(dates)
xmax = max(dates)
ymin = min(OHLCV[:,1])
ymax = max(OHLCV[:,1])

self.lines_hl.set_segments( zip(zip(dates, OHLCV[:,2]), zip(dates, OHLCV[:,3])))
self.lines_hl.set_color(color)
self.lines_hl.set_linewidth(self.bar_width)
self.lines_op.set_segments( zip(zip((date_array - self.date_offset).tolist(), OHLCV[:,1]), zip(date_array.tolist(), OHLCV[:,1])))
self.lines_op.set_color(color)
self.lines_op.set_linewidth(self.bar_width)
self.lines_cl.set_segments( zip(zip((date_array + self.date_offset).tolist(), OHLCV[:,4]), zip(date_array.tolist(), OHLCV[:,4])))
self.lines_cl.set_color(color)
self.lines_cl.set_linewidth(self.bar_width)

self.ax.set_xlim(xmin,xmax)
self.ax.set_ylim(ymin,ymax)

self.ax.xaxis.grid(True, 'major')
self.ax.grid(True)
self.async_plotter.save(self.fig, '%04i.png'%i)

if __name__=='__main__':
print "Starting"
data_table = urllib2.urlopen(r"http://ichart.finance.yahoo.com/table.csv?s=IBM&a=00&b=1&c=2012&d=00&e=15&f=2013&g=d&ignore=.csv").readlines()[1:][::-1]
parsed_table = []
#Format: Date, Open, High, Low, Close, Volume
dtype = (lambda x: datetime.datetime.strptime(x, '%Y-%m-%d').date(),float, float, float, float, int)

for row in data_table:
field = row.strip().split(',')[:-1]
data_tmp = [i(j) for i,j in zip(dtype, field)]
parsed_table.append(data_tmp)

parsed_table = np.array(parsed_table)
import time
bf = time.time()
count = 10

a = AsyncPlotter()
_chart = FinanceChart(a)

print "Done with startup tasks"
for i in xrange(count):
_chart.test_plot(parsed_table, i)

a.join()
print('Plot time: %.2f' %(float(time.time() - bf) / float(count)))

关于python - Matplotlib 的绘图非常缓慢,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/14349148/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com