-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Description
🚀 Feature
Currently, the API of torch.constraints
exposes a mix of singleton Constraint
instances and Constraint
classes. e.g.
>>> type(constraints.real)
torch.distributions.constraints._Real
>>> type(constraints.interval)
type
The ask is to uniformly return the more flexible class in the public API, e.g. instead of exposing real = _Real()
, we should instead expose real = _Real
.
Motivation
The motivation for this is that the current API makes checking for constraint type awkward, e.g. to check if a constraint is an interval or a real constraint, we need to do:
constraint is constraints.real or isinstance(constraint, constraints.interval)
, which requires figuring out which constraint is a type vs an actual instance, instead of more simply:
isinstance(constraint, (constraints.real, constraints.interval))
Note that this will be a breaking change and we may need to go through a round of deprecation to support this.
Alternative
Alternatively, we can instead do isinstance
checks on the non-public classes, e.g. isinstance(constraint, (constraints._Real, constraints._Interval)
, but I am not sure if we have strong reasons to not expose these classes publicly.
Another option is to provide a backward compatible utility function:
def _unwrap(constraint):
if isinstance(constraint, constraints.independent):
return _unwrap(constraint.base_constraint)
return constraint if isinstance(constraint, type) else constraint.__class__
def constraint_type_eq(constraint1, constraint2):
return _unwrap(constraint1) == _unwrap(constraint2)
cc @fritzo @neerajprad @alicanb @vishwakftw @nikitaved @fehiepsi