gpt4 book ai didi

python - 通过在 numpy 中设置一些条件来检索元素的位置

转载 作者:太空狗 更新时间:2023-10-30 02:04:12 24 4
gpt4 key购买 nike

对于给定的二维数据数组,如何检索粗体中的 7 和 11 的位置(索引)。因为只有它们才是邻居中被同值包围的元素

  import numpy as np
data = np.array([
[0,1,2,3,4,7,6,7,8,9,10],
[3,3,3,4,7,7,7,8,11,12,11],
[3,3,3,5,7,**7**,7,9,11,11,11],
[3,4,3,6,7,7,7,10,11,**11**,11],
[4,5,6,7,7,9,10,11,11,11,11]
])

print data

最佳答案

使用 scipy,您可以将这样的点表征为其邻域的最大值和最小值:

import numpy as np
import scipy.ndimage.filters as filters

def using_filters(data):
return np.where(np.logical_and.reduce(
[data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
for f in (filters.maximum_filter, filters.minimum_filter)]))

using_filters(data)
# (array([2, 3]), array([5, 9]))

仅使用 numpy,您可以将 data 与自身的 8 个移位切片进行比较,以找到相等的点:

def using_eight_shifts(data):
h, w = data.shape
data2 = np.empty((h+2, w+2))
data2[(0,-1),:] = np.nan
data2[:,(0,-1)] = np.nan
data2[1:1+h,1:1+w] = data

result = np.where(np.logical_and.reduce([
(data2[i:i+h,j:j+w] == data)
for i in range(3)
for j in range(3)
if not (i==1 and j==1)]))
return result

正如您在上面看到的,此策略生成了一个扩展数组,该数组在数据周围有一个 NaN 边界。这允许将移位的切片表示为 data2[i:i+h,j:j+w]

如果您知道要与邻居进行比较,您可能需要从一开始就定义带有 NaN 边界的 data,这样您就不必创建第二个数组如上所述。

使用八次移位(和比较)比遍历 data 中的每个单元格并将其与相邻单元格进行比较要快得多:

def using_quadratic_loop(data):
return np.array([[i,j]
for i in range(1,np.shape(data)[0]-1)
for j in range(1,np.shape(data)[1]-1)
if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T

这是一个基准:

using_filters            : 0.130
using_eight_shifts : 0.340
using_quadratic_loop : 18.794

这是用于生成基准的代码:

import timeit
import operator
import numpy as np
import scipy.ndimage.filters as filters
import matplotlib.pyplot as plt

data = np.array([
[0,1,2,3,4,7,6,7,8,9,10],
[3,3,3,4,7,7,7,8,11,12,11],
[3,3,3,5,7,7,7,9,11,11,11],
[3,4,3,6,7,7,7,10,11,11,11],
[4,5,6,7,7,9,10,11,11,11,11]
])

data = np.tile(data, (50,50))

def using_filters(data):
return np.where(np.logical_and.reduce(
[data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
for f in (filters.maximum_filter, filters.minimum_filter)]))


def using_eight_shifts(data):
h, w = data.shape
data2 = np.empty((h+2, w+2))
data2[(0,-1),:] = np.nan
data2[:,(0,-1)] = np.nan
data2[1:1+h,1:1+w] = data

result = np.where(np.logical_and.reduce([
(data2[i:i+h,j:j+w] == data)
for i in range(3)
for j in range(3)
if not (i==1 and j==1)]))
return result


def using_quadratic_loop(data):
return np.array([[i,j]
for i in range(1,np.shape(data)[0]-1)
for j in range(1,np.shape(data)[1]-1)
if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T

np.testing.assert_equal(using_quadratic_loop(data), using_filters(data))
np.testing.assert_equal(using_eight_shifts(data), using_filters(data))

timing = dict()
for f in ('using_filters', 'using_eight_shifts', 'using_quadratic_loop'):
timing[f] = timeit.timeit('{f}(data)'.format(f=f),
'from __main__ import data, {f}'.format(f=f),
number=10)

for f, t in sorted(timing.items(), key=operator.itemgetter(1)):
print('{f:25}: {t:.3f}'.format(f=f, t=t))

关于python - 通过在 numpy 中设置一些条件来检索元素的位置,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/23025497/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com