Skip to content

RuntimeError: unsupported input to tensordot, got dims=0 #2886

@bdatko

Description

@bdatko

Issue Description

I am unable to reproduce the tutorial on Inference with Discrete Latent Variables due to the following error:
RuntimeError: unsupported input to tensordot, got dims=0
The error might be related to cpu-only for PyTorch, but I can reproduce the below code snippet without an error when using PyTorch version 1.8.1. I am also seeing the same error on Windows, just posting the bug for Linux since I am away from my Windows machine.

Environment

For any bugs, please provide the following:
OS:

DISTRIB_ID=ManjaroLinux
DISTRIB_RELEASE=21.0.7
DISTRIB_CODENAME=Ornara
DISTRIB_DESCRIPTION="Manjaro Linux"

python:

Python 3.7.10

pip freeze:

backcall==0.2.0
certifi==2021.5.30
decorator==4.4.2
ipython @ file:///tmp/build/80754af9/ipython_1598883837425/work
ipython-genutils==0.2.0
jedi @ file:///tmp/build/80754af9/jedi_1611333758854/work
mkl-fft==1.3.0
mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
mkl-service==2.3.0
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
olefile==0.46
opt-einsum==3.3.0
parso @ file:///tmp/build/80754af9/parso_1596826841367/work
pexpect @ file:///tmp/build/80754af9/pexpect_1594383317248/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1594384075987/work
Pillow @ file:///tmp/build/80754af9/pillow_1617386154241/work
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
ptyprocess==0.6.0
Pygments @ file:///tmp/build/80754af9/pygments_1600458456400/work
pyro-api==0.1.2
pyro-ppl==1.6.0
six @ file:///tmp/build/80754af9/six_1623709665295/work
torch==1.9.0
torchaudio==0.9.0a0+33b2469
torchvision==0.10.0
tqdm==4.61.1
traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work

Pyro version:

python -c "import pyro; print(pyro.__version__)"
1.6.0

Code Snippet

import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.6.0')
pyro.set_rng_seed(0)

@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print('model x.shape = {}'.format(x.shape))
    print('model y.shape = {}'.format(y.shape))
    print('model z.shape = {}'.format(z.shape))
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);

model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-072209664cbc> in <module>
      1 elbo = TraceEnum_ELBO(max_plate_nesting=0)
----> 2 elbo.loss(model, config_enumerate(guide, "sequential"));

~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py in loss(self, model, guide, *args, **kwargs)
    356         elbo = 0.0
    357         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 358             elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
    359             if is_identically_zero(elbo_particle):
    360                 continue

~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py in _compute_dice_elbo(model_trace, guide_trace)
    183             costs.setdefault(ordering[name], []).append(cost)
    184 
--> 185     return Dice(guide_trace, ordering).compute_expectation(costs)
    186 
    187 

~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/util.py in compute_expectation(self, costs)
    301                     else:
    302                         cost, prob = packed.broadcast_all(cost, prob)
--> 303                     expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
    304 
    305         LAST_CACHE_SIZE[0] = count_cached_ops(cache)

~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/torch/functional.py in tensordot(a, b, dims, out)
    927 
    928     if len(dims_a) == 0 or len(dims_b) == 0:
--> 929         raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
    930 
    931     if out is None:

RuntimeError: unsupported input to tensordot, got dims=0

Below is the environment where I see no error for the same system and same code snippet shown above.

DISTRIB_ID=ManjaroLinux
DISTRIB_RELEASE=21.0.7
DISTRIB_CODENAME=Ornara
DISTRIB_DESCRIPTION="Manjaro Linux"

conda list | grep torch
# packages in environment at /home/benda/anaconda3/envs/pyro_torch_1_8_1:
cpuonly                   1.0                           0    pytorch
ffmpeg                    4.3                  hf484d3e_0    pytorch
pytorch                   1.8.1               py3.7_cpu_0  [cpuonly]  pytorch
torchaudio                0.8.1                      py37    pytorch
torchvision               0.9.1                  py37_cpu  [cpuonly]  pytorch

conda list | grep pyro
# packages in environment at /home/benda/anaconda3/envs/pyro_torch_1_8_1:
pyro-api                  0.1.2                    pypi_0    pypi
pyro-ppl                  1.6.0                    pypi_0    pypi

python --version
Python 3.7.10

 python -c "import pyro; print(pyro.__version__)"
1.6.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions