34 lines
742 B
Python
34 lines
742 B
Python
import numpy as np
|
|
import pymc as pm
|
|
import arviz as az
|
|
from matplotlib import pyplot as plt
|
|
|
|
|
|
x = np.random.normal(3, 2, size=30)
|
|
# plt.hist(x)
|
|
# plt.show()
|
|
|
|
with pm.Model() as model:
|
|
mu = pm.Uniform('mu', lower=-5, upper=5)
|
|
sig = pm.HalfNormal('sig', sigma=10)
|
|
data = pm.Normal('data', mu=mu, sigma=sig, observed=x)
|
|
|
|
pm.model_to_graphviz(model)
|
|
|
|
with model:
|
|
trace = pm.sample(5000, tune=1000)
|
|
|
|
# plt.plot(trace.posterior.data_vars["mu"][0, 1:100])
|
|
|
|
# az.plot_trace(trace)
|
|
summary = az.summary(trace)
|
|
|
|
with model:
|
|
post = pm.sample_posterior_predictive(trace)
|
|
|
|
plt.hist(post.posterior_predictive.data.to_numpy().flatten(), density=True, bins=50)
|
|
plt.hist(x, alpha=0.4, density=True, bins=6)
|
|
|
|
az.plot_ppc(post)
|
|
|
|
plt.show()
|