How to use the ArviZ plot_lm function?

Utkarsh Maheshwari
6 min readAug 14, 2021

--

API Reference
Source code
PR

This is a quick guide to plot regression Bayesian models, which I developed during GSoC 2021. If you want to dive into the lower level, head over to the API reference and source code mentioned above. In this blog, I will be showing the different visualization ArviZ’s plot_lm offers and how they could be achieved. I have taken an example from ROS kidiq which is a good one to show various use cases it supports.

plot_lm aims to provide support for regression bayesian models. It would help users visualize the models in fewer lines of code. plot_lm is a useful tool for you if:

  1. You want quick visualizations of your bayesian models in an interactive way.
  2. You are less familiar with xarray and reshaping functions and don’t want to get stuck solving frustrating dimension mismatch errors.
  3. You have less experience working with plotting libraries like matplotlib and bokeh.
import pymc3 as pm
import pandas as pd
import arviz as az

We have the csv file. It can easily be imported using pandas’ read_csv method. It is used to convert the csv file to pandas dataframe.

data = pd.read_csv("../kidiq.csv")
data.head()
png
data.describe()
png

We’ll model the kid_score value against a single parameter, mom_iq. But first, let us visualize kid_score vs mom_iq

data.plot.scatter("mom_iq", "kid_score", figsize=(12,8), s=40)
plt.show()
png

Let us try to model a linear regression equation in the bayesian way using pymc.

with pm.Model() as model:
mom_iq = pm.Data("mom_iq", data["mom_iq"])

sigma = pm.HalfNormal('sigma', sd=10)
intercept = pm.Normal('Intercept', mu=0, sd=10)
x_coeff = pm.Normal('slope', mu=0, sd=10)

mean = intercept + x_coeff * mom_iq
likelihood = pm.Normal(
'kid_score', mu=mean, sd=sigma, observed=data["kid_score"]
)

trace = pm.sample(500)
prior = pm.sample_prior_predictive()
posterior_predictive = pm.sample_posterior_predictive(trace, samples=500)

InferenceData creation step. InferenceData is a data structure that ArviZ follows. It is based on xarray DataArrays and Datasets. To know more about it, kindly refer to ArviZ docs here.
We can use the amazing converter from_pymc3. There are several others for different libraries like pystan, tensorFlowProbablity, etc. This is a crucial step! Be careful with coords and dims.

idata = az.from_pymc3(
trace = trace,
posterior_predictive=posterior_predictive,
dims = {
"kid_score": ["mom_iq"],
},
coords={
"mom_iq": data.mom_iq,
}
)

We have our model ready! Now, all we need is to pass this information in the plot_lm in the right way. You can pass it in different ways. Let’s look at them one by one.

Using idata make things easier for both, user and developers. However, it is not compulsory to use idata argument and instead one can directly pass array-like inputs to y and x. However, for visualizing uncertainty, idata is necessary.

Just give it idata and y, variable name to be plotted. x will be automatically selected as coords of y.

ax = az.plot_lm(idata=idata, y="kid_score", figsize=(10,6))
png

Do you want to select the x plotters by yourself? Just keep it in constant_data and provide the name. Alternatively, you can provide the values by passing an array-like structure.

az.plot_lm(idata=idata, y="kid_score", x="mom_iq", figsize=(10,6))

If your x values in observed_data , you can pass x=idata.observed_Data[“x”].

It is not limited to just plotting the samples. It offers added functionality to select your type of uncertainty plot! Either “samples” or “hdi”, “samples” by default.

ax = az.plot_lm(idata=idata, y="kid_score", kind_pp="hdi", figsize=(10,6))
png

Not satisfied? There is more! There is a way to visualize uncertainty in the mean along with the uncertainty in observed data. But we need to provide the model info for that, using y_model argument, pointing to a variable having shape = (chain, draws, \*)

idata.posterior["model"] = idata.posterior["Intercept"] + idata.posterior["slope"]*idata.constant_data["mom_iq"]
az.plot_lm(
idata=idata,
y="kid_score",
y_model="model",
figsize=(10,6),
legend=False
)
png

And just like kind_pp, there is kind_model to specify the type of uncertainty.

az.plot_lm(
idata=idata,
y="kid_score",
y_model="model",
kind_model="hdi",
figsize=(10,6),
legend=False
)
png

Have multiple plots? hold on! It is supported. Specify a tuple of x variables. Don’t forget to add the variables to constant_data group. You can add a jitter effect with xjitter argument. Just add xjitter=True!

idata.constant_data['mom_age'] = data.mom_age
idata.constant_data['mom_hs'] = data.mom_hs
idata.constant_data['mom_work'] = data.mom_work
_,ax = plt.subplots(2,2, figsize=(20,12))
az.plot_lm(
idata=idata,
y="kid_score",
x=("mom_iq", "mom_age", "mom_hs", "mom_work"),
legend=False,
axes=ax,
)
png

If you face any issues while plotting, please go ahead with raising an issue ticket in the ArviZ repository. I’ll try my best to solve it ASAP.

Now, what if your data is multidimensional data? Complex things are made simple here. Let’s say we have a small fake dataset of inferneceData, described as below. We have observed a variable over 5 timestamps for 4 locations ( xy = 00, 01, 01, 11).

We can easily plot values by using plot_dim argument. Also, let’s also visualize xjitter.

az.plot_lm(
idata=idata,
y=”obs_val”,
plot_dim=”time”,
axes=ax,
xjitter=True
)

According to our data, this call should generate 4 plots for all 4 locations.
(xy) =
(00), (01),
(10), (11).

Be careful with large multi-dimensional data. You can use xarray indexing if you don’t want plot_lm to return a large number of plots for all the dimensions. For example, if you want to plot all timestamps for xy = 01 and 11, do it like:

az.plot_lm(
y=idata.observed_data.sel(y=1),
plot_dim="time",
y_hat = idata.posterior_predictive["obs_val"].sel(y=1)
)

You can also tweak the visualization stuff by using y_kwargs , y_hat_plot_kwargs and several others by changing marker style, size, color and many other visualization parameters. To know more about such parameters, you can visit matplotlib’s Line2D properties and bokeh’s circle parameters.

That’s all for now. I will be adding more blogs in future about plot_ts functionality. So, stay tuned! Please do follow, share and comment down your questions if any!

There is one more, plot_ts , which I am currently working on. User API and matplotlib function for plotting the overall time series is ready. However, the bokeh artist function still needs to be done. It can be extended by adding support for plotting components that are used to model the overall time series. You can refer to this PR if you want to know more about it.

Hope I made it easy for you! See you again, soon!

--

--

Utkarsh Maheshwari
Utkarsh Maheshwari

No responses yet