-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[dynamo] Added support for tensor's is_inference
method
#136450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dynamo] Added support for tensor's is_inference
method
#136450
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136450
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (5 Unrelated Failures)As of commit 6d69b45 with merge base a0a1873 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/dynamo/test_functions.py
Outdated
x = torch.randn(2, 2) | ||
fn(x) | ||
|
||
self.assertEqual(cnts.frame_count, 3) # Recompile! inference_mode changed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're actually only testing here that the inference mode state is guarded on, not that if a tensor is an inference mode tensor or not causes a change. Move the allocation of the tensor in the mode but do the fn call outside of it to test this.
Since you are diverging the semantics of the program on the inside, you could also just check if the result equals the eager result or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this makes sense - thanks!
Are you sure we're not guarding on inference mode-ness of a Tensor already? I checked the implementation of is_inference
This seems to be derived entirely from the dispatch key set. But we DO guard on that right now:
?? |
Yeah I had the same thought here (#135439 (comment))
However, I'm a noob, and I wasn't sure how specific guard semantics should be, i.e. are there other things that would break the dispatch key set guard, and if so is this fine? It's why I placed the check for If we do want a specific guard, would it also be simpler to just check against the appropriate dispatch keys vs. introducing another field? |
Oh, I missed your comment edit. Your last log suggests that we already guard on inference-mode ness. So in fact you can get rid of all the new guard code, your test case should still pass |
Updating, thank you for your patience @ezyang. |
test/dynamo/test_functions.py
Outdated
eager_result = fn(x_inference) | ||
|
||
cnts = torch._dynamo.testing.CompileCounter() | ||
fn = torch._dynamo.optimize(cnts, nopython=True)(fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can write this test more clearly. The most important thing is to distinguish fn from opt_fn. The second is to just directly assertEqual(fn(x_inference), opt_fn(x_inference)) and so forth
Thanks, just nits on the test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
another? |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #135439
This PR adds support for the
is_inference
method on torch tensors which successfully compiles the following example fn without graph breaks:I've also tried to add guards on the tensor to guard against
is_inference
. I wasn't 100% sure where these should go so please don't hesitate to correct me.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec @ezyang