Skip to content

Conversation

SalmanMohammadi
Copy link
Contributor

@SalmanMohammadi SalmanMohammadi commented Sep 23, 2024

Fixes #135439

This PR adds support for the is_inference method on torch tensors which successfully compiles the following example fn without graph breaks:

def fn_simple(x):
    if x.is_inference():
        return x.sum()
    else:
        return x.min()

I've also tried to add guards on the tensor to guard against is_inference. I wasn't 100% sure where these should go so please don't hesitate to correct me.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames @rec @ezyang

Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/136450

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (5 Unrelated Failures)

As of commit 6d69b45 with merge base a0a1873 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Sep 23, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

x = torch.randn(2, 2)
fn(x)

self.assertEqual(cnts.frame_count, 3) # Recompile! inference_mode changed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're actually only testing here that the inference mode state is guarded on, not that if a tensor is an inference mode tensor or not causes a change. Move the allocation of the tensor in the mode but do the fn call outside of it to test this.

Since you are diverging the semantics of the program on the inside, you could also just check if the result equals the eager result or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this makes sense - thanks!

@ezyang
Copy link
Contributor

ezyang commented Sep 24, 2024

Are you sure we're not guarding on inference mode-ness of a Tensor already? I checked the implementation of is_inference

    bool no_ADInplaceOrView = !key_set_.has_any(c10::inplace_or_view_ks);
    bool no_Autograd = !key_set_.has_any(c10::autograd_dispatch_keyset);
    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
        no_ADInplaceOrView == no_Autograd,
        "ADInplaceOrView and Autograd keys must be on/off at the same time.");
    return no_ADInplaceOrView && no_Autograd;
  }

This seems to be derived entirely from the dispatch key set. But we DO guard on that right now:

TensorCheck::TensorCheck(
    const LocalState& state,
    PyTypeObject* pt,
    const at::Tensor& v,
    std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
    std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
    : pytype(pt),
      dispatch_key_(state.apply(v.key_set()).raw_repr()),
      dtype_(v.dtype().toScalarType()),

??

@SalmanMohammadi
Copy link
Contributor Author

Yeah I had the same thought here (#135439 (comment))

V0922 18:59:34.045000 493191 torch/_dynamo/guards.py:2830] [0/1] [__recompiles] Recompiling function fn_simple in /home/salman/pytorch/build/test.py:16
V0922 18:59:34.045000 493191 torch/_dynamo/guards.py:2830] [0/1] [__recompiles] triggered by the following guard failure(s):
V0922 18:59:34.045000 493191 torch/_dynamo/guards.py:2830] [0/1] [__recompiles] - 0/0: GLOBAL_STATE changed: grad_mode
V0922 18:59:34.082000 493191 torch/_dynamo/guards.py:2830] [0/2] [__recompiles] Recompiling function fn_simple in /home/salman/pytorch/build/test.py:16
V0922 18:59:34.082000 493191 torch/_dynamo/guards.py:2830] [0/2] [__recompiles] triggered by the following guard failure(s):
V0922 18:59:34.082000 493191 torch/_dynamo/guards.py:2830] [0/2] [__recompiles] - 0/1: tensor 'L['x']' dispatch key set mismatch. expected DispatchKeySet(CPU, BackendSelect), actual DispatchKeySet(CPU, BackendSelect, ADInplaceOrView, AutogradCPU)
V0922 18:59:34.082000 493191 torch/_dynamo/guards.py:2830] [0/2] [__recompiles] - 0/0: GLOBAL_STATE changed: grad_mode
Does this mean the guards on dispatch key set pick up the inference moded-ness?

However, I'm a noob, and I wasn't sure how specific guard semantics should be, i.e. are there other things that would break the dispatch key set guard, and if so is this fine? It's why I placed the check for is_inference to be triggered first but that did feel a bit weird.

If we do want a specific guard, would it also be simpler to just check against the appropriate dispatch keys vs. introducing another field?

@ezyang
Copy link
Contributor

ezyang commented Sep 24, 2024

Oh, I missed your comment edit. Your last log suggests that we already guard on inference-mode ness. So in fact you can get rid of all the new guard code, your test case should still pass

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 24, 2024
@SalmanMohammadi
Copy link
Contributor Author

Updating, thank you for your patience @ezyang.

eager_result = fn(x_inference)

cnts = torch._dynamo.testing.CompileCounter()
fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can write this test more clearly. The most important thing is to distinguish fn from opt_fn. The second is to just directly assertEqual(fn(x_inference), opt_fn(x_inference)) and so forth

@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2024

Thanks, just nits on the test

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot!

@ezyang
Copy link
Contributor

ezyang commented Sep 26, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@SalmanMohammadi
Copy link
Contributor Author

another?
@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo open source release notes: dynamo topic: bug fixes topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_method is_inference
4 participants