gpt4 book ai didi

python - pymc3多元traceplot颜色编码

转载 作者:太空宇宙 更新时间:2023-11-03 12:02:52 27 4
gpt4 key购买 nike

我是 pymc3 的新手,我在生成易于阅读的跟踪图时遇到了问题。我将 4 个多元高斯混合拟合到数据集中的某些 (x, y) 点。该模型运行良好。我的问题是关于操作 pm.traceplot() 命令以使输出更加用户友好。这是我的代码:

import matplotlib.pyplot as plt
import numpy as np
model = pm.Model()
N_CLUSTERS = 4
with model:
#cluster prior
w = pm.Dirichlet('w', np.ones(N_CLUSTERS))
#latent cluster of each observation
category = pm.Categorical('category', p=w, shape=len(points))

#make sure each cluster has some values:
w_min_potential = pm.Potential('w_min_potential', tt.switch(tt.min(w) < 0.1, -np.inf, 0))
#multivariate normal means
mu = pm.MvNormal('mu', [0,0], cov=[[1,0],[0,1]], shape = (N_CLUSTERS,2) )
#break symmetry
pm.Potential('order_mu_potential', tt.switch(
tt.all(
[mu[i, 0] < mu[i+1, 0] for i in range(N_CLUSTERS - 1)]), -np.inf, 0))
#multivariate centers
data = pm.MvNormal('data', mu =mu[category], cov=[[1,0],[0,1]], observed=points)

with model:
trace = pm.sample(1000)

调用 pm.traceplot(trace, ['w', 'mu']) 生成此图像: Default traceplot() output

如您所见,峰值对应于 x 或 y 值以及哪些配对在一起是不明确的。我已经管理了一个解决方法如下:

from cycler import cycler
#plot the x-means and y-means of our data!
fig, (ax0, ax1) = plt.subplots(nrows=2)
plt.xlabel('$\mu$')
plt.ylabel('frequency')
for i in range(4):
ax0.hist(trace['mu'][:,i,0], bins=100, label='x{}'.format(i), alpha=0.6);
ax1.hist(trace['mu'][:,i,1],bins=100, label='y{}'.format(i), alpha=0.6);
ax0.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax1.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax0.legend()
ax1.legend()

这会产生以下更清晰的图: A more legible histogram of means

我已经查看了 pymc3 文档和此处最近的问题,但无济于事。我的问题是:是否可以通过 pymc3 中的内置方法使用 matplotlib 完成我在这里完成的操作,如果可以,怎么做?

最佳答案

最近将多维变量和不同链之间的更好区分添加到 ArviZ(PyMC3 依赖于绘图的库)。

在 ArviZ 的最新版本中,您应该可以:

az.plot_trace(trace, compact=True, legend=True)

获取每个变量的不同维度,以颜色区分,以linestyle区分不同的chain。默认设置使用 matplotlib 的默认颜色循环和 4 种不同的线型,实线、虚线、点线和点划线。通过使用 compact_prop 自定义维度表示和使用 chain_prop 自定义链表示,这两个属性都可以设置为自定义美学和自定义值。此外,如果使用 compact,使用 combined=True 来减少第一列中的困惑也是一个好主意。例如:

az.plot_trace(trace, compact=True, combined=True, legend=True, chain_prop=("ls", "-"))

将使用来自所有链的数据在第一列中绘制 KDE,并将使用实线样式绘制所有链(由于组合 arg,仅与第二列相关)。将显示两个图例,一个用于链信息,另一个用于紧凑信息。

关于python - pymc3多元traceplot颜色编码,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/43193223/

27 4 0
Copyright 2021 - 2024 cfsdn All Rights Reserved 蜀ICP备2022000587号
广告合作:1813099741@qq.com 6ren.com