Skip to content

Incorrect results when sampling from the prior #90

@elanmart

Description

@elanmart

While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.

What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.

Examples

Example from the rethinking

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def model():
    μ <~ dist.Normal(178, 20)
    σ <~ dist.Uniform(0, 50)
    h <~ dist.Normal(μ, σ)
    
    return h

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=model, 
    model_args=(), 
    num_samples=10_000
)

fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)

sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])

plt.tight_layout()

Result

image

Expected

image

Synthetic example 1

In this example I sample an offset from Uniform(0, 1).
Then I sample from Uniform(12 - offset, 12 + offset)
So I expect my samples to be distributed in range [11, 13]
But I get samples in range [-15, 15]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_1():
    
    center = 12
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_1, 
    model_args=(), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Synthetic example 2

This is the same example as above, but center variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_2(center):
    
    offset <~ dist.Uniform(0, 1)
    
    low = (center - offset)
    high = (center + offset)
    
    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_2, 
    model_args=(12, ), 
    num_samples=10_000
)


ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Expectation

For the examples 1 and 2, here's what I'd expect to get:

image

Environment

Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingexampleFixing or adding an examplepriority-1Not bug, but high priority issue / PR

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions