gpt4 book ai didi

python - 如何引导 numpy 数组的最里面的数组?

转载 作者:行者123 更新时间:2023-12-01 01:27:11 25 4
gpt4 key购买 nike

我有一个这些维度的 numpy 数组

data.shape(类别、模型、类型、事件):(10, 11, 50, 100)

现在我只想在最里面的数组(100)中进行替换样本。对于像这样的单个数组:

数据[0][0][0]


数组([ 40.448624 , 39.459843 , 33.76762 , 38.944622 , 21.407362 ,
35.55499、68.5111、16.512974、21.118315、18.447166、
16.026619、21.596252、41.798622、63.01645、46.886642、
68.874756、17.472408、53.015724、85.41213、59.388977、
17.352108、61.161705、23.430847、20.203123、22.73194、
77.40547、43.02974、29.745787、21.50163、13.820962、
46.91466、41.43656、18.008326、13.122162、59.79936、
94.555305、24.798452、30.362497、13.629236、10.792178、
35.298515、20.904285、15.409604、20.567234、46.376335、
13.82727、17.970661、18.408686、21.987917、21.30094、
24.26776、27.399046、49.16879、21.831453、66.577、
15.524615、18.091696、24.346598、24.709772、19.068447、
24.221592、25.244864、52.865868、22.860783、23.586731、
18.928782、21.960285、74.77856、15.176119、20.795431、
14.3638935、35.937237、29.993324、30.848495、48.145336、
38.02541、101.15249、49.801117、38.123184、12.041505、
18.788296、20.53382、31.20367、19.76104、92.56279、
41.62944、23.53344、18.967432、14.781404、20.02018、
27.736559、16.108913、44.935062、12.629299、34.65672、
20.60169、21.779675、31.585844、23.768578、92.463196]、
数据类型=float32)

我可以使用以下内容进行替换样本:np.random.choice(data[0][0][0], 100),我将做了数千次。


数组([ 13.629236, 92.56279 , 21.960285, 20.567234, 21.50163 ,
16.026619、20.203123、23.430847、16.512974、15.524615、
18.967432、22.860783、85.41213、21.779675、23.586731、
24.26776、66.577、20.904285、19.068447、21.960285、
68.874756、31.585844、23.586731、61.161705、101.15249、
59.79936、16.512974、43.02974、16.108913、24.26776、
23.430847、14.781404、40.448624、13.629236、24.26776、
19.068447、16.026619、16.512974、16.108913、77.40547、
12.629299、31.585844、24.798452、18.967432、14.781404、
23.430847、49.16879、18.408686、22.73194、10.792178、
16.108913、18.967432、12.041505、85.41213、41.62944、
31.20367、17.970661、29.745787、39.459843、10.792178、
43.02974、21.831453、21.50163、24.798452、30.362497、
21.50163、18.788296、20.904285、17.352108、41.798622、
18.447166、16.108913、19.068447、61.161705、52.865868、
20.795431、85.41213、49.801117、13.82727、18.928782、
41.43656、46.886642、92.56279、41.62944、18.091696、
20.60169、48.145336、20.53382、40.448624、20.60169、
23.586731、22.73194、92.56279、94.555305、22.73194、
17.352108, 46.886642, 27.399046, 18.008326, 15.176119],
数据类型=float32)

但是由于 np.random.choice 中没有 axis ,我该如何对所有数组(即(类别、模型、类型))执行此操作?或者循环遍历它是唯一的选择吗?

最佳答案

最快/最简单的答案是基于对数组的扁平版本进行索引:

def resampFlat(arr, reps):
n = arr.shape[-1]

# create an array to shift random indexes as needed
shift = np.repeat(np.arange(0, arr.size, n), n).reshape(arr.shape)

# get a flat view of the array
arrflat = arr.ravel()
# sample the array by generating random ints and shifting them appropriately
return np.array([arrflat[np.random.randint(0, n, arr.shape) + shift]
for i in range(reps)])

时间确认这是最快的答案。

时间

我测试了上面的 resampFlat 函数以及更简单的基于 for 循环的解决方案:

def resampFor(arr, reps):
# store the shape for the return value
shape = arr.shape
# flatten all dimensions of arr except the last
arr = arr.reshape(-1, arr.shape[-1])
# preallocate the return value
ret = np.empty((reps, *arr.shape), dtype=arr.dtype)
# generate the indices of the resampled values
idxs = np.random.randint(0, arr.shape[-1], (reps, *arr.shape))

for rep,idx in zip(ret, idxs):
# iterate over the resampled replicates
for row,rowrep,i in zip(arr, rep, idx):
# iterate over the event arrays within a replicate
rowrep[...] = row[i]

# give the return value the appropriate shape
return ret.reshape((reps, *shape))

以及基于 Paul Panzer 奇特索引方法的解决方案:

def resampFancyIdx(arr, reps):
idx = np.random.randint(0, arr.shape[-1], (reps, *data.shape))
_, I, J, K, _ = np.ogrid[tuple(map(slice, (0, *arr.shape[:-1], 0)))]

return arr[I, J, K, idx]

我使用以下数据进行了测试:

shape = ((10, 11, 50, 100))
data = np.arange(np.prod(shape)).reshape(shape)

以下是数组展平方法的结果:

%%timeit
resampFlat(data, 100)

1.25 s ± 9.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

for 循环方法的结果:

%%timeit
resampFor(data, 100)

1.66 s ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

来自 Paul 的精美索引:

%%timeit
resampFancyIdx(data, 100)

1.42 s ± 16.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

与我的预期相反,resampFancyIdx 击败了 resampFor,实际上我必须相当努力才能想出更好的东西。在这一点上,我真的很想更好地解释一下 C 级索引如何工作,以及为什么它如此高效。

关于python - 如何引导 numpy 数组的最里面的数组?,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/53236040/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com