Skip to content

TransformedDistribution and event_shape #21596

@justindomke

Description

@justindomke

🐛 Bug

TransformedDistribution does not seem to have correct event_shape when the transform has changed the number of dimensions (e.g. with StickBreakingTransform).

To Reproduce

Steps to reproduce the behavior:

import torch as tt
dist = tt.distributions.Dirichlet(tt.ones(3))
support = dist.support
tform = tt.distributions.constraint_registry.biject_to(support)
dist_unconstrained = tt.distributions.TransformedDistribution(dist,tform.inv)

print(tform)
print(dist_unconstrained.sample())
print(dist_unconstrained.event_shape)

Produces the results of

StickBreakingTransform()
tensor([ 0.3628, -1.9573])
torch.Size([3])

Expected behavior

The last print statement should (seemingly) give:

torch.Size([2])

Environment

  • PyTorch Version (e.g., 1.0): 1.1.0
  • OS (e.g., Linux): Mac
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): n/a
  • Python version: 3.6.4
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: n/a

Metadata

Metadata

Assignees

Labels

module: distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions