gpt4 book ai didi

python - 查找每个唯一 bin 的最大值位置 (binargmax)

转载 作者:IT老高 更新时间:2023-10-28 21:41:42 26 4
gpt4 key购买 nike

设置

假设我有

bins = np.array([0, 0, 1, 1, 2, 2, 2, 0, 1, 2])
vals = np.array([8, 7, 3, 4, 1, 2, 6, 5, 0, 9])
k = 3

我需要 bins 中唯一 bin 的最大值位置。

# Bin == 0
# ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 8 and happens at position 0

(vals * (bins == 0)).argmax()

0

# Bin == 1
# ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 4 and happens at position 3

(vals * (bins == 1)).argmax()

3

# Bin == 2
# ↓ ↓ ↓ ↓
# [0 0 1 1 2 2 2 0 1 2]
# [8 7 3 4 1 2 6 5 0 9]
# ↑ ↑ ↑ ↑
# ⇧
# [0 1 2 3 4 5 6 7 8 9]
# Maximum is 9 and happens at position 9

(vals * (bins == 2)).argmax()

9

这些函数很老套,甚至不能泛化为负值。

问题

如何使用 Numpy 以最有效的方式获取所有这些值?

我尝试过的。

def binargmax(bins, vals, k):
out = -np.ones(k, np.int64)
trk = np.empty(k, vals.dtype)
trk.fill(np.nanmin(vals) - 1)

for i in range(len(bins)):
v = vals[i]
b = bins[i]
if v > trk[b]:
trk[b] = v
out[b] = i

return out

binargmax(bins, vals, k)

array([0, 3, 9])

LINK TO TESTING AND VALIDATION

最佳答案

numpy_indexed 库:

我知道这在技术上不是 numpy,但 numpy_indexed 库有一个矢量化 group_by 函数,非常适合这个,只是想要分享作为我经常使用的替代方案:

>>> import numpy_indexed as npi
>>> npi.group_by(bins).argmax(vals)
(array([0, 1, 2]), array([0, 3, 9], dtype=int64))

使用简单的pandas groupbyidxmax:

df = pd.DataFrame({'bins': bins, 'vals': vals})
df.groupby('bins').vals.idxmax()

使用 sparse.csr_matrix

这个选项在非常大的输入上非常快。

sparse.csr_matrix(
(vals, bins, np.arange(vals.shape[0]+1)), (vals.shape[0], k)
).argmax(0)

# matrix([[0, 3, 9]])

性能

函数

def chris(bins, vals, k):
return npi.group_by(bins).argmax(vals)

def chris2(df):
return df.groupby('bins').vals.idxmax()

def chris3(bins, vals, k):
sparse.csr_matrix((vals, bins, np.arange(vals.shape[0] + 1)), (vals.shape[0], k)).argmax(0)

def divakar(bins, vals, k):
mx = vals.max()+1

sidx = bins.argsort()
sb = bins[sidx]
sm = np.r_[sb[:-1] != sb[1:],True]

argmax_out = np.argsort(bins*mx + vals)[sm]
max_out = vals[argmax_out]
return max_out, argmax_out

def divakar2(bins, vals, k):
last_idx = np.bincount(bins).cumsum()-1
scaled_vals = bins*(vals.max()+1) + vals
argmax_out = np.argsort(scaled_vals)[last_idx]
max_out = vals[argmax_out]
return max_out, argmax_out


def user545424(bins, vals, k):
return np.argmax(vals*(bins == np.arange(bins.max()+1)[:,np.newaxis]),axis=-1)

def user2699(bins, vals, k):
res = []
for v in np.unique(bins):
idx = (bins==v)
r = np.where(idx)[0][np.argmax(vals[idx])]
res.append(r)
return np.array(res)

def sacul(bins, vals, k):
return np.lexsort((vals, bins))[np.append(np.diff(np.sort(bins)), 1).astype(bool)]

@njit
def piRSquared(bins, vals, k):
out = -np.ones(k, np.int64)
trk = np.empty(k, vals.dtype)
trk.fill(np.nanmin(vals))

for i in range(len(bins)):
v = vals[i]
b = bins[i]
if v > trk[b]:
trk[b] = v
out[b] = i

return out

设置

import numpy_indexed as npi
import numpy as np
import pandas as pd
from timeit import timeit
import matplotlib.pyplot as plt
from numba import njit
from scipy import sparse

res = pd.DataFrame(
index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
dtype=float
)

k = 5

for f in res.index:
for c in res.columns:
bins = np.random.randint(0, k, c)
k = 5
vals = np.random.rand(c)
df = pd.DataFrame({'bins': bins, 'vals': vals})
stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
setp = 'from __main__ import bins, vals, k, df, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

结果

enter image description here

k 更大的结果(这是广播受到重创的地方):

res = pd.DataFrame(
index=['chris', 'chris2', 'chris3', 'divakar', 'divakar2', 'user545424', 'user2699', 'sacul', 'piRSquared'],
columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000],
dtype=float
)

k = 500

for f in res.index:
for c in res.columns:
bins = np.random.randint(0, k, c)
vals = np.random.rand(c)
df = pd.DataFrame({'bins': bins, 'vals': vals})
stmt = '{}(df)'.format(f) if f in {'chris2'} else '{}(bins, vals, k)'.format(f)
setp = 'from __main__ import bins, vals, df, k, {}'.format(f)
res.at[f, c] = timeit(stmt, setp, number=50)

ax = res.div(res.min()).T.plot(loglog=True)
ax.set_xlabel("N");
ax.set_ylabel("time (relative)");

plt.show()

enter image description here

从图中可以明显看出,当组的数量很少时,广播是一个绝妙的技巧,但是在 k 值较高时,广播的时间复杂度/内存增长过快,无法使其具有高性能.

关于python - 查找每个唯一 bin 的最大值位置 (binargmax),我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/52006947/

26 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com