-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
torch.tensordot
accepts entirely different argument formats in Python and TorchScript.
To Reproduce
In [1]: import torch
In [2]: x, y = torch.randn(3, 4), torch.randn(4, 3)
In [3]: torch.tensordot(x, y, dims=((1,), (0,)))
Out[3]:
tensor([[-0.2011, 0.1324, 0.7855],
[ 0.4894, -2.1793, -0.4989],
[ 2.4699, 0.4472, -0.5411]])
In [5]: @torch.jit.script
...: def f(x, y):
...: return torch.tensordot(x, y, dims=((1,), (0,)))
...:
---------------------------------------------------------------------------
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> (Tensor):
Argument dims_self not provided.
aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> (Tensor(a!)):
Argument dims_self not provided.
The original call is:
File "<ipython-input-5-61613126bcad>", line 3
@torch.jit.script
def f(x, y):
return torch.tensordot(x, y, dims=((1,), (0,)))
~~~~~~~~~~~~~~~ <--- HERE
In [6]: @torch.jit.script
...: def f(x, y):
...: return torch.tensordot(x, y, dims_self=[1], dims_other=[0])
In [7]: torch.tensordot(x, y, dims_self=[1], dims_other=[0])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-7-be1bfc130a4a> in <module>
----> 1 torch.tensordot(x, y, dims_self=[1], dims_other=[0])
TypeError: tensordot() got an unexpected keyword argument 'dims_self'
Expected behavior
It is fine that the TorchScript
version doesn't accept the Python call signature of dims
, but the Python call signature should accept dims_self
and dims_other
as an alternative to dims
.
Environment
- PyTorch Version (e.g., 1.0): 1.8.0
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
, source): Conda - Build command you used (if compiling from source): n/a
- Python version: 3.8.5
- CUDA/cuDNN version: n/a
- GPU models and configuration: n/a
- Any other relevant information: n/a
cc @gmagogsfm
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue