Skip to content

The signature of torch.tensordot in Python should be compatible with its signature in TorchScript #53487

@Linux-cpp-lisp

Description

@Linux-cpp-lisp

🐛 Bug

torch.tensordot accepts entirely different argument formats in Python and TorchScript.

To Reproduce

In [1]: import torch

In [2]: x, y = torch.randn(3, 4), torch.randn(4, 3)

In [3]: torch.tensordot(x, y, dims=((1,), (0,)))
Out[3]: 
tensor([[-0.2011,  0.1324,  0.7855],
        [ 0.4894, -2.1793, -0.4989],
        [ 2.4699,  0.4472, -0.5411]])

In [5]: @torch.jit.script
   ...: def f(x, y):
   ...:     return torch.tensordot(x, y, dims=((1,), (0,)))
   ...: 
---------------------------------------------------------------------------
RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> (Tensor):
  Argument dims_self not provided.
  
  aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> (Tensor(a!)):
  Argument dims_self not provided.

The original call is:
  File "<ipython-input-5-61613126bcad>", line 3
@torch.jit.script
def f(x, y):
    return torch.tensordot(x, y, dims=((1,), (0,)))
           ~~~~~~~~~~~~~~~ <--- HERE


In [6]: @torch.jit.script
   ...: def f(x, y):
   ...:     return torch.tensordot(x, y, dims_self=[1], dims_other=[0])

In [7]: torch.tensordot(x, y, dims_self=[1], dims_other=[0])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-be1bfc130a4a> in <module>
----> 1 torch.tensordot(x, y, dims_self=[1], dims_other=[0])

TypeError: tensordot() got an unexpected keyword argument 'dims_self'

Expected behavior

It is fine that the TorchScript version doesn't accept the Python call signature of dims, but the Python call signature should accept dims_self and dims_other as an alternative to dims.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): Conda
  • Build command you used (if compiling from source): n/a
  • Python version: 3.8.5
  • CUDA/cuDNN version: n/a
  • GPU models and configuration: n/a
  • Any other relevant information: n/a

cc @gmagogsfm

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions