-
-
Notifications
You must be signed in to change notification settings - Fork 995
Description
This issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say StreamingMCMC
) with similar interface to MCMC
and still independent of kernel (using either HMC
or NUTS
) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC
.
Along with the new StreamingMCMC
class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.
Tasks (to be split into multiple PRs)
- StreamingMCMC class #2857 Create a
StreamingMCMC
class with interface identical to MCMC (except disallowing parallel chains). - StreamingMCMC class #2857 Generalize unit tests of
MCMC
to parametrize over bothMCMC
andStreamingMCMC
- Add some tests ensuring
StreamingMCMC
andMCMC
perform identical computations, up to numerical precision - Create a tutorial using
StreamingMCMC
on a big model
- Implement pyro.ops.streaming module #2856 Create streaming helpers for mean, variance, etc.
- Add
r_hat
to pyro.ops.streaming - Add
n_eff = ess
to pyro.ops.streaming