gpt4 book ai didi

python - 多元正态的 Numpy 向量化

转载 作者:太空宇宙 更新时间:2023-11-04 02:33:00 24 4
gpt4 key购买 nike

我有两个二维 numpy 数组 A、B。我想使用 scipy.stats.multivariate_normal 来计算 A 中每一行的联合 logpdf,使用 B 中的每一行作为协方差矩阵。是否有某种方法可以在不显式循环行的情况下执行此操作? scipy.stats.multivariate_normal 对 A 和 B 的直接应用确实计算了 A 中每一行的 logpdf(这是我想要的),但是使用整个二维数组 A 作为协方差矩阵,这不是我想要的(我需要B 的每一行创建一个不同的协方差矩阵)。我正在寻找一种使用 numpy 向量化并避免显式循环遍历两个数组的解决方案。

最佳答案

我也在尝试完成类似的事情。这是我的代码,它包含三个 NxD 矩阵。每行X是一个数据点,每行means是一个均值向量,每行covariances是对角线的对角线向量协方差矩阵。结果是一个长度为 N 的对数概率向量。

def vectorized_gaussian_logpdf(X, means, covariances):
"""
Compute log N(x_i; mu_i, sigma_i) for each x_i, mu_i, sigma_i
Args:
X : shape (n, d)
Data points
means : shape (n, d)
Mean vectors
covariances : shape (n, d)
Diagonal covariance matrices
Returns:
logpdfs : shape (n,)
Log probabilities
"""
_, d = X.shape
constant = d * np.log(2 * np.pi)
log_determinants = np.log(np.prod(covariances, axis=1))
deviations = X - means
inverses = 1 / covariances
return -0.5 * (constant + log_determinants +
np.sum(deviations * inverses * deviations, axis=1))

请注意,此代码仅适用于对角协方差矩阵。在这种特殊情况下,下面的数学定义被简化:行列式变为元素上的乘积,逆变为逐元素倒数,矩阵乘法变为逐元素乘法。

multivariate normal pdf

正确性和运行时间的快速测试:

def test_vectorized_gaussian_logpdf():
n = 128**2
d = 64

means = np.random.uniform(-1, 1, (n, d))
covariances = np.random.uniform(0, 2, (n, d))
X = np.random.uniform(-1, 1, (n, d))

refs = []

ref_start = time.time()
for x, mean, covariance in zip(X, means, covariances):
refs.append(scipy.stats.multivariate_normal.logpdf(x, mean, covariance))
ref_time = time.time() - ref_start

fast_start = time.time()
results = vectorized_gaussian_logpdf(X, means, covariances)
fast_time = time.time() - fast_start

print("Reference time:", ref_time)
print("Vectorized time:", fast_time)
print("Speedup:", ref_time / fast_time)

assert np.allclose(results, refs)

我获得了大约 250 倍的加速。 (是的,我的应用程序要求我计算 16384 个不同的高斯分布。)

关于python - 多元正态的 Numpy 向量化,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48686934/

24 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com