Fix interaction between PyroParam and torch.func.grad #3328
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Addresses a bug described downstream in BasisResearch/chirho#393
This PR adds a fix for compatibility of
PyroModule
s andPyroParam
s withtorch.func.grad
and the other functional automatic differentiation transforms intorch.func
. The fix is basically to replace eachpyro.param
statement or other interaction with the parameter store with a dummy version that does not store and retrieve a parameter tensor from a nonlocal state (which is invisible to the tracing machinery intorch.func
).Without this fix, gradient computations in
torch.func.grad
do not propagate to the unconstrained parameters behind constrainedPyroParam
s even when usingpyro.settings.set(module_local_param=True)
and are always zero. After this fix, the functional AD system intorch.func
behaves correctly withAutoGuide
s and otherPyroModule
s whenmodule_local_param=True
, though it is still fundamentally incompatible with the global parameter store state whenmodule_local_param=False
.Tested: