Skip to content

Rewrite functional.tensordot to be TorchScript-able #53672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Rewrite functional.tensordot to be TorchScript-able #53672

wants to merge 1 commit into from

Conversation

gmagogsfm
Copy link
Contributor

Fixes #53487

@facebook-github-bot facebook-github-bot added oncall: jit Add this issue/PR to JIT oncall triage queue cla signed labels Mar 10, 2021
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 10, 2021

💊 CI failures summary and remediations

As of commit d8ec82c (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-scanned failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

r"""Returns a contraction of a and b over multiple dimensions.

:attr:`tensordot` implements a generalized matrix product.

Args:
a (Tensor): Left tensor to contract
b (Tensor): Right tensor to contract
dims (int or Tuple[List[int]] containing two lists): number of dimensions to
dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the official documentation (https://pytorch.org/docs/stable/generated/torch.tensordot.html), it looks like we don't support List[int] or List[List[int]]. Or is it necessary because of how aten op is defined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation link you referred to is actually generated from this docstring, so when I update it in this PR, documentation will change.

The support for List[int] is already there in the implementation, I am updating docstring to match what's implemented.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think we should be adding overhead to eager opperators to support them in JIT. This function can be supported with overloads while keeping the original code more in tact.

if isinstance(dims, (list, tuple)) or \
(isinstance(dims, torch.Tensor) and dims.numel() > 1):

dims_a = torch.jit.annotate(List[int], [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isnt up to date syntax, you should be using dims_a : List[int] = []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -937,15 +937,15 @@ def _consecutive_return_inverse(input, return_inverse=False, return_counts=False
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__


def tensordot(a, b, dims=2, out=None):
def tensordot(a, b, dims: Any = 2, out: Optional[torch.Tensor] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.jit.isinstance and adding overhead here, we're iterating over the entire list when we use to just check if the input was a list type. I dont think we should be adding overhead to eager operators.

I think you can support this without making these changes by using overloads:

@overload
def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):

@overload
def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):

@overload
def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! Changed to use @overload.

Copy link
Contributor Author

@gmagogsfm gmagogsfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised, PTAL

@gmagogsfm gmagogsfm requested a review from eellison March 11, 2021 02:00
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for doing this ! some of the assertions may be better explicit as you have them, idk, just commenting to let you know where they will be checked as part of the following line

if isinstance(dims, torch.Tensor):
dims = dims.item()

if isinstance(dims, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit you can fold this if with the tuple one.
if isinstance(dims, (list, tuple))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dims = dims.item()

if isinstance(dims, list):
assert len(dims) == 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit the assertion here is unnecessary since the list unpacking will also check that the length is 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

num_elements = dims.numel()
if num_elements > 1:
assert dims.size()[0] == 2
dims_a = torch.jit.annotate(List[int], dims[0].tolist())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: not to use torch.jit.annotate use up to date syntax

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this doesn't work at the moment. TS complains about redefining variable dims_a.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

file an issue and add in comment ?

dims_a = torch.jit.annotate(List[int], dims[0].tolist())
dims_b = torch.jit.annotate(List[int], dims[1].tolist())
else:
assert num_elements == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dims.item() will already check that len == 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

dims_b = torch.jit.annotate(List[int], dims[1].tolist())
else:
assert num_elements == 1
dims_val = int(dims.item())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're missing a check that dims_val >= 0 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, added.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link

codecov bot commented Mar 12, 2021

Codecov Report

Merging #53672 (d8ec82c) into master (0584fd9) will decrease coverage by 0.00%.
The diff coverage is 78.78%.

@@            Coverage Diff             @@
##           master   #53672      +/-   ##
==========================================
- Coverage   77.36%   77.35%   -0.01%     
==========================================
  Files        1879     1879              
  Lines      183257   183285      +28     
==========================================
+ Hits       141774   141785      +11     
- Misses      41483    41500      +17     

@facebook-github-bot
Copy link
Contributor

@gmagogsfm merged this pull request in f48a971.

Comment on lines +1038 to +1039
if len(dims_a) == 0 or len(dims_b) == 0:
raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check breaks the simple case tensordot(torch.tensor(0.), torch.tensor(0.), 0), which works fine in NumPy

>>> np.tensordot(np.zeros(()), np.zeros(()), 0)
array(0.)

This important edge case dims=0 is needed by generic machinery as in Pyro and opt_einsum. Can we revert this check?

cc @neerajprad

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The signature of torch.tensordot in Python should be compatible with its signature in TorchScript
5 participants