-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
BUG: special.logsumexp: fix precision issue #18424
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Closes scipy#18295. Updates logsum calculation with more precise logsum calculation outlined in scipy#18295 where possible.
Fix test to assert_allclose to measure error at appropriate tolerance Co-authored-by: Matteo Raso <33975162+MatteoRaso@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a good start!
I think the implementation could use some explanatory comments and readability tweaks.
Also would be good to add a couple of tests of a similar accuracy problems with higher dimensional arrays and exercise non-default axis and/or keepdims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, it looks good. Some improvements might be possible.
# = a_max + log(m + R ) | ||
# = a_max + log(m) + log(1 + (1/m) * R) | ||
|
||
tmp0 = (a == a_max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tmp0 = (a == a_max) | |
mask = (a == a_max) |
Some more meaningful name might help reading the code.
tmp = b * np.exp(a - a_max) | ||
|
||
# sumexp for a != a_max | ||
tmp = b * np.exp(a - a_max) * (~tmp0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tmp = b * np.exp(a - a_max) * (~tmp0) | |
R = b * np.exp(a - a_max) * (~tmp0) |
Why not call it R like in the comment above? I know that my suggestion is not consistent as I only wanted to show the idea.
tmp = b * np.exp(a - a_max) * (~tmp0) | ||
|
||
# sumexp for where a = a_max | ||
tmp0 = b*tmp0.astype(float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tmp0 = b*tmp0.astype(float) | |
m = b * tmp0.astype(float) |
To align naming with comment.
What happens if a np.float32 is passed as a, does this cast to float(64) change the behavior of the current implementation, i.e. the dtype of the returned value?
|
||
# suppress warnings about log of zero | ||
with np.errstate(divide='ignore'): | ||
s = np.sum(tmp, axis=axis, keepdims=keepdims) | ||
s0 = np.sum(tmp0, axis=axis, keepdims=keepdims) | ||
sf = s + s0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sf = s + s0 | |
s = s0 + s1 |
It would be nice if the result, ie the sum of both terms, is called s as before.
sf*=sgn | ||
|
||
|
||
precise = ((s0>0) & (s>0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From here on, it seems to me that the code could be simpler. s0 and s are either 0 or positive and s==0 does not hurt. So, only special casing for s0==0 should be sufficient.
Not trying to block this PR but wouldn't it make sense in the long run to implement |
Note that even if it were available in scipy Cython API, scikit-learn would probably still have its own Cython implementation to make inlining possible. Same story is already true for |
Hey @sadigulcelik , would you like to return to this? There are some comments above to address but it sounds like this was close. |
Needs a rebase now. Would be great to decide on this one either way. |
This PR hasn't seen commit activity for more than a year, so I think bumping the milestone is the right call. May need a new champion if the original author is busy. |
It looks like this PR will be superseded by gh-21597. Thanks for this anyway @sadigulcelik! |
Closes #18295. Updates scipy.special.logsumexp calculation with more precise logsumexp calculation outlined in #18295 where possible.
Reference issue
What does this implement/fix?
Additional information