Skip to content

fix saving the param store with PyTorch 2 (closes #3201) #3206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

fix saving the param store with PyTorch 2 (closes #3201) #3206

wants to merge 1 commit into from

Conversation

ilia-kats
Copy link
Contributor

In ParamStoreDict.getitem, we add a weakref to the unconstrained value to the transformed parameter. Apparently, when the constrained is Real(), transform_to returns the same object that it was passed, so we end up with weakrefs in ParamStoreDict._params. This was fine with PyTorch1, but under Pytorch2 an exception is raised when saving a Parameter with a weakref. So let's explicitly remove the .unconstrained attribute from all parameters to be saved

@fritzo
Copy link
Member

fritzo commented May 17, 2023

Thanks for fixing this!

Hmm, I'm concerned that if we delete the .unconstrained() from the parameters, it will corrupt the param store. While this will work for the workflow save; restart; load, it will break the checkpointing workflow train; save; train; save;...; restart; load.

WDYT of this slight modification of your fix

# Shallow copy parameters to remove .unconstrained() attributes.
params = {name: param[...] for name, param in self._params.items()}

EDIT yup it looks like this is what's causing the failed test.

fritzo
fritzo previously approved these changes May 17, 2023
In ParamStoreDict.__getitem__, we add a weakref to the unconstrained
value to the transformed parameter. Apparently, when the constrained is
Real(), transform_to returns the same object that it was passed, so we
end up with weakrefs in ParamStoreDict._params. This was fine with
PyTorch1, but under Pytorch2 an exception is raised when saving a
Parameter with a weakref. So let's explicitly remove the .unconstrained
attributes from all parameters to be saved
@ilia-kats
Copy link
Contributor Author

Thanks for the suggestion, fixed. I don't have an overview of the Pyro codebase yet, so I wasn't aware that other Pyro internals rely on .unconstrained being present in the param store, since the __getitem__ code really looks like it's meant to just add the attribute to the objects that it returns and not to the stored ones.

@fritzo
Copy link
Member

fritzo commented May 17, 2023

Hmm, looks like some tests are still failing. Let me look into it...

@fritzo
Copy link
Member

fritzo commented May 17, 2023

Hey Thanks for your help fixing this @ilia-kats. Following your lead I found more weakref issues and fixed them all at once in #3212. I'll close this PR, but it #3212 wouldn't have happened without your help!

@fritzo fritzo closed this May 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants