gpt4 book ai didi

python - 如何获取 numpy 数组中重复元素的所有索引的列表

转载 作者:太空狗 更新时间:2023-10-29 19:34:08 25 4
gpt4 key购买 nike

我正在尝试获取 numpy 数组中所有重复元素的索引,但我目前找到的解决方案对于大型(>20000 个元素)输入数组来说效率真的很低(大约需要 9 秒) .这个想法很简单:

  1. records_array是一个 numpy 时间戳数组 (datetime),我们要从中提取重复时间戳的索引

  2. time_array 是一个 numpy 数组,包含所有在 records_array

    中重复的时间戳
  3. records 是一个包含一些 Record 对象的 django QuerySet(可以轻松转换为列表)。我们想要创建一个由 Record 的 tagId 属性的所有可能组合构成的列表,对应于从 records_array 中找到的重复时间戳。

这是我目前可用的(但效率低下的)代码:

tag_couples = [];
for t in time_array:
users_inter = np.nonzero(records_array == t)[0] # Get all repeated timestamps in records_array for time t
l = [str(records[i].tagId) for i in users_inter] # Create a temporary list containing all tagIds recorded at time t
if l.count(l[0]) != len(l): #remove tuples formed by the first tag repeated
tag_couples +=[x for x in itertools.combinations(list(set(l)),2)] # Remove duplicates with list(set(l)) and append all possible couple combinations to tag_couples

我很确定这可以通过使用 Numpy 来优化,但是我找不到一种方法来比较 records_arraytime_array 的每个元素而不使用 for循环(这不能用 == 来比较,因为它们都是数组)。

最佳答案

使用 numpy 的矢量化解决方案,利用 unique() 的魔力.

import numpy as np

# create a test array
records_array = np.array([1, 2, 3, 1, 1, 3, 4, 3, 2])

# creates an array of indices, sorted by unique element
idx_sort = np.argsort(records_array)

# sorts records array so all unique elements are together
sorted_records_array = records_array[idx_sort]

# returns the unique values, the index of the first occurrence of a value, and the count for each element
vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)

# splits the indices into separate arrays
res = np.split(idx_sort, idx_start[1:])

#filter them with respect to their size, keeping only items occurring more than once
vals = vals[count > 1]
res = filter(lambda x: x.size > 1, res)

以下代码是原始答案,需要更多内存,使用 numpy 广播并调用 unique 两次:

records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
vals, inverse, count = unique(records_array, return_inverse=True,
return_counts=True)

idx_vals_repeated = where(count > 1)[0]
vals_repeated = vals[idx_vals_repeated]

rows, cols = where(inverse == idx_vals_repeated[:, newaxis])
_, inverse_rows = unique(rows, return_index=True)
res = split(cols, inverse_rows[1:])

符合预期 res = [array([0, 3, 4]), array([1, 8]), array([2, 5, 7])]

关于python - 如何获取 numpy 数组中重复元素的所有索引的列表,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/30003068/

25 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com