Skip to content

Conversation

OlaRonning
Copy link
Member

Removes bias on correlation in SBvM normalization and fixes NaN in normalization for negative correlation (corr < 0.).

From #1511.

@@ -400,7 +400,7 @@ def norm_const(self):

fs = (
lbinoms.reshape(-1, 1)
+ 2 * m * jnp.log(corr)
+ m * jnp.log(corr**2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! I guess you can clip by tiny here in case users set corr=0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes sense.

fehiepsi
fehiepsi previously approved these changes Dec 17, 2022
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Ola! Just a small comment on the floating point.

+ 2 * m * jnp.log(corr)
- m * jnp.log(4 * jnp.prod(conc, axis=-1))
fs = lbinoms.reshape(-1, 1) + m * (
jnp.log(jnp.clip(corr**2, a_min=jnp.finfo(float).tiny))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The float finfo is too small for single precision, how about jnp.finfo(jnp.result_type(float)).tiny?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants