-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: special.logsumexp
: improve precision when one element is much bigger than the rest
#21597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
special.logsumexp
: improve precision when one element is much bigger than the rest
scipy/special/_logsumexp.py
Outdated
a[b == 0] = -xp.inf | ||
|
||
# Scale by real part for complex inputs, because this affects | ||
# the magnitude of the exponential. | ||
if xp_size(a) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make size-0 arrays a special case. It possible that they can be made to work with the existing code again. I'd be happy to review that as a simple follow-up. But conceptually, it's simpler to treat the special case here so we don't need to complicate the algorithmic code.
# Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy | ||
out = xp.squeeze(out, axis=axis) if not keepdims else out | ||
sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn | ||
out = out[()] if out.ndim == 0 else out | ||
sgn = sgn[()] if (sgn is not None and sgn.ndim == 0) else sgn | ||
|
||
return (out, sgn) if return_sign else out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For simplicity, _logsumexp
uses keepdims
throughout and always returns sgn
(which may be None
if it won't be used). Reduce the axes away, convert to 0d, and choose what to return at the end.
|
||
# Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy | ||
out = xp.squeeze(out, axis=axis) if not keepdims else out | ||
sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is annoyingly complicated, but I guess we can't remove keepdims
now. Necessary to do this for sgn? I suspect so, otherwise you wouldn't have this line here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but I guess we can't remove
keepdims
now
It is easier to do the calculations internally with keepdims=True
. The annoying thing is supporting keepdims=False
, the default, but it is very natural for the user to want this function to behave like any other reducing operation.
It looked a little simpler in the old implementation, but this is how it was working - it used keepdims
throughout (because that's much more convenient) and squeezed at the end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible that some of this could be removed by letting the last reducing operation in _logsumexp
eliminate the axes (if keepdims=False
). For simplicity, I just decided to ignore that possibility for now and separate these details in logsumexp
from the math stuff in _logsumexp
.
|
Failure in Array API job surprised me given that there was no problem with |
[skip ci]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Matt, just a couple questions.
Thanks Matt and for the review Stéfan! |
FYI it seems like this change has had a pretty big impact on the values of imaginary components of results: #21610 |
Reference issue
Closes gh-18295
Supersedes gh-18424
May be used to address gh-19521/gh-19549 and make(The idea of usinglog_softmax
array API compatible.log1p
is essentially the same, but I thinksoftmax
needs its own implementation.)What does this implement/fix?
gh-18295 reported that
logsumexp
can lose precision when one element is much bigger than the rest, especially when the exponential of it is close to 1. This improves the precision as described in the issue and linked paper.Additional information
gh-18424 was out of date after converting
logsumexp
to the array API. For instance,xp.max
does not work on the real component of complex arrays, so conversion of some parts would not be trivial. Also, there were some unresolved comments about the complexity, so I chose to start from scratch.Also,
logsumexp
was getting quite complicated as it was. I found it challenging to work within the existing structure, so I refactored to simplify (26bb631) before getting started with the upgrade.I'll add a review that documents the math inline with the code.