Skip to content

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 19, 2022

This makes a number of improvements to sequential_gaussian_filter_sample():

  1. Allows user to pass in noise, rather than drawing noise internally (i.e. more JAX-style). The motivation here is to allow both (1) computing the mean, and (2) supporting antithetic sampling where we can pass in cat([noise, -noise]).
  2. Avoid dropping the first entry of the returned samples. This makes sequential_gaussian_filter_sample() more useful outside of the context of GaussianHMM, e.g. it will be easier to use in a fully-observed Markov model. This attempts to make up for my admittedly bad design choice of making the hidden Z series one step longer than the observed X series in GaussianHMM 😬
  3. Adds a profiling script. I saw no speed changes due to this PR.
  4. Refactors some backward-sample logic to use placement into a torch.empty(), rather than torch.nn.functional.pad and torch.stack. This is admittedly less functional, but does reduce memory usage and IMHO reads cleaner.

This also exposes the helper matrix_and_gaussian_to_gaussian() which I'm finding useful.

Tested

  • added new gradient tests
  • added tests for antithetic sampling
  • profiled with and without gradients, saw no time difference

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes make sense to me. Looks great overall! Thanks, Fritz.

if noise is None:
noise = torch.randn(shape, dtype=loc.dtype, device=loc.device)
else:
noise = noise.reshape(shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to broadcast noise here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, it is better to avoid broadcasting the noise.

@fehiepsi fehiepsi merged commit 1098e38 into dev Oct 23, 2022
@fritzo
Copy link
Member Author

fritzo commented Oct 24, 2022

Thanks for reviewing @fehiepsi!

@fritzo fritzo deleted the gaussian-rsample branch August 10, 2023 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants