Skip to content

[bug] bart.py example consistently fails with cholesky error with default arguments #3001

@abeppu

Description

@abeppu

I'm just attempting to get started using pyro, and was trying to work through examples when I found that the example listed on the example page under "Multivariate Forecasting" reliably fails for me.

I see that there is a related open issue which is quite old #2017 which has some suggestions for the broader problem which seemed promising, but the PR following the approach discussed (#2019) was never approved.

It's not impossible that this is due to some environmental factor. However, I created a clean virtualenv to explore pyro. If there's some environmental contributing factor which I am not aware of, please document it, or even better, add a helper method to health-check a given environment.

If you can confirm/reproduce the failure, I would respectfully suggest that either

  • a solution should be pursued
  • a working example of a workaround should be exhibited and linked
  • or at a minimum non-functioning "examples" should be removed from the examples page

Issue Description

examples/contrib/forecast/bart.py fails with cholesky error when run with default params (no args)

Note that this example attempts to use the backtest method, which trains a model several times over different time windows. The first several such windows succeed.

The error appears as follows:

Traceback (most recent call last):
  File "examples/contrib/forecast/bart.py", line 180, in <module>
    main(args)
  File "examples/contrib/forecast/bart.py", line 156, in main
    forecaster_options=forecaster_options,
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/evaluate.py", line 205, in backtest
    batch_size=batch_size,
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 361, in __call__
    return super().__call__(data, covariates, num_samples, batch_size)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 390, in forward
    return self.model(data, covariates)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/nn/module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 185, in forward
    self.model(zero_data, covariates)
  File "examples/contrib/forecast/bart.py", line 121, in model
    self.predict(noise_model, prediction)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/contrib/forecast/forecaster.py", line 157, in predict
    noise = pyro.sample("residual", noise_dist)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/primitives.py", line 163, in sample
    apply_stack(msg)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 218, in apply_stack
    default_process_message(msg)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 179, in default_process_message
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/torch_distribution.py", line 49, in __call__
    if self.has_rsample
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/hmm.py", line 584, in rsample
    z = _sequential_gaussian_filter_sample(self._init, trans, sample_shape)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/distributions/hmm.py", line 144, in _sequential_gaussian_filter_sample
    contracted = joint.marginalize(left=state_dim)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/ops/gaussian.py", line 244, in marginalize
    P_b = cholesky(P_bb)
  File "/Users/aaron/.pyenv/versions/vpyro/lib/python3.7/site-packages/pyro/ops/tensor_utils.py", line 399, in cholesky
    return torch.linalg.cholesky(x)
RuntimeError: torch.linalg.cholesky: (Batch element 255): The factorization could not be completed because the input is not positive-definite (the leading minor of order 2 is not positive-definite).

Environment

value
OS macOS Big Sur (11.5.2) (intell)
python version 3.7.9
pytorch version 1.10.1
pyro version 1.8.0

Note, I get the same behavior on linux in docker.

Code Snippet

copy-pasted the example here: https://pyro.ai/examples/forecast_simple.html / https://github.com/pyro-ppl/pyro/blob/dev/examples/contrib/forecast/bart.py and simply ran:

python bart.py

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions