-
-
Notifications
You must be signed in to change notification settings - Fork 995
Closed
Description
Issue Description
The param store dict's save method seems to have been messed up by pytorch version 2.0.0
with
pyro-ppl 1.8.4+dd4e0f81
torch 1.13.0
if I run
import pyro
import torch
from torch.distributions import constraints
def model():
x_scale = pyro.param("x_scale", torch.tensor(1.0),
constraint=constraints.positive)
z = pyro.sample("z", pyro.distributions.Normal(0, 1))
return pyro.sample("x", pyro.distributions.Normal(z, x_scale), obs=torch.tensor(0.1))
def guide():
z_loc = pyro.param("z_loc", torch.tensor(0.0))
z_scale = pyro.param(
"z_scale", torch.tensor(0.5), constraint=constraints.positive
)
pyro.sample("z", pyro.distributions.Normal(z_loc, z_scale))
optimizer = pyro.optim.ClippedAdam({"lr": 0.01})
svi = pyro.infer.SVI(model, guide, optimizer, pyro.infer.Trace_ELBO())
for _ in range(10):
svi.step()
store = pyro.get_param_store()
store.save("tmp.pt")
I see no error, and the param store can be saved just fine.
However, with
pyro-ppl 1.8.4+dd4e0f81
torch 2.0.0
if I run the same thing, I get
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/sfleming/opt/anaconda3/envs/cellbender38/lib/python3.8/site-packages/pyro/params/param_store.py", line 279, in save
torch.save(self.get_state(), output_file)
File "/Users/sfleming/opt/anaconda3/envs/cellbender38/lib/python3.8/site-packages/torch/serialization.py", line 441, in save
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
File "/Users/sfleming/opt/anaconda3/envs/cellbender38/lib/python3.8/site-packages/torch/serialization.py", line 653, in _save
pickler.dump(obj)
TypeError: cannot pickle 'weakref' object
Environment
- macOS Monterey
- Python 3.8
- pytorch and pyro versions as above
Metadata
Metadata
Assignees
Labels
No labels