Skip to content

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Sep 20, 2024

Reference issue

Closes gh-18295
Supersedes gh-18424
May be used to address gh-19521/gh-19549 and make log_softmax array API compatible. (The idea of using log1p is essentially the same, but I think softmax needs its own implementation.)

What does this implement/fix?

gh-18295 reported that logsumexp can lose precision when one element is much bigger than the rest, especially when the exponential of it is close to 1. This improves the precision as described in the issue and linked paper.

Additional information

gh-18424 was out of date after converting logsumexp to the array API. For instance, xp.max does not work on the real component of complex arrays, so conversion of some parts would not be trivial. Also, there were some unresolved comments about the complexity, so I chose to start from scratch.

Also, logsumexp was getting quite complicated as it was. I found it challenging to work within the existing structure, so I refactored to simplify (26bb631) before getting started with the upgrade.

I'll add a review that documents the math inline with the code.

@mdhaber mdhaber added enhancement A new feature or improvement scipy.special labels Sep 20, 2024
@mdhaber mdhaber changed the title Gh18295 ENH: special.logsumexp: improve lost precision when one element is bigger than the rest Sep 20, 2024
@mdhaber mdhaber changed the title ENH: special.logsumexp: improve lost precision when one element is bigger than the rest ENH: special.logsumexp: improve precision when one element is much bigger than the rest Sep 20, 2024
a[b == 0] = -xp.inf

# Scale by real part for complex inputs, because this affects
# the magnitude of the exponential.
if xp_size(a) == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make size-0 arrays a special case. It possible that they can be made to work with the existing code again. I'd be happy to review that as a simple follow-up. But conceptually, it's simpler to treat the special case here so we don't need to complicate the algorithmic code.

Comment on lines +117 to +123
# Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy
out = xp.squeeze(out, axis=axis) if not keepdims else out
sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn
out = out[()] if out.ndim == 0 else out
sgn = sgn[()] if (sgn is not None and sgn.ndim == 0) else sgn

return (out, sgn) if return_sign else out
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For simplicity, _logsumexp uses keepdims throughout and always returns sgn (which may be None if it won't be used). Reduce the axes away, convert to 0d, and choose what to return at the end.


# Deal with shape details - reducing dimensions and convert 0-D to scalar for NumPy
out = xp.squeeze(out, axis=axis) if not keepdims else out
sgn = xp.squeeze(sgn, axis=axis) if (sgn is not None and not keepdims) else sgn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is annoyingly complicated, but I guess we can't remove keepdims now. Necessary to do this for sgn? I suspect so, otherwise you wouldn't have this line here.

Copy link
Contributor Author

@mdhaber mdhaber Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I guess we can't remove keepdims now

It is easier to do the calculations internally with keepdims=True. The annoying thing is supporting keepdims=False, the default, but it is very natural for the user to want this function to behave like any other reducing operation.

It looked a little simpler in the old implementation, but this is how it was working - it used keepdims throughout (because that's much more convenient) and squeezed at the end.

Copy link
Contributor Author

@mdhaber mdhaber Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that some of this could be removed by letting the last reducing operation in _logsumexp eliminate the axes (if keepdims=False). For simplicity, I just decided to ignore that possibility for now and separate these details in logsumexp from the math stuff in _logsumexp.

@mdhaber
Copy link
Contributor Author

mdhaber commented Sep 20, 2024

Will fix the vectorization bug (in _elements_and_indices_with_max_real) that is causing the failures shortly. Done.

@mdhaber
Copy link
Contributor Author

mdhaber commented Sep 20, 2024

Failure in Array API job surprised me given that there was no problem with array_api_strict, but data-apis/array-api-strict#62.

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Matt, just a couple questions.

@lucascolley lucascolley added this to the 1.15.0 milestone Sep 21, 2024
@lucascolley lucascolley added the needs-release-note a maintainer should add a release note written by a reviewer/author to the wiki label Sep 21, 2024
@lucascolley
Copy link
Member

Thanks Matt and for the review Stéfan!

@apaszke
Copy link

apaszke commented Sep 23, 2024

FYI it seems like this change has had a pretty big impact on the values of imaginary components of results: #21610

@mdhaber mdhaber removed the needs-release-note a maintainer should add a release note written by a reviewer/author to the wiki label Nov 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: special: Loss of precision in logsumexp
5 participants