gpt4 book ai didi

python - 大型 3D 阵列上的快速一维线性 np.NaN 插值

转载 作者:太空宇宙 更新时间:2023-11-03 13:15:43 25 4
gpt4 key购买 nike

我有一个 (z, y, x) 的 3D 数组 shape=(92, 4800, 4800) 其中每个值沿 axis 0 代表不同的时间点。在少数情况下,时域中的值获取失败,导致某些值为 np.NaN。在其他情况下,没有获取任何值,z 上的所有值都是 np.NaN

使用线性插值沿着 axis 0 填充 np.NaN 的最有效方法是什么,忽略所有值都是 np.NaN 的实例>?

这是我正在做的工作示例,它使用 pandas 包装器到 scipy.interpolate.interp1d。这在原始数据集上每个切片大约需要 2 秒,这意味着整个数组在 2.6 小时内处理完毕。大小减小的示例数据集大约需要 9.5 秒。

import numpy as np
import pandas as pd

# create example data, original is (92, 4800, 4800)
test_arr = np.random.randint(low=-10000, high=10000, size=(92, 480, 480))
test_arr[1:90:7, :, :] = -32768 # NaN fill value in original data
test_arr[:, 1:90:6, 1:90:8] = -32768

def interpolate_nan(arr, method="linear", limit=3):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)

for i in range(arr.shape[1]):
# slice along y axis, interpolate with pandas wrapper to interp1d
line_stack = pd.DataFrame(data=arr[:,i,:], dtype=np.float32)
line_stack.replace(to_replace=-37268, value=np.NaN, inplace=True)
line_stack.interpolate(method=method, axis=0, inplace=True, limit=limit)
line_stack.replace(to_replace=np.NaN, value=-37268, inplace=True)
result[:, i, :] = line_stack.values.astype(np.int16)
return result

使用示例数据集在我的机器上的性能:

%timeit interpolate_nan(test_arr)
1 loops, best of 3: 9.51 s per loop

编辑:

我应该澄清一下代码正在产生我预期的结果。问题是 - 如何优化这个过程?

最佳答案

我最近在 numba 的帮助下为我的特定用例解决了这个问题,并且还做了 a little writeup on it .

from numba import jit

@jit(nopython=True)
def interpolate_numba(arr, no_data=-32768):
"""return array interpolated along time-axis to fill missing values"""
result = np.zeros_like(arr, dtype=np.int16)

for x in range(arr.shape[2]):
# slice along x axis
for y in range(arr.shape[1]):
# slice along y axis
for z in range(arr.shape[0]):
value = arr[z,y,x]
if z == 0: # don't interpolate first value
new_value = value
elif z == len(arr[:,0,0])-1: # don't interpolate last value
new_value = value

elif value == no_data: # interpolate

left = arr[z-1,y,x]
right = arr[z+1,y,x]
# look for valid neighbours
if left != no_data and right != no_data: # left and right are valid
new_value = (left + right) / 2

elif left == no_data and z == 1: # boundary condition left
new_value = value
elif right == no_data and z == len(arr[:,0,0])-2: # boundary condition right
new_value = value

elif left == no_data and right != no_data: # take second neighbour to the left
more_left = arr[z-2,y,x]
if more_left == no_data:
new_value = value
else:
new_value = (more_left + right) / 2

elif left != no_data and right == no_data: # take second neighbour to the right
more_right = arr[z+2,y,x]
if more_right == no_data:
new_value = value
else:
new_value = (more_right + left) / 2

elif left == no_data and right == no_data: # take second neighbour on both sides
more_left = arr[z-2,y,x]
more_right = arr[z+2,y,x]
if more_left != no_data and more_right != no_data:
new_value = (more_left + more_right) / 2
else:
new_value = value
else:
new_value = value
else:
new_value = value
result[z,y,x] = int(new_value)
return result

这比我的初始代码快 20 倍

关于python - 大型 3D 阵列上的快速一维线性 np.NaN 插值,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30910944/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com