Skip to content

Possible bug in gate_logits of ZeroInflatedDistribution #3301

@reemagit

Description

@reemagit

=== UPDATE ===
corrected typos in code
=== UPDATE ===

When I use the gate_logits parameter of ZeroInflatedPoisson I get unexpected behavior in the number of zeros sampled by the distribution. This doesn't happen if I use the gate parameter.

import torch
import pyro.distributions as dist

N=1000 # num samples
rate=torch.tensor(100.) ## arbitrary rate
p = torch.tensor(0.3) # inflation probability
logit = torch.log(p / (1-p)) # logit of inf. prob.

x1 = (dist.ZeroInflatedPoisson(rate=rate, gate_logits=logit).sample([N])==0).sum()/N 
x2 = (dist.ZeroInflatedPoisson(rate=rate, gate=p).sample([N])==0).sum()/N 

print(x1) # prints approximately 0.3
print(x2) # prints 1.

I would expect that both x1 and x2 are close to 0.3 (the inflation probability). Am I missing something or is this a bug?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions