-
-
Notifications
You must be signed in to change notification settings - Fork 995
Closed
Labels
Milestone
Description
Issue Description
I am unable to reproduce the tutorial on Inference with Discrete Latent Variables due to the following error:
RuntimeError: unsupported input to tensordot, got dims=0
The error might be related to cpu-only for PyTorch, but I can reproduce the below code snippet without an error when using PyTorch version 1.8.1. I am also seeing the same error on Windows, just posting the bug for Linux since I am away from my Windows machine.
Environment
For any bugs, please provide the following:
OS:
DISTRIB_ID=ManjaroLinux
DISTRIB_RELEASE=21.0.7
DISTRIB_CODENAME=Ornara
DISTRIB_DESCRIPTION="Manjaro Linux"
python:
Python 3.7.10
pip freeze:
backcall==0.2.0
certifi==2021.5.30
decorator==4.4.2
ipython @ file:///tmp/build/80754af9/ipython_1598883837425/work
ipython-genutils==0.2.0
jedi @ file:///tmp/build/80754af9/jedi_1611333758854/work
mkl-fft==1.3.0
mkl-random @ file:///tmp/build/80754af9/mkl_random_1618853974840/work
mkl-service==2.3.0
numpy @ file:///tmp/build/80754af9/numpy_and_numpy_base_1620831194891/work
olefile==0.46
opt-einsum==3.3.0
parso @ file:///tmp/build/80754af9/parso_1596826841367/work
pexpect @ file:///tmp/build/80754af9/pexpect_1594383317248/work
pickleshare @ file:///tmp/build/80754af9/pickleshare_1594384075987/work
Pillow @ file:///tmp/build/80754af9/pillow_1617386154241/work
prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1602688806899/work
ptyprocess==0.6.0
Pygments @ file:///tmp/build/80754af9/pygments_1600458456400/work
pyro-api==0.1.2
pyro-ppl==1.6.0
six @ file:///tmp/build/80754af9/six_1623709665295/work
torch==1.9.0
torchaudio==0.9.0a0+33b2469
torchvision==0.10.0
tqdm==4.61.1
traitlets @ file:///tmp/build/80754af9/traitlets_1602787416690/work
typing-extensions @ file:///home/ktietz/src/ci_mi/typing_extensions_1612808209620/work
wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work
Pyro version:
python -c "import pyro; print(pyro.__version__)"
1.6.0
Code Snippet
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.ops.indexing import Vindex
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.6.0')
pyro.set_rng_seed(0)
@config_enumerate
def model():
p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
x = pyro.sample("x", dist.Categorical(p[0]))
y = pyro.sample("y", dist.Categorical(p[x]))
z = pyro.sample("z", dist.Categorical(p[y]))
print('model x.shape = {}'.format(x.shape))
print('model y.shape = {}'.format(y.shape))
print('model z.shape = {}'.format(z.shape))
return x, y, z
def guide():
pass
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);
model x.shape = torch.Size([3])
model y.shape = torch.Size([3, 1])
model z.shape = torch.Size([3, 1, 1])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-5-072209664cbc> in <module>
1 elbo = TraceEnum_ELBO(max_plate_nesting=0)
----> 2 elbo.loss(model, config_enumerate(guide, "sequential"));
~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py in loss(self, model, guide, *args, **kwargs)
356 elbo = 0.0
357 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 358 elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
359 if is_identically_zero(elbo_particle):
360 continue
~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/traceenum_elbo.py in _compute_dice_elbo(model_trace, guide_trace)
183 costs.setdefault(ordering[name], []).append(cost)
184
--> 185 return Dice(guide_trace, ordering).compute_expectation(costs)
186
187
~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/pyro/infer/util.py in compute_expectation(self, costs)
301 else:
302 cost, prob = packed.broadcast_all(cost, prob)
--> 303 expected_cost = expected_cost + scale * torch.tensordot(prob, cost, prob.dim())
304
305 LAST_CACHE_SIZE[0] = count_cached_ops(cache)
~/anaconda3/envs/pyro_test/lib/python3.7/site-packages/torch/functional.py in tensordot(a, b, dims, out)
927
928 if len(dims_a) == 0 or len(dims_b) == 0:
--> 929 raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
930
931 if out is None:
RuntimeError: unsupported input to tensordot, got dims=0
Below is the environment where I see no error for the same system and same code snippet shown above.
DISTRIB_ID=ManjaroLinux
DISTRIB_RELEASE=21.0.7
DISTRIB_CODENAME=Ornara
DISTRIB_DESCRIPTION="Manjaro Linux"
conda list | grep torch
# packages in environment at /home/benda/anaconda3/envs/pyro_torch_1_8_1:
cpuonly 1.0 0 pytorch
ffmpeg 4.3 hf484d3e_0 pytorch
pytorch 1.8.1 py3.7_cpu_0 [cpuonly] pytorch
torchaudio 0.8.1 py37 pytorch
torchvision 0.9.1 py37_cpu [cpuonly] pytorch
conda list | grep pyro
# packages in environment at /home/benda/anaconda3/envs/pyro_torch_1_8_1:
pyro-api 0.1.2 pypi_0 pypi
pyro-ppl 1.6.0 pypi_0 pypi
python --version
Python 3.7.10
python -c "import pyro; print(pyro.__version__)"
1.6.0