作者热门文章
- c - 在位数组中找到第一个零
- linux - Unix 显示有关匹配两种模式之一的文件的信息
- 正则表达式替换多个文件
- linux - 隐藏来自 xtrace 的命令
假设我有一个 Pandas 系列 s
,其值总和为 1
,并且其值也都大于或等于 0
。我需要从所有值中减去一个常量,以使新系列的总和等于 0.6
。问题是,当我减去这个常量时,值永远不会小于零。
在数学公式中,假设我有一系列 x
,我想找到 k
import pandas as pd
import numpy as np
from string import ascii_uppercase
np.random.seed([3, 141592653])
s = np.power(
1000, pd.Series(
np.random.rand(10),
list(ascii_uppercase[:10])
)
).pipe(lambda s: s / s.sum())
s
A 0.001352
B 0.163135
C 0.088365
D 0.010904
E 0.007615
F 0.407947
G 0.005856
H 0.198381
I 0.027455
J 0.088989
dtype: float64
总和是1
s.sum()
0.99999999999999989
我可以使用 Scipy 的 optimize
模块中的牛顿法(以及其他方法)
from scipy.optimize import newton
def f(k):
return s.sub(k).clip(0).sum() - .6
找到这个函数的根会得到我需要的k
initial_guess = .1
k = newton(f, x0=initial_guess)
然后从s
中减去这个
new_s = s.sub(k).clip(0)
new_s
A 0.000000
B 0.093772
C 0.019002
D 0.000000
E 0.000000
F 0.338583
G 0.000000
H 0.129017
I 0.000000
J 0.019626
dtype: float64
新的总和是
new_s.sum()
0.60000000000000009
我们能否在不使用求解器的情况下找到 k
?
最佳答案
我没想到 newton
会取得成功。但在大型阵列上,它确实如此。
numba.njit
灵感来自 Thierry's Answer
使用 numba
s jit
import numpy as np
from numba import njit
@njit
def find_k_numba(a, t):
a = np.sort(a)
m = len(a)
s = a.sum()
to_remove = s - t
if to_remove <= 0:
k = 0
else:
for i, x in enumerate(a):
k = to_remove / (m - i)
if k < x:
break
else:
to_remove -= x
return k
numpy
灵感来自 Paul's Answer
Numpy 肩负重任。
import numpy as np
def find_k_numpy(a, t):
a = np.sort(a)
m = len(a)
s = a.sum()
to_remove = s - t
if to_remove <= 0:
k = 0
else:
c = a.cumsum()
n = np.arange(m)[::-1]
b = n * a + c
i = np.searchsorted(b, to_remove)
k = a[i] + (to_remove - b[i]) / (m - i)
return k
scipy.optimize.newton
我的牛顿方法
import numpy as np
from scipy.optimize import newton
def find_k_newton(a, t):
s = a.sum()
if s <= t:
k = 0
else:
def f(k_):
return np.clip(a - k_, 0, a.max()).sum() - t
k = newton(f, (s - t) / len(a))
return k
import pandas as pd
from timeit import timeit
res = pd.DataFrame(
np.nan, [10, 30, 100, 300, 1000, 3000, 10000, 30000],
'find_k_newton find_k_numpy find_k_numba'.split()
)
for i in res.index:
a = np.random.rand(i)
t = a.sum() * .6
for j in res.columns:
stmt = f'{j}(a, t)'
setp = f'from __main__ import {j}, a, t'
res.at[i, j] = timeit(stmt, setp, number=200)
res.plot(loglog=True)
res.div(res.min(1), 0)
find_k_newton find_k_numpy find_k_numba
10 57.265421 17.552150 1.000000
30 29.221947 9.420263 1.000000
100 16.920463 5.294890 1.000000
300 10.761341 3.037060 1.000000
1000 1.455159 1.033066 1.000000
3000 1.000000 2.076484 2.550152
10000 1.000000 3.798906 4.421955
30000 1.000000 5.551422 6.784594
关于python - 求根的闭式解,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/49314543/
我是一名优秀的程序员,十分优秀!