-
Notifications
You must be signed in to change notification settings - Fork 25.1k
[MPS] Add native implementation for shift ops #131813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Similar to how AND/OR/XOR ops are implemented TODO: Consider using MPS method calls rather than metal kernels
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/131813
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit d1c7b2a with merge base bf6aae1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -85,8 +85,6 @@ static Tensor slow_conv2d_forward_mps(const Tensor& self, | |||
// These ops are not supported via MPS backend currently, and we fallback to run on CPU. | |||
// For the rest of unsupported ops the user needs to pass 'PYTORCH_ENABLE_MPS_FALLBACK=1' | |||
// to fallback on CPU, otherwise we will error out. | |||
m.impl("bitwise_left_shift.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); | |||
m.impl("bitwise_right_shift.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there was a fallback in place, why were things not working before?
Granted that the error was with __left__.Scalar
. Not sure how that magic works
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, the fallback was for the Tensor variant, not for the Scalar ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this only works if one calls torch.bitwise_left_shift
, not sure what's the story with all those
@@ -8457,29 +8457,29 @@ | |||
device_check: NoCheck # TensorIterator | |||
variants: method, function | |||
dispatch: | |||
CPU, CUDA: __lshift__ | |||
CPU, CUDA, MPS: __lshift__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, where is the magic that ties __lshift__
to bitwise_left_shift_out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here:
pytorch/aten/src/ATen/native/BinaryOps.cpp
Line 1292 in 3d7c424
lshift_stub(iter.device_type(), iter); |
@pytorchbot merge -f "MPS + Linr tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales
Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales
Missed it while working on pytorch#131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: pytorch#135607 Approved by: https://github.com/manuelcandales
Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales (cherry picked from commit 3bf6be4)
[MPS] Add missing dispatch to rshift.Tensor (#135607) Missed it while working on #131813 Test plan: `python -c "import torch;print(torch.randint(100, 500, (64,), device='mps') >> torch.tensor([3,], device='mps'))"` Pull Request resolved: #135607 Approved by: https://github.com/manuelcandales (cherry picked from commit 3bf6be4) Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Similar to how AND/OR/XOR ops are implemented
TODO: Consider using MPS method calls rather than metal kernels