Multiple color-coded traceplot pymc3
I'm new to pymc3 and I'm having trouble generating an easy-to-read trace template. I am mapping a mixture of 4 multivariate gaussians to some (x, y) points in the dataset. The model works fine. My question is how to manipulate the pm.traceplot () command to make the output more user friendly. Here's my code:
import matplotlib.pyplot as plt
import numpy as np
model = pm.Model()
N_CLUSTERS = 4
with model:
#cluster prior
w = pm.Dirichlet('w', np.ones(N_CLUSTERS))
#latent cluster of each observation
category = pm.Categorical('category', p=w, shape=len(points))
#make sure each cluster has some values:
w_min_potential = pm.Potential('w_min_potential', tt.switch(tt.min(w) < 0.1, -np.inf, 0))
#multivariate normal means
mu = pm.MvNormal('mu', [0,0], cov=[[1,0],[0,1]], shape = (N_CLUSTERS,2) )
#break symmetry
pm.Potential('order_mu_potential', tt.switch(
tt.all(
[mu[i, 0] < mu[i+1, 0] for i in range(N_CLUSTERS - 1)]), -np.inf, 0))
#multivariate centers
data = pm.MvNormal('data', mu =mu[category], cov=[[1,0],[0,1]], observed=points)
with model:
trace = pm.sample(1000)
The call pm.traceplot(trace, ['w', 'mu'])
brings up this image:
As you can see, this is ambiguous, which means the peak corresponds to the x or y value, and which ones are paired together. I succeeded in a workaround as follows:
from cycler import cycler
#plot the x-means and y-means of our data!
fig, (ax0, ax1) = plt.subplots(nrows=2)
plt.xlabel('$\mu$')
plt.ylabel('frequency')
for i in range(4):
ax0.hist(trace['mu'][:,i,0], bins=100, label='x{}'.format(i), alpha=0.6);
ax1.hist(trace['mu'][:,i,1],bins=100, label='y{}'.format(i), alpha=0.6);
ax0.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax1.set_prop_cycle(cycler('color', ['c', 'm', 'y', 'k']))
ax0.legend()
ax1.legend()
This gives the following, much clearer plot:
I've looked at the pymc3 documentation and the latest questions here, to no avail. My question is this: is it possible to do what I did here with matplotlib via the built-in methods in pymc3, and if so, how?
source to share
At least in recent versions, you can use compact=True
both in:
pm.traceplot(trace, var_names = ['parameters'], compact=True)
to get one graph with all your parameters together Docs at: https://arviz-devs.github.io/arviz/_modules/arviz/plots/traceplot.html
However, I was unable to get the colors to differ between the lines
source to share