gpt4 book ai didi

python - 使用 apply_along_axis 绘制

转载 作者:行者123 更新时间:2023-11-28 17:42:23 24 4
gpt4 key购买 nike

我有一个 3D ndarry 对象,它包含光谱数据(即空间 xy 维度和能量维度)。我想提取并绘制线图中每个像素的光谱。目前,我正在沿我感兴趣的轴使用 np.ndenumerate 来执行此操作,但速度很慢。我希望尝试 np.apply_along_axis,看看它是否更快,但我不断收到一个奇怪的错误。

什么有效:

# Setup environment, and generate sample data (much smaller than real thing!)
import numpy as np
import matplotlib.pyplot as plt

ax = range(0,10) # the scale to use when plotting the axis of interest
ar = np.random.rand(4,4,10) # the 3D data volume

# Plot all lines along axis 2 (i.e. the spectrum contained in each pixel)
# on a single line plot:

for (x,y) in np.ndenumerate(ar[:,:,1]):
plt.plot(ax,ar[x[0],x[1],:],alpha=0.5,color='black')

据我了解,这基本上是一个循环,其效率低于基于数组的方法,因此我想尝试一种使用 np.apply_along_axis 的方法,看看它是否更快。然而,这是我第一次尝试使用 Python,我仍在寻找它的工作原理,所以如果这个想法存在根本性缺陷,请纠正我!

我想尝试的:

# define a function to pass to apply_along_axis
def pa(y,x):
if ~all(np.isnan(y)): # only do the plot if there is actually data there...
plt.plot(x,y,alpha=0.15,color='black')
return

# check that the function actually works...
pa(ar[1,1,:],ax) # should produce a plot - does for me :)

# try to apply to to the whole array, along the axis of interest:
np.apply_along_axis(pa,2,ar,ax) # does not work... booo!

产生的错误:

---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-109-5192831ba03c> in <module>()
12 # pa(ar[1,1,:],ax)
13
---> 14 np.apply_along_axis(pa,2,ar,ax)

//anaconda/lib/python2.7/site-packages/numpy/lib/shape_base.pyc in apply_along_axis(func1d, axis, arr, *args)
101 holdshape = outshape
102 outshape = list(arr.shape)
--> 103 outshape[axis] = len(res)
104 outarr = zeros(outshape, asarray(res).dtype)
105 outarr[tuple(i.tolist())] = res

TypeError: object of type 'NoneType' has no len()

任何关于这里出了什么问题的想法/关于如何更好地做到这一点的建议都会很棒。

谢谢!

最佳答案

apply_along_axis 根据函数的输出创建一个新数组

您将返回 None(不返回任何内容)。因此错误。 Numpy 检查返回输出的长度以查看它对新数组是否有意义。

因为您不是根据结果构造新数组,所以没有理由使用 apply_along_axis。它不会更快。

但是,您当前的 ndenumerate 语句完全等同于:

import numpy as np
import matplotlib.pyplot as plt

ar = np.random.rand(4,4,10) # the 3D data volume
plt.plot(ar.reshape(-1, 10).T, alpha=0.5, color='black')

一般来说,你可能想做这样的事情:

for pixel in ar.reshape(-1, ar.shape[-1]):
plt.plot(x_values, pixel, ...)

这样您就可以轻松地迭代高光谱阵列中每个像素的光谱。


这里的瓶颈可能不是您使用数组的方式。在 matplotlib 中使用相同的参数单独绘制每条线会有些低效。

构造时间会稍长一些,但是 LineCollection 的渲染速度会快得多。 (基本上,使用 LineCollection 告诉 matplotlib 不要费心检查每条线的属性,只需将它们全部传递给低级渲染器以相同的方式绘制。你绕过了一堆单个 draw 调用支持大型对象的单个 draw。)

不利的一面是,代码的可读性会差一些。

稍后我将添加一个示例。

关于python - 使用 apply_along_axis 绘制,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/22431985/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com