-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
module: distributionsRelated to torch.distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
TransformedDistribution does not seem to have correct event_shape when the transform has changed the number of dimensions (e.g. with StickBreakingTransform
).
To Reproduce
Steps to reproduce the behavior:
import torch as tt
dist = tt.distributions.Dirichlet(tt.ones(3))
support = dist.support
tform = tt.distributions.constraint_registry.biject_to(support)
dist_unconstrained = tt.distributions.TransformedDistribution(dist,tform.inv)
print(tform)
print(dist_unconstrained.sample())
print(dist_unconstrained.event_shape)
Produces the results of
StickBreakingTransform()
tensor([ 0.3628, -1.9573])
torch.Size([3])
Expected behavior
The last print statement should (seemingly) give:
torch.Size([2])
Environment
- PyTorch Version (e.g., 1.0): 1.1.0
- OS (e.g., Linux): Mac
- How you installed PyTorch (
conda
,pip
, source):conda
- Build command you used (if compiling from source): n/a
- Python version: 3.6.4
- CUDA/cuDNN version: n/a
- GPU models and configuration: n/a
- Any other relevant information: n/a
Metadata
Metadata
Assignees
Labels
module: distributionsRelated to torch.distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module