-
-
Notifications
You must be signed in to change notification settings - Fork 995
Update dev branch to use PyTorch 1.8 prerelease #2753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cf2a76c
to
649e00a
Compare
@@ -87,8 +84,6 @@ class CholeskyTransform(Transform): | |||
positive definite matrix. | |||
""" | |||
bijective = True | |||
sign = +1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note .sign
is defined only for scalar transforms, not multivariate transforms
{'d': 3, 'eta': [1.], 'test_data': | ||
[[[1.0000, 0.0000, 0.0000], [-0.8221, 0.5693, 0.0000], [0.7655, 0.1756, 0.6190]], | ||
[[1.0000, 0.0000, 0.0000], [-0.5345, 0.8451, 0.0000], [-0.5459, -0.3847, 0.7444]], | ||
[[1.0000, 0.0000, 0.0000], [-0.3758, 0.9267, 0.0000], [-0.2409, 0.4044, 0.8823]], | ||
[[1.0000, 0.0000, 0.0000], [-0.8800, 0.4750, 0.0000], [-0.9493, 0.1546, 0.2737]], | ||
[[1.0000, 0.0000, 0.0000], [0.2284, 0.9736, 0.0000], [-0.1283, 0.0451, 0.9907]]]}, | ||
]), | ||
{ | ||
'd': 3, | ||
'eta': [1.], | ||
'test_data': [ | ||
[[1.0, 0.0, 0.0], | ||
[-0.17332135, 0.98486533, 0.0], | ||
[0.43106407, -0.54767312, 0.71710384]], | ||
[[1.0, 0.0, 0.0], | ||
[-0.31391555, 0.94945091, 0.0], | ||
[-0.31391296, -0.29767500, 0.90158097]], | ||
], | ||
}, | ||
]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the old low-precision test data was failing the new test_support_shape
, so I created new data
@fehiepsi @neerajprad can we discuss this today? This PR is blocking development on pyro.distributions (#2739, #2766, ...) |
from pyro.distributions.torch import Beta, TransformedDistribution | ||
from pyro.distributions.torch_distribution import TorchDistribution | ||
from pyro.distributions.transforms.cholesky import CorrMatrixCholeskyTransform, _vector_to_l_cholesky | ||
|
||
from . import constraints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have LKJ in pytorch now, so we can remove this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like PyTorch has an LKJCholesky
but neither an LKJ
nor an LKJCorrCholesky
. When I try to replace Pyro's LKJCorrCholesky
with PyTorch's LKJCholesk
the test_lkj.py tests fail. The code is quite different.
@fehiepsi do you understand the differences between these distributions, and can you suggest a deprecation plan?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm..LkJCholesky and LKJCorrCholesky should be the same except that the sampling method should be more efficient. I think we can merge this for now and I'll take a look at the tests to see why that is failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for making this change, @fritzo.
@neerajprad thanks for reviewing! I'm happy merging this now and resolving LKJ differences in a follow-up PR. WDYT? |
That sounds great. I'll take a look at the LKJ distribution differences after this PR merges. |
Hmm is this binary actually available (yet, or still)? I am getting the following, both in a fresh local conda installation on my mac, as well as in CI:
|
Hmm ok I see from the readme that this requires manual install of the pytorch nightly. I still get the following on linux though:
|
@Balandat In linux you could try adding
And please let me know if you can suggest a better option for depending on PyTorch 1.8 prerelease. |
Addresses #2754
Blocking #2739, #2766
This pins our dev branch to track the latest PyTorch 1.8 nightly build, and fixes Pyro to be compatible with PyTorch 1.8.
This may be inconvenient for advanced users who use Pyro's dev branch, but it will help minimize breakage around the upcoming PyTorch 1.8 release by allowing us to release Pyro faster.
Tested
test_distributions.py::test_support_shape
test_distributions.py::test_is_discrete
andis_not_discrete
collapse
tests; these will be fixed by Add a Distribution.infer_shapes() method to statically infer shapes #2739