import arviz as az import numpy as np
J = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) schools = np.array( [ "Choate", "Deerfield", "Phillips Andover", "Phillips Exeter", "Hotchkiss", "Lawrenceville", "St. Paul's", "Mt. Hermon", ] )ArviZ style sheets#
# ArviZ ships with style sheets! az.style.use("arviz-darkgrid")
Feel free to check the examples of style sheets here.
Get started with plotting#ArviZ is designed to be used with libraries like PyStan and PyMC3, but works fine with raw NumPy arrays.
Plotting a dictionary of arrays, ArviZ will interpret each key as the name of a different random variable. Each row of an array is treated as an independent series of draws from the variable, called a chain. Below, we have 10 chains of 50 draws, each for four different distributions.
size = (10, 50) az.plot_forest( { "normal": rng.normal(size=size), "gumbel": rng.gumbel(size=size), "student t": rng.standard_t(df=6, size=size), "exponential": rng.exponential(size=size), } );ArviZ rcParams#
You may have noticed that for both plot_posterior()
and plot_forest()
, the Highest Density Interval (HDI) is 94%, which you may find weird at first. This particular value is a friendly reminder of the arbitrary nature of choosing any single value without further justification, including common values like 95%, 50% and even our own default, 94%. ArviZ includes default values for a few parameters, you can access them with az.rcParams
. To change the default confidence interval (CI) value (including HDI) to let’s say 90% you can do:
az.rcParams["stats.ci_prob"] = 0.90PyMC integration#
ArviZ integrates with PyMC. In fact, the object returned by default by most PyMC sampling methods is the arviz.InferenceData
object.
Therefore, we only need to define a model, sample from it and we can use the result with ArviZ straight away.
with pm.Model(coords={"school": schools}) as centered_eight: mu = pm.Normal("mu", mu=0, sigma=5) tau = pm.HalfCauchy("tau", beta=5) theta = pm.Normal("theta", mu=mu, sigma=tau, dims="school") pm.Normal("obs", mu=theta, sigma=sigma, observed=y, dims="school") # This pattern can be useful in PyMC idata = pm.sample_prior_predictive() idata.extend(pm.sample()) pm.sample_posterior_predictive(idata, extend_inferencedata=True)
Sampling: [mu, obs, tau, theta] Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (4 chains in 4 jobs) NUTS: [mu, tau, theta]
100.00% [8000/8000 00:05<00:00 Sampling 4 chains, 84 divergences]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 6 seconds. The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details The effective sample size per chain is smaller than 100 for some parameters. A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details There were 84 divergences after tuning. Increase `target_accept` or reparameterize. Sampling: [obs]
100.00% [4000/4000 00:00<00:00]
Here we have combined the outputs of prior sampling, MCMC sampling to obtain the posterior samples and posterior predictive samples into a single InferenceData
, the main ArviZ data structure.
The more groups it has contains the more powerful analyses it can perform. You can check the InferenceData
structure specification here.
Tip
By default, PyMC does not compute the pointwise log likelihood values, which are needed for model comparison with WAIC or PSIS-LOO-CV. Use idata_kwargs={"log_likelihood": True}
to have it computed right after sampling for you. Alternatively, you can also use pymc.compute_log_likelihood()
before calling compare()
, loo()
, waic()
or loo_pit()
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: mu (chain, draw) float64 10.02 7.399 6.104 1.803 ... 1.041 12.6 10.12 theta (chain, draw, school) float64 8.304 4.128 12.45 ... 12.92 10.8 tau (chain, draw) float64 3.041 3.737 3.529 1.581 ... 3.22 1.696 2.607 Attributes: created_at: 2023-12-21T18:42:25.932752 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2 sampling_time: 5.5613462924957275 tuning_steps: 1000
chain
(chain)
int64
0 1 2 3
draw
(draw)
int64
0 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
school
(school)
<U16
'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
mu
(chain, draw)
float64
10.02 7.399 6.104 ... 12.6 10.12
array([[10.01616859, 7.39870754, 6.10390012, ..., 6.54997489, 2.94778869, 6.91650085], [ 4.01149902, 2.07855056, -0.68667672, ..., -0.54235784, 0.05805515, 9.09085323], [ 5.49064057, 4.59822153, 2.04148792, ..., 2.07279391, 5.69873861, 3.32367995], [ 8.70829797, 7.55870198, 6.85748359, ..., 1.04081216, 12.59510165, 10.117066 ]])
theta
(chain, draw, school)
float64
8.304 4.128 12.45 ... 12.92 10.8
array([[[ 8.30361561, 4.12758049, 12.4460451 , ..., 6.95713023, 13.75430465, 9.77551649], [ 7.58401019, 13.78469311, 5.94050566, ..., 7.6708223 , 8.92676495, 10.28714692], [ 7.18025463, 4.90120631, 5.45850116, ..., 12.92216355, 5.95912878, 4.04596538], ..., [ -2.6746592 , 10.30530899, 8.41038577, ..., 0.08389213, 6.10968175, 6.66570827], [ 14.3746294 , 9.72659713, 5.1220349 , ..., 4.31835766, 3.8050433 , -10.60413484], [ 8.61240144, 7.32946535, 2.39715987, ..., 4.8712123 , 16.55573237, 22.51555813]], [[ 1.56394196, 2.97062537, 6.49213497, ..., 2.25145565, 7.14652964, 6.21873441], [ 8.12024519, 4.66304429, -0.27855786, ..., 6.45419283, 2.05051969, 0.94996111], [ -2.96163717, -1.09334579, 2.53297674, ..., -2.7709398 , 2.36146314, 1.86246516], ... [ 5.16190827, -2.40429508, -1.21305766, ..., 3.63524491, -0.29313777, 0.07086012], [ 4.64367878, 1.21535343, 0.16244389, ..., 6.44769038, 9.58517892, 3.24432684], [ 9.54081251, 3.24954791, 1.66766636, ..., -3.46069732, 11.00560454, 5.93449964]], [[ 6.1691707 , 7.41416714, 5.19833713, ..., 5.81133824, 12.88042451, 9.63653997], [ 10.10785454, 5.9766506 , 8.28513382, ..., 8.02011269, 10.5569741 , 6.18588925], [ 7.48063275, 8.14469625, 4.88856739, ..., 8.65600797, 9.0646824 , 2.21830007], ..., [ 7.02525925, -0.06950446, -1.83976303, ..., -1.8569798 , 4.51451922, -1.40221565], [ 15.35244947, 14.93547841, 10.61935592, ..., 10.24425272, 12.3515291 , 9.32315654], [ 11.2086903 , 11.7297221 , 11.28618529, ..., 9.58590561, 12.91900154, 10.80205844]]])
tau
(chain, draw)
float64
3.041 3.737 3.529 ... 1.696 2.607
array([[3.04085424, 3.73699557, 3.52916984, ..., 6.75409648, 7.61308722, 6.03115644], [1.92170081, 3.31815377, 2.3888031 , ..., 5.73293901, 7.11182748, 3.31980662], [3.64170796, 2.30682737, 2.05620781, ..., 2.96123711, 4.46687344, 3.39035597], [3.55027977, 2.63062658, 2.72902953, ..., 3.22020415, 1.69610654, 2.60744401]])
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
PandasIndex
PandasIndex(Index(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', 'St. Paul's', 'Mt. Hermon'], dtype='object', name='school'))
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000, school: 8) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 -4.287 -7.086 3.44 ... 11.22 39.41 Attributes: created_at: 2023-12-21T18:42:27.486051 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2
chain
(chain)
int64
0 1 2 3
draw
(draw)
int64
0 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
school
(school)
<U16
'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
obs
(chain, draw, school)
float64
-4.287 -7.086 3.44 ... 11.22 39.41
array([[[ -4.28736059, -7.08630683, 3.43999628, ..., 16.72805701, 18.15235934, -14.29239245], [ -6.62176948, 17.57260739, -25.15532315, ..., 15.95505077, 19.80829143, 27.20685887], [ -4.86280362, 13.31918717, 12.71011033, ..., 7.56424872, 4.50234324, 9.70560954], ..., [-21.78223063, 16.00223392, 15.48546058, ..., -11.48500656, 10.69597094, 2.22936059], [ -9.76438288, 32.02301127, -8.85311159, ..., 0.78011742, -3.15859846, 9.53409738], [ 1.25476996, 13.4197506 , 2.34174843, ..., 27.2619465 , 22.96631099, 13.60202959]], [[-28.28208931, 4.59345136, -4.39995246, ..., 5.53726941, -7.96284873, -1.17957727], [ 14.10173233, 9.94862403, 9.94349407, ..., 2.76214882, -21.34425512, -0.10119237], [-16.76594888, -6.66164047, 18.48177012, ..., -11.33043287, 6.29539644, -5.88779985], ... [ 22.32187942, -10.99182765, 21.30504729, ..., -6.17543156, 7.12612982, -26.45690048], [-17.0249973 , 4.84921373, -21.36686519, ..., 2.27788074, -14.39792348, -38.51356827], [ 20.85685696, 7.66274685, 34.36628215, ..., 4.75454461, 12.16792906, -0.62248367]], [[ 13.85162016, 8.4981051 , 17.76871673, ..., 13.37470768, 37.30277795, 14.29847321], [ -9.92076077, 13.4244665 , -10.15615617, ..., 21.09032283, 10.75412907, 22.10614908], [ 21.44892537, -7.32517093, 9.11179549, ..., 6.62885944, 17.56781464, -12.16482526], ..., [ 5.23552978, 4.9983298 , 3.19652755, ..., 13.35002976, 14.52556399, -4.20904121], [ 27.55723066, 21.05638371, 14.93803856, ..., 28.23393337, 15.04141601, 0.5618765 ], [ 18.19689229, 19.37122514, -12.74516527, ..., 15.24767448, 11.22430745, 39.41317922]]])
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
PandasIndex
PandasIndex(Index(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', 'St. Paul's', 'Mt. Hermon'], dtype='object', name='school'))
<xarray.Dataset> Dimensions: (chain: 4, draw: 1000) Coordinates: * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999 Data variables: (12/17) energy (chain, draw) float64 57.32 60.94 ... 61.22 58.55 step_size (chain, draw) float64 0.264 0.264 ... 0.2667 0.2667 index_in_trajectory (chain, draw) int64 2 -4 -5 5 3 -4 ... 1 0 -8 2 -10 6 energy_error (chain, draw) float64 0.1952 0.251 ... -0.1325 0.279 tree_depth (chain, draw) int64 3 3 4 4 4 3 4 4 ... 2 4 3 4 3 5 3 process_time_diff (chain, draw) float64 0.001172 0.001172 ... 0.001192 ... ... diverging (chain, draw) bool False False False ... False False acceptance_rate (chain, draw) float64 0.8216 0.7868 ... 0.9907 0.842 n_steps (chain, draw) float64 7.0 7.0 15.0 ... 7.0 23.0 7.0 lp (chain, draw) float64 -55.39 -55.34 ... -55.68 -51.57 step_size_bar (chain, draw) float64 0.2845 0.2845 ... 0.2817 0.2817 perf_counter_start (chain, draw) float64 1.286e+04 ... 1.286e+04 Attributes: created_at: 2023-12-21T18:42:25.944471 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2 sampling_time: 5.5613462924957275 tuning_steps: 1000
chain
(chain)
int64
0 1 2 3
draw
(draw)
int64
0 1 2 3 4 5 ... 995 996 997 998 999
array([ 0, 1, 2, ..., 997, 998, 999])
energy
(chain, draw)
float64
57.32 60.94 59.76 ... 61.22 58.55
array([[57.32383154, 60.94068619, 59.75756164, ..., 63.53387493, 65.37102685, 65.20841989], [60.14130178, 58.89311556, 59.14713576, ..., 63.5270502 , 71.8738035 , 68.80343138], [55.93059037, 56.92180979, 56.94315155, ..., 59.49443584, 63.72644036, 60.48630773], [58.81676494, 58.04711877, 54.33271884, ..., 58.46144443, 61.22100117, 58.54953497]])
step_size
(chain, draw)
float64
0.264 0.264 0.264 ... 0.2667 0.2667
array([[0.26396035, 0.26396035, 0.26396035, ..., 0.26396035, 0.26396035, 0.26396035], [0.19332428, 0.19332428, 0.19332428, ..., 0.19332428, 0.19332428, 0.19332428], [0.12696273, 0.12696273, 0.12696273, ..., 0.12696273, 0.12696273, 0.12696273], [0.26672933, 0.26672933, 0.26672933, ..., 0.26672933, 0.26672933, 0.26672933]])
index_in_trajectory
(chain, draw)
int64
2 -4 -5 5 3 -4 ... 1 0 -8 2 -10 6
array([[ 2, -4, -5, ..., 7, -14, -12], [ -1, -5, 7, ..., 6, -10, 19], [ -3, 4, 5, ..., -9, -10, 5], [ 2, -2, -3, ..., 2, -10, 6]])
energy_error
(chain, draw)
float64
0.1952 0.251 ... -0.1325 0.279
array([[ 0.19524881, 0.25099148, -0.19756085, ..., 0.05831026, 0.27881153, -0.03383447], [-1.51505314, 0.2719956 , -0.07500553, ..., -0.08881493, -0.25774433, -0.02570111], [ 0.32780594, -0.14535642, 0.50925839, ..., -0.11633583, -0.09823986, 0.0702055 ], [-0.16500366, -0.23936873, 0.07150804, ..., -0.09337167, -0.13250795, 0.27902743]])
tree_depth
(chain, draw)
int64
3 3 4 4 4 3 4 4 ... 3 2 4 3 4 3 5 3
array([[3, 3, 4, ..., 4, 4, 4], [2, 3, 5, ..., 4, 4, 6], [3, 3, 3, ..., 5, 4, 4], [4, 3, 4, ..., 3, 5, 3]])
process_time_diff
(chain, draw)
float64
0.001172 0.001172 ... 0.001192
array([[0.0011721 , 0.00117216, 0.00174558, ..., 0.00214332, 0.00261753, 0.00273903], [0.00062291, 0.00130788, 0.00494658, ..., 0.00181786, 0.00162899, 0.00613694], [0.00082023, 0.00083151, 0.00084115, ..., 0.00231952, 0.00163583, 0.00183312], [0.00243389, 0.00140381, 0.00207236, ..., 0.00119947, 0.00342331, 0.00119218]])
perf_counter_diff
(chain, draw)
float64
0.001172 0.001172 ... 0.001192
array([[0.00117203, 0.00117196, 0.00174511, ..., 0.00214318, 0.0026166 , 0.00273846], [0.00062261, 0.00130776, 0.00494587, ..., 0.00181556, 0.00162879, 0.00613635], [0.00082018, 0.00083115, 0.00084088, ..., 0.00231939, 0.00163543, 0.00183269], [0.00243336, 0.00140354, 0.00207217, ..., 0.00119927, 0.00342249, 0.00119172]])
reached_max_treedepth
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
largest_eigval
(chain, draw)
float64
nan nan nan nan ... nan nan nan nan
array([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]])
smallest_eigval
(chain, draw)
float64
nan nan nan nan ... nan nan nan nan
array([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]])
max_energy_error
(chain, draw)
float64
1.35 0.663 ... -0.4735 0.7188
array([[ 1.35006912e+00, 6.62979180e-01, -9.39626470e-01, ..., 8.66206569e-01, 6.61696749e-01, 6.84546362e-02], [ 7.05993837e+00, 3.99967580e-01, -1.41582062e-01, ..., -1.67665478e-01, 1.12346535e+00, -1.66112536e-01], [ 8.22196032e-01, -1.84360521e-01, 5.09258392e-01, ..., -1.08101560e+00, -4.19907552e-01, -1.11379334e-01], [-7.76180766e-01, 1.08352547e+00, 1.80811389e+03, ..., -9.33716679e-02, -4.73459338e-01, 7.18777698e-01]])
diverging
(chain, draw)
bool
False False False ... False False
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, True, ..., False, False, False]])
acceptance_rate
(chain, draw)
float64
0.8216 0.7868 ... 0.9907 0.842
array([[0.82160664, 0.78679582, 0.93381245, ..., 0.94057589, 0.81137994, 0.97880393], [0.33546992, 0.83651698, 0.98980225, ..., 1. , 0.81939648, 0.9873885 ], [0.68021301, 0.98270213, 0.87155011, ..., 0.92903356, 0.97522716, 0.97625585], [0.98162912, 0.89135783, 0.75874678, ..., 0.99760908, 0.9906563 , 0.84200239]])
n_steps
(chain, draw)
float64
7.0 7.0 15.0 15.0 ... 7.0 23.0 7.0
array([[ 7., 7., 15., ..., 15., 15., 15.], [ 3., 7., 31., ..., 15., 15., 63.], [ 7., 7., 7., ..., 23., 15., 15.], [15., 7., 13., ..., 7., 23., 7.]])
lp
(chain, draw)
float64
-55.39 -55.34 ... -55.68 -51.57
array([[-55.39474912, -55.34498455, -55.90118011, ..., -61.19745026, -60.65705623, -60.11706429], [-52.15261145, -53.96492738, -53.28848263, ..., -59.91952462, -62.11250053, -59.1320069 ], [-53.18315128, -52.57438688, -54.71455537, ..., -56.29985037, -56.63658834, -56.55218317], [-55.03832612, -51.12725249, -52.32267991, ..., -54.07525829, -55.67959675, -51.57364915]])
step_size_bar
(chain, draw)
float64
0.2845 0.2845 ... 0.2817 0.2817
array([[0.28448221, 0.28448221, 0.28448221, ..., 0.28448221, 0.28448221, 0.28448221], [0.22994461, 0.22994461, 0.22994461, ..., 0.22994461, 0.22994461, 0.22994461], [0.23240795, 0.23240795, 0.23240795, ..., 0.23240795, 0.23240795, 0.23240795], [0.28174654, 0.28174654, 0.28174654, ..., 0.28174654, 0.28174654, 0.28174654]])
perf_counter_start
(chain, draw)
float64
1.286e+04 1.286e+04 ... 1.286e+04
array([[12860.06279232, 12860.06412191, 12860.06541501, ..., 12862.00241087, 12862.00466867, 12862.00748065], [12860.34520235, 12860.34650129, 12860.34797863, ..., 12862.77109214, 12862.77310547, 12862.77485077], [12860.03398152, 12860.03490944, 12860.03585 , ..., 12862.16758335, 12862.17003923, 12862.1717998 ], [12860.20777938, 12860.21041693, 12860.2119971 , ..., 12862.10850612, 12862.10986241, 12862.1134737 ]])
PandasIndex
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: theta (chain, draw, school) float64 52.13 -71.41 148.5 ... 1.115 6.39 tau (chain, draw) float64 120.4 7.113 1.983 2.866 ... 8.423 6.926 12.31 mu (chain, draw) float64 -2.798 1.822 -4.905 ... -1.888 -4.516 1.978 Attributes: created_at: 2023-12-21T18:42:18.246297 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2
chain
(chain)
int64
0
draw
(draw)
int64
0 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
school
(school)
<U16
'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
theta
(chain, draw, school)
float64
52.13 -71.41 148.5 ... 1.115 6.39
array([[[ 52.12990848, -71.40961046, 148.45396261, ..., -58.1668519 , -108.38964215, -107.22607952], [ -2.98219909, -12.40068559, 9.08255035, ..., -2.42641586, 10.54483923, -4.30264517], [ -8.21486982, -8.0543877 , -9.33841132, ..., -6.75280651, -7.68786744, -3.35798233], ..., [ -4.56610777, -10.24940233, 1.52571805, ..., -12.87461539, 12.431841 , 13.42135243], [ -6.66028709, -1.00695966, -8.61142924, ..., -1.80970085, 1.55669161, 4.97409914], [ 3.30202877, -8.26092296, -2.3766941 , ..., 7.09251568, 1.11544978, 6.39030883]]])
tau
(chain, draw)
float64
120.4 7.113 1.983 ... 6.926 12.31
array([[1.20350259e+02, 7.11308756e+00, 1.98281662e+00, 2.86618622e+00, 4.86023707e+00, 4.12924177e+00, 1.10895048e+00, 9.16182643e+00, 2.85166255e+01, 1.74994497e-01, 1.84395392e+00, 5.70075034e+00, 4.34822607e+00, 1.11818158e+01, 2.37084677e+01, 3.48191862e+00, 2.33009345e+00, 5.59032773e+00, 7.25827603e-01, 1.25367168e+01, 1.11049471e+00, 1.64707902e+00, 1.30782277e+01, 9.25774987e+00, 1.96307416e+00, 9.34262262e+00, 1.51509478e+01, 2.73865100e+00, 9.20863813e+00, 1.30200464e+01, 1.13417430e+01, 7.24452377e+00, 5.58936409e+01, 4.55256171e+00, 2.09355650e+01, 7.87157508e+01, 1.25073997e+00, 1.05482340e+01, 9.68377239e+00, 1.34340416e+01, 1.13248203e+01, 1.68897062e+00, 1.27154565e+00, 3.57038626e+00, 4.90245059e+00, 9.54757370e+00, 6.68431598e+01, 4.25480046e+00, 2.60988431e+00, 5.50636724e-01, 5.92330659e+01, 8.51916693e+00, 3.67282077e+00, 1.15331926e-01, 7.26938416e+00, 4.51115350e+00, 5.97688512e+00, 3.65723879e+00, 2.15854909e+00, 1.30355056e+01, 1.59584588e+01, 3.87085457e+01, 4.62751904e+00, 1.22047539e+01, 3.25320640e+00, 7.25002051e+00, 9.02612223e+00, 1.90036329e+01, 4.11208964e+00, 2.56353763e+01, 5.03188753e+01, 1.00154294e+01, 8.14445953e+00, 4.19069239e+00, 1.59763056e+02, 6.82084761e-01, 5.79787659e+00, 4.58533286e+00, 4.32377594e-01, 4.00140231e-01, ... 1.20927319e+01, 1.09217486e+02, 8.09245279e-01, 4.05053633e-01, 1.61738617e+01, 2.04276129e+00, 5.01392921e+00, 1.14407830e+01, 7.79520759e-01, 2.33802397e+00, 7.79222949e+01, 1.14714627e+00, 2.01627028e+01, 2.19033375e+00, 2.94083077e+01, 9.86448478e+00, 4.60376303e+01, 8.03507868e+00, 3.55859534e+00, 4.80279244e+01, 1.71598912e+00, 4.76930406e+00, 1.35801148e+00, 4.47380841e-01, 9.34524659e-01, 1.50793989e+01, 1.40747670e-01, 2.11279391e+01, 1.17776419e+01, 6.38138487e+00, 1.63582280e+01, 6.14220622e+01, 1.56289342e+00, 3.45122200e+00, 3.43429985e+00, 2.38708720e+00, 5.58917921e+01, 3.78475321e+00, 1.84628445e+00, 1.29671436e+01, 3.12313646e+00, 1.00472178e+01, 1.51128247e+01, 3.48763414e+00, 2.51887909e+00, 6.46846605e-02, 9.19507533e+00, 8.28736351e+00, 9.74960177e+00, 7.17480124e+00, 1.19787569e+00, 7.07415378e+00, 9.37868082e+00, 2.05886731e+01, 4.17808320e+00, 6.29649463e+00, 4.55584171e+00, 5.74094751e+01, 8.52208295e-01, 1.43970404e+00, 7.45792929e+00, 1.17502011e+01, 3.53416661e+00, 1.69906888e+01, 2.81542211e+00, 2.39777196e+00, 1.76489141e+01, 1.65566721e+00, 9.21379395e+01, 1.46799386e+00, 7.51647608e+00, 1.65002186e+00, 1.06012444e+00, 1.09200809e+01, 3.64908694e-01, 3.47487215e+01, 1.54201350e+00, 8.42311604e+00, 6.92604987e+00, 1.23086651e+01]])
mu
(chain, draw)
float64
-2.798 1.822 ... -4.516 1.978
array([[-2.79833170e+00, 1.82209372e+00, -4.90506024e+00, 3.12166668e-01, -3.92926762e-01, 5.24056501e+00, 5.31038852e+00, -2.14644544e+00, 7.60262031e+00, 6.84433542e+00, -8.75254388e+00, 6.17804014e+00, -5.99611903e+00, -5.44571841e+00, -2.86859377e+00, -3.61870529e+00, 4.83170566e+00, -7.43538533e-01, -1.21566386e+00, -8.93201683e+00, -3.10858146e+00, -2.07288623e+00, -3.95690735e+00, 6.84391594e-01, -1.20126271e+00, 1.24603380e+00, 4.20578097e+00, 3.01721685e+00, -1.97051697e-01, -5.56761380e+00, 2.16113952e+00, -1.24307193e+00, -5.17899598e+00, -7.86351241e+00, 7.09058889e-01, 7.84623504e+00, -5.91593306e+00, -1.98517761e-01, -4.11057086e+00, 1.13272204e+00, 1.84479869e+00, 8.55394778e+00, -9.24428722e-01, -6.44164012e+00, -4.82004584e-01, -1.40243947e+00, -8.07018930e-01, -1.01082422e+01, 4.92238142e+00, -1.22367463e+00, 1.94229167e-01, -4.91459674e-01, 7.13683434e-01, 8.45109167e-01, -1.74213299e+00, 5.43230803e+00, 1.45928534e+00, -2.74871196e+00, -8.10984395e-01, 1.07899563e+01, ... -1.95669331e+00, -3.65687045e+00, 5.32234730e+00, -2.82006939e+00, 1.96915264e+00, 1.76562028e+00, 5.70080916e+00, 3.76296061e+00, -1.29746629e-01, 9.35831682e+00, -5.01770874e+00, 1.54146099e+00, 2.53221810e-01, -2.33403779e-02, 8.30452329e+00, 6.63325426e+00, -5.03843820e+00, -2.31213771e+00, 5.21966113e+00, -3.04654561e+00, -5.71906351e+00, 9.26553161e-02, 6.33173305e+00, 2.46773051e+00, -9.91811137e-01, 6.16264112e+00, 1.02634985e+01, -1.27867416e-01, 6.41327098e+00, -5.50751753e-01, -3.81962946e+00, 5.70809183e+00, -2.30364545e+00, 2.68985985e+00, 5.82606613e+00, 3.80816000e-02, -1.29321503e+00, -1.08656890e-01, -1.24800623e+00, -9.74227223e+00, 1.04810806e-01, -3.17858647e+00, -2.56822963e+00, 5.52915630e+00, 3.59564839e+00, -2.01626273e+00, 7.09334652e+00, -1.74779070e+00, 2.35849644e+00, 3.89676935e+00, -7.22788306e-01, -3.97312647e+00, -5.27230708e+00, 3.93372845e+00, -1.15734268e+00, 5.35697289e+00, -1.88774187e+00, -4.51587139e+00, 1.97767608e+00]])
PandasIndex
PandasIndex(Index([0], dtype='int64', name='chain'))
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 490, 491, 492, 493, 494, 495, 496, 497, 498, 499], dtype='int64', name='draw', length=500))
PandasIndex
PandasIndex(Index(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', 'St. Paul's', 'Mt. Hermon'], dtype='object', name='school'))
<xarray.Dataset> Dimensions: (chain: 1, draw: 500, school: 8) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499 * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (chain, draw, school) float64 78.06 -75.96 125.8 ... 2.925 19.95 Attributes: created_at: 2023-12-21T18:42:18.248182 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2
chain
(chain)
int64
0
draw
(draw)
int64
0 1 2 3 4 5 ... 495 496 497 498 499
array([ 0, 1, 2, ..., 497, 498, 499])
school
(school)
<U16
'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
obs
(chain, draw, school)
float64
78.06 -75.96 125.8 ... 2.925 19.95
array([[[ 78.06421367, -75.95720602, 125.77991556, ..., -30.20883426, -98.43989718, -133.89361943], [ -8.86433334, -11.66984393, 2.60940893, ..., -2.45778134, 0.69198926, 9.78731587], [ 3.22028416, -5.57040584, -24.53336264, ..., -4.41181677, -1.95566696, 5.97762225], ..., [ 7.52022619, 0.64134247, 6.57555305, ..., -11.16382856, 17.68101215, 11.05142281], [ 9.39343144, 0.91883489, -8.0791263 , ..., -22.06784172, 12.42686913, -8.68197749], [ 3.65802969, -18.63997147, -11.5706403 , ..., 21.48075845, 2.92502201, 19.95380577]]])
PandasIndex
PandasIndex(Index([0], dtype='int64', name='chain'))
PandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 490, 491, 492, 493, 494, 495, 496, 497, 498, 499], dtype='int64', name='draw', length=500))
PandasIndex
PandasIndex(Index(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', 'St. Paul's', 'Mt. Hermon'], dtype='object', name='school'))
<xarray.Dataset> Dimensions: (school: 8) Coordinates: * school (school) <U16 'Choate' 'Deerfield' ... "St. Paul's" 'Mt. Hermon' Data variables: obs (school) float64 28.0 8.0 -3.0 7.0 -1.0 1.0 18.0 12.0 Attributes: created_at: 2023-12-21T18:42:18.248964 arviz_version: 0.17.0.dev0 inference_library: pymc inference_library_version: 5.10.2
school
(school)
<U16
'Choate' ... 'Mt. Hermon'
array(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', "St. Paul's", 'Mt. Hermon'], dtype='<U16')
obs
(school)
float64
28.0 8.0 -3.0 7.0 ... 1.0 18.0 12.0
array([28., 8., -3., 7., -1., 1., 18., 12.])
PandasIndex
PandasIndex(Index(['Choate', 'Deerfield', 'Phillips Andover', 'Phillips Exeter', 'Hotchkiss', 'Lawrenceville', 'St. Paul's', 'Mt. Hermon'], dtype='object', name='school'))
Below is a “trace plot”, a common visualization to check MCMC output and assess convergence. Note that the labeling information we included in the PyMC model via the coords
and dims
arguments is kept and added to the plot (it is also available in the InferenceData HTML representation above):
ArviZ also has first class support for CmdStanPy. After creating and sampling a CmdStanPy model:
from cmdstanpy import CmdStanModel model = CmdStanModel(stan_file="schools.stan")
/home/oriol/bin/miniforge3/envs/general/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
fit = model.sample(data="schools.json")
19:42:30 - cmdstanpy - INFO - CmdStan start processing chain 1 | | 00:00 Status chain 2 | | 00:00 Status chain 3 | | 00:00 Status chain 4 | | 00:00 Status chain 3 |███████████████████████████████ | 00:00 Iteration: 1600 / 2000 [ 80%] (Sampling) chain 1 |███████████████████████████████████████████████████████████| 00:00 Sampling completed chain 2 |███████████████████████████████████████████████████████████| 00:00 Sampling completed chain 3 |███████████████████████████████████████████████████████████| 00:00 Sampling completed chain 4 |███████████████████████████████████████████████████████████| 00:00 Sampling completed
19:42:30 - cmdstanpy - INFO - CmdStan done processing. 19:42:30 - cmdstanpy - WARNING - Some chains may have failed to converge. Chain 1 had 27 divergent transitions (2.7%) Chain 2 had 9 divergent transitions (0.9%) Chain 3 had 4 divergent transitions (0.4%) Chain 4 had 20 divergent transitions (2.0%) Use the "diagnose()" method on the CmdStanMCMC object to see further information.
The result can be used for plotting with ArviZ directly:
To make the most out of ArviZ however, it is recommended to convert the results to InferenceData. This will ensure all variables are assigned to the right groups and also gives you the option of labeling the data.
Tip
If ArviZ finds any variable names log_lik
in the CmdStanPy output, it will interpret them as the pointwise log likelihood values, in line with the Stan conventions used by the R libraries.
idata = az.from_cmdstanpy( fit, posterior_predictive="y_hat", dims={"y_hat": ["school"], "theta": ["school"]}, coords={"school": schools} ) az.plot_posterior(idata, var_names=["tau", "theta"]);
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4