Skip to content

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Jan 14, 2021

Addresses pyro-ppl/funsor#412
Requires pytorch/pytorch#50547 and pytorch/pytorch#50581

This adds a TorchDistribution.infer_shapes() method to statically compute (batch_shape, event_shape) for use in Funsor. This method inputs *args,**kwargs with the same signature as .__init__() but inputting shapes rather than tensors. The following should be equivalent:

# Eager version.
d = MultivariateNormal(loc, covariance_matrix)
batch_shape = d.batch_shape
event_shape = d.event_shape

# Lazy version.
batch_shape, event_shape = MultivariateNormal.infer_shapes(loc.shape, covariance_matrix.shape)

The default implementation uses new .support.event_dim logic from pytorch/pytorch#50547, so I've included backports in this PR (we can delete next time PyTorch is released).

Tested

  • unit tests for .support.is_discrete
  • unit tests for .support.event_dim
  • unit tests for .infer_shapes()

@fehiepsi
Copy link
Member

@fritzo Do we need to do the same for transforms?

@fritzo
Copy link
Member Author

fritzo commented Jan 15, 2021

Do we need to do the same for transforms?

Indeed to implement .infer_shapes() for TransformedDistribution we will need some sort of static shape computation on Transforms. @stefanwebb has been thinking about this in his Bijectors design doc. I think we will be able to use that doc's suggested .forward_shape() method, but I'd like to wait to implement that until a later PR.

UPDATE I have implemented Transform.forward_shape() and .inverse_shape() in pytorch/pytorch#50581

@fritzo
Copy link
Member Author

fritzo commented Feb 23, 2021

@fehiepsi this is now unblocked 😄

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@fehiepsi fehiepsi merged commit fd84ea8 into dev Feb 23, 2021
@fritzo
Copy link
Member Author

fritzo commented Feb 23, 2021

Thanks for reviewing @fehiepsi !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants