-
-
Notifications
You must be signed in to change notification settings - Fork 995
Closed
Description
@samuelklee showed me a model where we couldn't find a way to correctly index into a batch dimension using either torch.Tensor.__getitem__()
or pyro.ops.indexing.Vindex()
. Do we need something more general than Vindex()
?
Problem
Consider an enumerated model
def model():
p_plate = pyro.plate("p_plate", 4, dim=-1)
s_plate = pyro.plate("s_plate", 5, dim=-1)
v_plate = pyro.plate("v_plate", 6, dim=-2)
with p_plate:
with v_plate:
x = pyro.sample("x", dist.Dirichlet(torch.ones(3))) # [6,4,3]
with s_plate:
p_s = pyro.sample("p_s", dist.Categorical(logits=torch.zeros(4))) # [5]
with v_plate:
y = x[:, p_s, :] # [6, 5, 3]
The crux is that
- during generation,
p_s.shape == (5,)
andy.shape == (6, 5, 3)
but - during enumeration
p_s.shape == (4, 1, 1)
and we wanty.shape == (4, 6, 1, 3)
(which fails).
Is there a one-liner to index y = f(x, p_s)
, or do we need to add a helper?