-
-
Notifications
You must be signed in to change notification settings - Fork 998
Description
Trying out this class and ended up with a rather strange looking loss curve (I was using a LR scheduler which explains the waviness):
It looks like the training losses have a sign error, although they are still optimizing in the right direction. This bug doesn't appear in TraceGraph_ELBO
or JitTrace_ELBO
.
Looking at the code, I think this is just because the JITed class overrides loss_and_grads
and lacks a negative sign which appears in the non-JIT version. The validation curve is generated via SVI.evalulate_loss
which calls TraceGraph_ELBO.loss
and includes the negation.
So it seems like a simple fix, but before I made a PR I wanted to flag the issue in case I'm missing something. There seem to be negations in a couple places and it seems easy to mess up