Skip to content

How to index into a batch dimension using an enumerated index? #2875

@fritzo

Description

@fritzo

@samuelklee showed me a model where we couldn't find a way to correctly index into a batch dimension using either torch.Tensor.__getitem__() or pyro.ops.indexing.Vindex(). Do we need something more general than Vindex()?

Problem

Consider an enumerated model

def model():
    p_plate = pyro.plate("p_plate", 4, dim=-1)
    s_plate = pyro.plate("s_plate", 5, dim=-1)
    v_plate = pyro.plate("v_plate", 6, dim=-2)

    with p_plate:
       with v_plate: 
            x = pyro.sample("x", dist.Dirichlet(torch.ones(3)))  # [6,4,3]
    with s_plate:
        p_s = pyro.sample("p_s", dist.Categorical(logits=torch.zeros(4)))  # [5]
        with v_plate:
            y = x[:, p_s, :]  # [6, 5, 3]

The crux is that

  • during generation, p_s.shape == (5,) and y.shape == (6, 5, 3) but
  • during enumeration p_s.shape == (4, 1, 1) and we want y.shape == (4, 6, 1, 3) (which fails).

Is there a one-liner to index y = f(x, p_s), or do we need to add a helper?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions