Skip to content

Programs with stochastic support #37

@rlouf

Description

@rlouf

Programs with stochastic support (with control flow in them) are notoriously difficult to implement and to sample [1].

JAX cannot JIT-compile a function with python control flow (although it can apply grad to it), and instead requires a special construct, jax.lax.cond or jax.lax.switch. This notation is cumbersome as:

  1. It is verbose
  2. It requires modelers and those who read their models to know JAX syntax.

Since MCX parses model expressions into a graphical model it does seem feasible to extract control flow, translate it into a graph structure which is then compiled into a JAX-compatible function.

We need to discuss:

  • How do we represent control flow in a graphical model?
  • How do we compile a graph with control flow into a function that can be JIT-compiled?
  • (Later) How do we go about sampling these models?

Does not need to implement for v0.1 but API design should be clear before releasing as it will impact the way the graph is structured. Use this issue for discussions.

References

[1]: "Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support" https://arxiv.org/abs/1910.13324

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions