Skip to content

torch.tensordot(-,-,0) no longer works #61096

@fritzo

Description

@fritzo

🐛 Bug

This check
https://github.com/pytorch/pytorch/pull/53672/files#diff-5f3d4caa0693a716fc46fd7f6339312f1b5f0bf89e3a3ff58e9dc13a9486b17aR1038-R1039
introduced in #53672 breaks the simple case tensordot(torch.tensor(0.), torch.tensor(0.), 0), which worked fine in PyTorch 1.8 and works fine in NumPy

>>> np.tensordot(np.zeros(()), np.zeros(()), 0)
array(0.)

This important edge case dims=0 is needed by generic machinery as in Pyro and opt_einsum. Can we revert this check?

cc @neerajprad

To Reproduce

>>> import torch
>>> torch.tensordot(torch.zeros(()), torch.zeros(()), 0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/fobermey/opt/miniconda3/envs/pyro/lib/python3.7/site-packages/torch/functional.py", line 929, in tensordot
    raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
RuntimeError: unsupported input to tensordot, got dims=0

Expected behavior

Behave as in NumPy and PyTorch 1.8

Environment

PyTorch version: 1.9.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 10.15.7 (x86_64)
GCC version: Could not collect
Clang version: 12.0.0 (clang-1200.0.32.2)
CMake version: version 3.18.4
Libc version: N/A

Python version: 3.7.0 (default, Jun 28 2018, 07:39:16)  [Clang 4.0.1 (tags/RELEASE_401/final)] (64-bit runtime)
Python platform: Darwin-19.6.0-x86_64-i386-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip] gpytorch==1.5.0
[pip] numpy==1.19.4
[pip] numpyro==0.6.0
[pip] torch==1.9.0
[pip] torchfile==0.1.0
[pip] torchvision==0.10.0
[conda] gpytorch                  1.5.0                    pypi_0    pypi
[conda] numpy                     1.19.4                   pypi_0    pypi
[conda] numpyro                   0.6.0                     dev_0    <develop>
[conda] torch                     1.9.0                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.10.0                   pypi_0    pypi

cc @gmagogsfm @eellison @neerajprad

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions