gpt4 book ai didi

python - Python 中的一维 Wasserstein 距离

转载 作者:行者123 更新时间:2023-12-04 08:27:56 68 4
gpt4 key购买 nike

当源分布和目标分布 xy(也称为边际分布)为一维时,下面的公式是 Wasserstein 距离/最优传输的特例,也就是说,是向量。

enter image description here

其中 F^{-1} 是边际 uv 的累积分布的逆概率分布函数,来自真实名为 xy 的数据,均从正态分布生成:

import numpy as np
from numpy.random import randn
import scipy.stats as ss

n = 100
x = randn(n)
y = randn(n)

公式中的积分如何用python和scipy编码?我猜 x 和 y 必须转换为排名边缘,它们是非负的并且总和为 1,而 Scipy 的 ppf 可用于计算逆 F^{- 1}的?

最佳答案

请注意,当 n 变大时,我们有一组经过排序的 n 样本接近以 1/n、2/n、... 采样的逆 CDF不详。例如:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
plt.plot(norm.ppf(np.linspace(0, 1, 1000)), label="invcdf")
plt.plot(np.sort(np.random.normal(size=1000)), label="sortsample")
plt.legend()
plt.show()

plot

另请注意,从 0 到 1 的积分可以近似为 1/n、2/n、...、n/n 的总和。

因此我们可以简单地回答您的问题:

def W(p, u, v):
assert len(u) == len(v)
return np.mean(np.abs(np.sort(u) - np.sort(v))**p)**(1/p)

请注意,如果 len(u) != len(v),您仍然可以应用线性插值方法:

def W(p, u, v):
u = np.sort(u)
v = np.sort(v)
if len(u) != len(v):
if len(u) > len(v): u, v = v, u
us = np.linspace(0, 1, len(u))
vs = np.linspace(0, 1, len(v))
u = np.linalg.interp(u, us, vs)
return np.mean(np.abs(u - v)**p)**(1/p)

如果您有关于数据分布类型的先验信息而不是其参数,另一种方法是找到数据的最佳拟合分布(例如使用 scipy.stats.norm.fit) 为 uv 然后以所需的精度进行积分。例如:

from scipy.stats import norm as gauss
def W_gauss(p, u, v, num_steps):
ud = gauss(*gauss.fit(u))
vd = gauss(*gauss.fit(v))
z = np.linspace(0, 1, num_steps, endpoint=False) + 1/(2*num_steps)
return np.mean(np.abs(ud.ppf(z) - vd.ppf(z))**p)**(1/p)

关于python - Python 中的一维 Wasserstein 距离,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/65175268/

68 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com