Skip to content

MultivariateStudentT distribution initializes batch_shape as tuple instead of torch.Size #3098

@flo-schu

Description

@flo-schu

batch_shape = broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2])

While this is no problem in itself, this created an Error when using MultiVariateStudentT distribution in a higher level package, because it required batch_shape to be of type torch.Size

I don't see any issue with assigning the variable to a torch.Size object. Did I miss something?

batch_shape = torch.Size(broadcast_shape(df.shape, loc.shape[:-1], scale_tril.shape[:-2]))

Metadata

Metadata

Assignees

No one assigned

    Labels

    easyhelp wantedIssues suitable for, and inviting external contributionsrefactor

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions