Skip to content

[bug] Sign error in JitTraceGraph_ELBO #2798

@jamestwebber

Description

@jamestwebber

Trying out this class and ended up with a rather strange looking loss curve (I was using a LR scheduler which explains the waviness):

elbo_bug

It looks like the training losses have a sign error, although they are still optimizing in the right direction. This bug doesn't appear in TraceGraph_ELBO or JitTrace_ELBO.

Looking at the code, I think this is just because the JITed class overrides loss_and_grads and lacks a negative sign which appears in the non-JIT version. The validation curve is generated via SVI.evalulate_loss which calls TraceGraph_ELBO.loss and includes the negation.

So it seems like a simple fix, but before I made a PR I wanted to flag the issue in case I'm missing something. There seem to be negations in a couple places and it seems easy to mess up

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions