-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Labels
bugSomething isn't workingSomething isn't workingexampleFixing or adding an exampleFixing or adding an examplepriority-1Not bug, but high priority issue / PRNot bug, but high priority issue / PR
Description
While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.
What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.
Examples
Example from the rethinking
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def model():
μ <~ dist.Normal(178, 20)
σ <~ dist.Uniform(0, 50)
h <~ dist.Normal(μ, σ)
return h
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=model,
model_args=(),
num_samples=10_000
)
fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)
sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])
plt.tight_layout()
Result
Expected
Synthetic example 1
In this example I sample an offset
from Uniform(0, 1)
.
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_1():
center = 12
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_1,
model_args=(),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Synthetic example 2
This is the same example as above, but center
variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]
Code
import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx
from mcx import distributions as dist
from mcx import sample_joint
@mcx.model
def example_2(center):
offset <~ dist.Uniform(0, 1)
low = (center - offset)
high = (center + offset)
outcome <~ dist.Uniform(low, high)
rng_key = jax.random.PRNGKey(0)
prior_predictive = sample_joint(
rng_key=rng_key,
model=example_2,
model_args=(12, ),
num_samples=10_000
)
ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");
Result
Expectation
For the examples 1
and 2
, here's what I'd expect to get:
Environment
Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep 4 2020, 07:30:14)
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingexampleFixing or adding an exampleFixing or adding an examplepriority-1Not bug, but high priority issue / PRNot bug, but high priority issue / PR