Skip to content

Use unified type for distributions.constraint API #50616

@neerajprad

Description

@neerajprad

🚀 Feature

Currently, the API of torch.constraints exposes a mix of singleton Constraint instances and Constraint classes. e.g.

>>> type(constraints.real)
torch.distributions.constraints._Real

>>> type(constraints.interval)
type

The ask is to uniformly return the more flexible class in the public API, e.g. instead of exposing real = _Real(), we should instead expose real = _Real.

Motivation

The motivation for this is that the current API makes checking for constraint type awkward, e.g. to check if a constraint is an interval or a real constraint, we need to do:

constraint is constraints.real or isinstance(constraint, constraints.interval)

, which requires figuring out which constraint is a type vs an actual instance, instead of more simply:

isinstance(constraint, (constraints.real, constraints.interval))

Note that this will be a breaking change and we may need to go through a round of deprecation to support this.

Alternative

Alternatively, we can instead do isinstance checks on the non-public classes, e.g. isinstance(constraint, (constraints._Real, constraints._Interval), but I am not sure if we have strong reasons to not expose these classes publicly.

Another option is to provide a backward compatible utility function:

def _unwrap(constraint):
    if isinstance(constraint, constraints.independent):
        return _unwrap(constraint.base_constraint)
    return constraint if isinstance(constraint, type) else constraint.__class__


def constraint_type_eq(constraint1, constraint2):
    return _unwrap(constraint1) == _unwrap(constraint2)

cc @fritzo @neerajprad @alicanb @vishwakftw @nikitaved @fehiepsi

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: distributionsRelated to torch.distributionstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions