gpt4 book ai didi

python - interpolate.griddata 只使用一个核心

转载 作者:太空宇宙 更新时间:2023-11-04 04:33:37 31 4
gpt4 key购买 nike

我正在插入一个 2d numpy 数组来填充用 NaN 标记的缺失值。以下代码有效但仅使用一个内核。有没有更好的函数可以用来利用我拥有的所有 24 个内核?

x = np.arange(0, array.shape[1])
y = np.arange(0, array.shape[0])
#mask invalid values
array = np.ma.masked_invalid(array)
xx, yy = np.meshgrid(x, y)
#get only the valid values
x1 = xx[~array.mask]
y1 = yy[~array.mask]
newarr = array[~array.mask]

GD1 = interpolate.griddata((x1, y1), newarr.ravel(),
(xx, yy),
method='cubic')

最佳答案

我认为你可以用 dask 做到这一点.我不太熟悉 dask 但这是一个开始:

import numpy as np
from scipy import interpolate
import dask.array as da
import matplotlib.pyplot as plt
from dask import delayed

# create data with random missing entries
ar_size = 2000
chunk_size = 500
z_array = np.ones((ar_size, ar_size))
z_array[np.random.randint(0, ar_size-1, 50),
np.random.randint(0, ar_size-1, 50)]= np.nan

# XY coords
x = np.linspace(0, 3, z_array.shape[1])
y = np.linspace(0, 3, z_array.shape[0])

# gen sin wave for testing
z_array = z_array * np.sin(x)
# prove there are nans in the dataset
assert np.isnan(np.sum(z_array))

xx, yy = np.meshgrid(x, y)
print("global x.size: ", xx.size)

# make dask arrays
dask_xyz = da.from_array((xx, yy, z_array), chunks=(3, chunk_size, "auto"), name="dask_all")
dask_xx = dask_xyz[0,:,:]
dask_yy = dask_xyz[1,:,:]
dask_zz = dask_xyz[2,:,:]

# select only valid values
dask_valid_y1 = dask_yy[~da.isnan(dask_zz)]
dask_valid_x1 = dask_xx[~da.isnan(dask_zz)]
dask_newarr = dask_zz[~da.isnan(dask_zz)]

def gd_wrapped(x1, y1, newarr, xx, yy):
# note: linear and cubic griddata impl do not extrapolate
# and therefore fail near the boundaries... see RBF interp instead
print("local x.size: ", x1.size)
gd_zz = interpolate.griddata((x1, y1), newarr.ravel(),
(xx, yy),
method='nearest')
return gd_zz

def rbf_wrapped(x1, y1, newarr, xx, yy):
rbf_interpolant = interpolate.Rbf(x1, y1, newarr, function='linear')
return rbf_interpolant(xx, yy)

# interpolate
# gd_chunked = [delayed(rbf_wrapped)(x1, y1, newarr, xx, yy) for \
gd_chunked = [delayed(gd_wrapped)(x1, y1, newarr, xx, yy) for \
x1, y1, newarr, xx, yy \
in \
zip(dask_valid_x1.to_delayed().flatten(),
dask_valid_y1.to_delayed().flatten(),
dask_newarr.to_delayed().flatten(),
dask_xx.to_delayed().flatten(),
dask_yy.to_delayed().flatten())]
gd_out = delayed(da.concatenate)(gd_chunked, axis=0)
gd_out.visualize("dask_par.png")
gd1 = np.array(gd_out.compute())
print(gd1)
assert gd1.shape == (ar_size, ar_size)
print(gd1.shape)
plt.figure()
plt.imshow(gd1)
plt.savefig("dask_par_sin.png")

# prove we have no more nans in the data
assert ~np.isnan(np.sum(gd1))

此实现存在一些问题。 Griddata 无法推断,因此 nans 是 block 边界的问题。你可能可以用一些重叠的单元格来解决这个问题。作为权宜之计,您可以使用 method='nearest' 或尝试 radial basis function interpolation .

关于python - interpolate.griddata 只使用一个核心,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52227599/

31 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com