-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🐛 Bug
When a batch dimension is reinterpreted to be an event dimension, the number of event dimensions should increase by 1 (e.g. a distribution over scalars becomes a distribution over vectors) and as a result the distribution's support should be updated.
Because the distribution's support is not updated, transforms inferred by biject_to
are applied elementwise (rather than summing/squeezing out the event_dim
) and methods like log_abs_det_jacobian
do not correctly compute the determinant of what now should be a multivariate transform (here, this multivariate log determinant is the sum of the elementwise log determinants).
Maybe we should call it log_abs_diag_jacobian
here instead...? Just kidding.
To Reproduce
When I run
import torch.distributions as dist
dist.Independent(dist.Normal(torch.zeros(1), torch.ones(1)), 1).support == dist.Normal(torch.zeros(1), torch.ones(1)).support
I find that the supports are identical before and after reinterpreting batch dimensions.
As a result, the following snippet:
import torch.distributions as dist
from torch.distributions.constraint_registry import biject_to
broken_dist = dist.Independent(dist.Exponential(torch.ones(2)), 1)
tform = biject_to(broken_dist.support)
x = torch.randn(3,2)
y = tform(x) # NOTE: this is applied element-wise across x
tform.log_abs_det_jacobian(x, y).shape. # because broken_dist.event_shape == [2], after taking determinant this should be [3]
should report torch.Size([3])
because taking the log determinant removes the event_dim
, but instead it reports [3,2]
.
Expected behavior
After reinterpreting batch dimensions, the left-hand side is now a distribution over 1-dimensional rank-one tensors. Therefore, I expect the support to be constraints.real_vector
in this example. More generally, I expect the support of a dist.Independent
to be a product over the reinterpreted batch dimension of the individual supports, and for the inferred transforms
to properly sum over the new event_dim
.
Environment
Collecting environment information...
PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.19.1
Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.4
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] botorch==0.3.3
[pip3] gpytorch==1.3.0
[pip3] numpy==1.19.4
[pip3] numpyro==0.4.1
[pip3] torch==1.7.0
[pip3] torchvision==0.8.1
[conda] Could not collect