-
Notifications
You must be signed in to change notification settings - Fork 25k
Rewrite functional.tensordot to be TorchScript-able #53672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CI failures summary and remediationsAs of commit d8ec82c (more details on the Dr. CI page):
ci.pytorch.org: 1 failedThis comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
r"""Returns a contraction of a and b over multiple dimensions. | ||
|
||
:attr:`tensordot` implements a generalized matrix product. | ||
|
||
Args: | ||
a (Tensor): Left tensor to contract | ||
b (Tensor): Right tensor to contract | ||
dims (int or Tuple[List[int]] containing two lists): number of dimensions to | ||
dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the official documentation (https://pytorch.org/docs/stable/generated/torch.tensordot.html), it looks like we don't support List[int] or List[List[int]]. Or is it necessary because of how aten op is defined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation link you referred to is actually generated from this docstring, so when I update it in this PR, documentation will change.
The support for List[int] is already there in the implementation, I am updating docstring to match what's implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think we should be adding overhead to eager opperators to support them in JIT. This function can be supported with overloads while keeping the original code more in tact.
torch/functional.py
Outdated
if isinstance(dims, (list, tuple)) or \ | ||
(isinstance(dims, torch.Tensor) and dims.numel() > 1): | ||
|
||
dims_a = torch.jit.annotate(List[int], []) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isnt up to date syntax, you should be using dims_a : List[int] = []
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torch/functional.py
Outdated
@@ -937,15 +937,15 @@ def _consecutive_return_inverse(input, return_inverse=False, return_counts=False | |||
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__ | |||
|
|||
|
|||
def tensordot(a, b, dims=2, out=None): | |||
def tensordot(a, b, dims: Any = 2, out: Optional[torch.Tensor] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.jit.isinstance
and adding overhead here, we're iterating over the entire list when we use to just check if the input was a list type. I dont think we should be adding overhead to eager operators.
I think you can support this without making these changes by using overloads:
@overload
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
@overload
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):
@overload
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call! Changed to use @overload
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised, PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for doing this ! some of the assertions may be better explicit as you have them, idk, just commenting to let you know where they will be checked as part of the following line
torch/functional.py
Outdated
if isinstance(dims, torch.Tensor): | ||
dims = dims.item() | ||
|
||
if isinstance(dims, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit you can fold this if with the tuple one.
if isinstance(dims, (list, tuple))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
torch/functional.py
Outdated
dims = dims.item() | ||
|
||
if isinstance(dims, list): | ||
assert len(dims) == 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit the assertion here is unnecessary since the list unpacking will also check that the length is 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
num_elements = dims.numel() | ||
if num_elements > 1: | ||
assert dims.size()[0] == 2 | ||
dims_a = torch.jit.annotate(List[int], dims[0].tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: not to use torch.jit.annotate use up to date syntax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, this doesn't work at the moment. TS complains about redefining variable dims_a
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
file an issue and add in comment ?
torch/functional.py
Outdated
dims_a = torch.jit.annotate(List[int], dims[0].tolist()) | ||
dims_b = torch.jit.annotate(List[int], dims[1].tolist()) | ||
else: | ||
assert num_elements == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: dims.item() will already check that len == 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
dims_b = torch.jit.annotate(List[int], dims[1].tolist()) | ||
else: | ||
assert num_elements == 1 | ||
dims_val = int(dims.item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you're missing a check that dims_val
>= 0 here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Codecov Report
@@ Coverage Diff @@
## master #53672 +/- ##
==========================================
- Coverage 77.36% 77.35% -0.01%
==========================================
Files 1879 1879
Lines 183257 183285 +28
==========================================
+ Hits 141774 141785 +11
- Misses 41483 41500 +17 |
@gmagogsfm merged this pull request in f48a971. |
if len(dims_a) == 0 or len(dims_b) == 0: | ||
raise RuntimeError(f"unsupported input to tensordot, got dims={dims}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check breaks the simple case tensordot(torch.tensor(0.), torch.tensor(0.), 0)
, which works fine in NumPy
>>> np.tensordot(np.zeros(()), np.zeros(()), 0)
array(0.)
This important edge case dims=0
is needed by generic machinery as in Pyro and opt_einsum. Can we revert this check?
cc @neerajprad
Fixes #53487