-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
ENH: fft: support array API standard #19005
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
151 commits
Select commit
Hold shift + click to select a range
b8be637
WIP: ENH: fft: array API support [Design 4]
lucascolley 96dec74
MAINT: use npbasic instead of pfft
lucascolley cfc5265
ENH: xp for `init_nd_shape_and_axes`
lucascolley c27b516
ENH: `xp` and `device` keywords for `fftfreq` and `rfftfreq`
lucascolley 82e84df
TST: `test_identity` in `TestFFT1D`
lucascolley 8e2d9c5
TST: `test_identity` working with cupy
lucascolley ea0075e
TST: standard fns in `test_numpy.py`
lucascolley 02bf620
ENH: `ValueError`'s in `_basic.py`
lucascolley c09f1f3
MAINT: `test_numpy.py` renamed to `test_basic.py`
lucascolley 092db64
ENH: `xp` added to `_fftlog.py`
lucascolley cf5dcdb
MAINT: minor docstring improvements
lucascolley 27dd21d
MAINT: add `.` to error msg
lucascolley 22ffbcb
ENH: `fftfreq` and `rfftfreq` with `device` exception
lucascolley 0eaf77f
MAINT: remove unnecessary imports
lucascolley 1482b8d
MAINT: added keyword args
lucascolley 0964cfb
ENH: change `norm` to correct default value
lucascolley c90fa5d
MAINT: remove redundant `is not False`
lucascolley 15c398e
ENH: `is_numpy` helper added
lucascolley bafb6f5
MAINT: remove unnecessary fns from `_fftlog.py`
lucascolley 677f7b1
ENH: use `is_numpy` in `_basic.py`
lucascolley e39efee
MAINT/TST: fftlog tests switched to `_fftlog_np.py` for now
lucascolley 09ff62c
TST: namespace preservation
lucascolley 4a85dd5
`fftfreq`: `xp=None` and remove exception for `device`
lucascolley 99b8832
MAINT: move `_assert_matching_namespace` to `_array_api.py`
lucascolley 4721ce0
ENH: add `set_assert_allclose`
lucascolley 38b8428
MAINT: remove unnecessary import
lucascolley 6824a77
TST: `test_fft_n` array API compatible
lucascolley 5fd4dad
ENH: `fftfreq` now passing `device` to `xp`
lucascolley 656f747
CI: let some CI run on my fork
lucascolley 3140309
MAINT: remove unnecessary import
lucascolley 147c8c3
CI: stop some worklows on my fork
lucascolley efbbcbc
TST: `TestNamespaces` added to `test_basic.py`
lucascolley 87c7097
TST: `test_basic.py` now uses `@skip_if_array_api_gpu`
lucascolley bc5c9c7
TST: `fft1` reverted to `np` version
lucascolley ea9a157
TST: revert tests with non-standard fns to `@skip_if_array_api`
lucascolley 3c7c883
BUG: fix issues with `sqrt` of `int`s
lucascolley 3eb08cd
BUG: add `dtype=float64` to fix `sqrt` exceptions
lucascolley d72e485
MAINT: remove erroneous style change
lucascolley 5af7777
MAINT: reduce number of lines for `xp.asarray` calls
lucascolley 7b4b9ce
BUG: fix return structure of `_init_nd_shape_and_axes`
lucascolley 576d764
TST: skip more non-standard fns tests
lucascolley 3bd3155
TST: `Test_init_nd_shape_and_axes` array API compatible
lucascolley d6f2907
TST: skip more non-standard fns tests
lucascolley d770a5b
TST: skip some fns only on GPU
lucascolley fbed725
MAINT: revert `test_dtypes` changes
lucascolley 0764c55
BUG: fix `set_assert_allclose` for `xp=None`
lucascolley 4ff5909
TST: only skip on GPU for `test_axes`
lucascolley 9c4df0e
TST: remove `set_assert_allclose` for fns skipped on GPU
lucascolley 1a210d7
TST: `TestFFTThreadSafe` array API compatible
lucascolley d7b936a
TST: remove unnecessary GPU skip
lucascolley f5bc71a
TST: `test_fftlog.py` array API compatible
lucascolley 1a81819
TST: most of `tets_real_transforms.py` array API compatible
lucascolley fe8fa04
TST: added pytorch case in `set_assert_allclose`
lucascolley 397f9d1
BUG: add `.copy()` to pass pytorch tests
lucascolley 1b7a122
TST/MAINT: use `size` helper from array-api-compat
lucascolley d3b1d05
TST: add pytorch case for `test_fft_n`
lucascolley be89ec2
TST: `test_identity` pytorch compatible
lucascolley 94a2da5
TST/BUG: fix dtype issues in `test_hfft`
lucascolley 307a3ed
TST: add missing `xp.asarray`s
lucascolley c010c7b
TST: `TestFFTThreadSafe` pytorch compatible
lucascolley 7e0a39a
ENH: handle `xp=None` case for `fftfreq` and `rfftfreq`
lucascolley 2e10b49
MAINT: import `size` from `_array_api` instead of `array_api_compat`
lucascolley 43974d8
TST: remove unnecessary skips in `test_multithreading.py`
lucascolley ec77650
MAINT/TST: use `xp.linalg`
lucascolley b5cf63c
TST/BUG: revert to `assert_array_almost_equal` in `test_dtypes`
lucascolley e395a33
TST: change `linalg.norm` to `linalg.vector_norm`
lucascolley 2f0bd40
TST: change `rtol` from `1e-11` to `1e-12`
lucascolley e597d85
TST: exceptions for unsupported params
lucascolley 4352792
STY: comments cleanup
lucascolley 2484b80
TST: skip tests on `torch` backend where it is incompatible
lucascolley 4626d2c
TST: remove `torch` skips where it breaks with `parametrize`
lucascolley df80b91
MAINT: move basic docstrings and rename to `uarray`
lucascolley 9112619
MAINT: revert to original order of `basic` functions
lucascolley fb911bc
MAINT: move docstrings to new file for `helper`
lucascolley 2e6813f
MAINT: use `==` instead of `in`
lucascolley 7c6c2af
TST: change `rtol` to `1e-10` to let `test_identity` pass
lucascolley 42b76e6
TST: skip some multithreading tests for array API
lucascolley 7d71b3e
MAINT: merge `helper` files into one and get docstrings from numpy
lucascolley ac1141a
MAINT: `realtransforms` move docstrings to new files and rename `np` …
lucascolley f77b902
MAINT: `fftlog` move docstrings to new file
lucascolley 14c4ecb
TST/STY: `xfails`s for pytorch and modify comments
lucascolley 4e467ac
BUG: correct import location for fht docstrings
lucascolley f1e63ab
MAINT: change `is_numpy` to just check for the compat version
lucascolley bd3447f
STY: PEP8 clean-up
lucascolley c909dfd
BUG: fix circular import issue with `fftlog`
lucascolley 3a1222a
CI: remove workflow comment for my fork
lucascolley 5f45410
Merge branch 'main' into 'fft_array_api'
lucascolley f7277ed
BUG: avoid `device` for `xp=np` in `fftfreq` and `rfftfreq`
lucascolley c558d40
TST: adjust tolerance so tests pass
lucascolley 55622bc
TST: make `test_dtypes` array API compatible
lucascolley e9222cd
TST: add tests for new helper fns
lucascolley 2cb6268
MAINT: comments to explain `uarray` dispatch structure
lucascolley a097d91
MAINT: use `elif` in `test_helper.py`
lucascolley 49a1676
MAINT: change the use of `xp` in unrelated comments
lucascolley 5466d04
MAINT/TST: avoid extra imports with `xp.testing`
lucascolley 564a9bd
TST: skip `TestFFTThreadSafe` on GPU
lucascolley 46d43e0
Merge branch 'main' into fft_array_api
lucascolley 4168f8c
STY: PEP 257 in `_fftlog_np.py`
lucascolley 4a1faa7
STY: more PEP 257
lucascolley ad8cbe9
MAINT: use `copy` utility from #19014
lucascolley 08a47dc
Merge branch 'main' into fft_array_api
lucascolley dbbc1e9
MAINT: move `set_assert_allclose` to `_lib._array_api.py`
lucascolley e73cf27
MAINT: refactor `_basic.py` with execute helpers
lucascolley e43b8ba
Merge branch 'main' into fft_array_api
lucascolley 9890300
DOC/MAINT: include example from #19129 in renamed file
lucascolley 6b36494
Merge branch 'main' into fft_array_api
lucascolley c648e35
TST/MAINT: remove unnecessary `xp_test` and gpu skip
lucascolley 88ae151
MAINT: remove unnecessary `xp_test`
lucascolley 939f723
BUG: fix issue in `_execute_nD` in `_basic.py`
rgommers 98f6d9d
TST: rework 1-D fft/rfft/hfft tests for array API
rgommers 7057899
MAINT: refactor fft `_execute_1D/nD`
rgommers a889b7d
MAINT: refactor real transforms with `execute` fns
lucascolley 1c12ae7
Merge branch 'main' into fft_array_api
lucascolley 6da65af
TST: replace `xfail`s with skips
lucascolley 2262fbe
Merge branch 'main' into fft_array_api
lucascolley 80f6f8f
MAINT: use `assert_close` and `assert_equal`
lucascolley 7eed157
MAINT: remove `set_assert_allclose`
lucascolley 20cfa33
Merge branch 'main' into fft_array_api
lucascolley 4f5691a
MAINT: remove `_exectute_1D`
lucascolley c33fae8
TST: remove `skip_if_array_api` markers
lucascolley 527a0a4
STY: clean-up
lucascolley 5ff5615
MAINT: refactor `test_axes_*`
lucascolley 61eceff
ENH: array-agnostic `fftlog`
lucascolley ee682eb
MAINT: refactor `fftlog` with `from_uarray` parameter
lucascolley 3e70cc4
MAINT/STY: clean-up
lucascolley a113893
MAINT: remove unnecessary `copy` from `fftlog`
lucascolley 0e2a4d3
MAINT: refactor and clean up `_realtransforms`
lucascolley 283419a
STY: revert `realtransforms` docstring changes
lucascolley 719a7b1
MAINT: remove more unnecessary `copy`s from `realtransforms`
lucascolley dee321b
STY: lint to 88 char line length
lucascolley 0a2abf8
MAINT: move `fftlog` backend to a separate file
lucascolley 07c6dec
MAINT: import array assertions as future names
lucascolley 17bf785
Merge remote-tracking branch 'upstream/main' into fft_array_api
lucascolley b0248e2
MAINT: update imports for gh-19186
lucascolley d3a9d87
TST/MAINT: modify tests to satisfy stricter assertions
lucascolley 5d6de96
MAINT/BUG: refactor to fix custom uarray backends
lucascolley 23bffb7
TST/BUG: check that the mock backend implements the public scipy API
lucascolley 65684db
MAINT: improve comments for `_basic_backend`
lucascolley ea36a04
MAINT: allow `overwite_x` to be dropped
lucascolley a158634
MAINT: move singular transform checks to `fhtcoeff`
lucascolley 291b9d7
MAINT: check `device` is `None` when not supported
lucascolley 4eef6be
MAINT: replace slicing with `flip`
lucascolley 5f68e8c
BUG: use `kwargs` in basic backend to be compatible with standard
lucascolley aed9f82
MAINT: revert `_init_nd_shape_and_axes` to returning lists
lucascolley dda7057
MAINT: lint
lucascolley 5002220
Merge branch 'main' into fft_array_api
rgommers b7e817a
CI: add scipy.fft to the submodules tested in the array API CI job
rgommers 14cd806
MAINT: update `_init_nd_shape_and_axes` return types and tests
rgommers 6ee7a89
STY: fix linter complaint
rgommers 8994857
DOC: update fftfreq/rfftfreq docs
lucascolley 2befe78
TST: fix fftfreq/rfftfreq tests for torch-cuda
rgommers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from scipy._lib._array_api import array_namespace, is_numpy | ||
from . import _pocketfft | ||
import numpy as np | ||
|
||
|
||
def arg_err_msg(param): | ||
return f'Providing {param!r} is only supported for numpy arrays.' | ||
|
||
|
||
def _validate_fft_args(workers, plan, norm): | ||
if workers is not None: | ||
raise ValueError(arg_err_msg("workers")) | ||
if plan is not None: | ||
raise ValueError(arg_err_msg("plan")) | ||
if norm is None: | ||
norm = 'backward' | ||
return norm | ||
|
||
|
||
def _execute(func_str, pocketfft_func, x, **kwargs): | ||
xp = array_namespace(x) | ||
# pocketfft is used whenever SCIPY_ARRAY_API is not set, | ||
# or x is a NumPy array or array-like. | ||
# When SCIPY_ARRAY_API is set, we try to use xp.fft for CuPy arrays, | ||
# PyTorch arrays and other array API standard supporting objects. | ||
# If xp.fft does not exist, we attempt to convert to np and back to use pocketfft. | ||
if is_numpy(xp): | ||
return pocketfft_func(x, **kwargs) | ||
|
||
try: | ||
s = kwargs["n"] | ||
except KeyError: | ||
s = kwargs["s"] | ||
try: | ||
axes = kwargs["axis"] | ||
except KeyError: | ||
axes = kwargs["axes"] | ||
norm = kwargs["norm"] | ||
workers = kwargs["workers"] | ||
plan = kwargs["plan"] | ||
|
||
norm = _validate_fft_args(workers, plan, norm) | ||
if hasattr(xp, 'fft'): | ||
xp_func = getattr(xp.fft, func_str) | ||
return xp_func(x, s, axes, norm=norm) | ||
|
||
x = np.asarray(x) | ||
y = pocketfft_func(x, s, axes, norm=norm) | ||
return xp.asarray(y) | ||
|
||
|
||
def fft(x, n=None, axis=-1, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('fft', _pocketfft.fft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, *, | ||
plan=None): | ||
return _execute('ifft', _pocketfft.ifft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def rfft(x, n=None, axis=-1, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('rfft', _pocketfft.rfft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def irfft(x, n=None, axis=-1, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('irfft', _pocketfft.irfft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def hfft(x, n=None, axis=-1, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('hfft', _pocketfft.hfft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def ihfft(x, n=None, axis=-1, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('ihfft', _pocketfft.ihfft, x, n=n, axis=axis, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def fftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('fftn', _pocketfft.fftn, x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
|
||
def ifftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('ifftn', _pocketfft.ifftn, x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def fft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.fft2(x, s=s, axes=axes, norm=norm, | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def ifft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.ifft2(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def rfftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('rfftn', _pocketfft.rfftn, x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def rfft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.rfft2(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def irfftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
return _execute('irfftn', _pocketfft.irfftn, x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, workers=workers, plan=plan) | ||
|
||
|
||
def irfft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.irfft2(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def hfftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.hfftn(x, s=s, axes=axes, norm=norm, | ||
rgommers marked this conversation as resolved.
Show resolved
Hide resolved
|
||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def hfft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.hfft2(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def ihfftn(x, s=None, axes=None, norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.ihfftn(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) | ||
|
||
|
||
def ihfft2(x, s=None, axes=(-2, -1), norm=None, | ||
overwrite_x=False, workers=None, *, plan=None): | ||
xp = array_namespace(x) | ||
x = np.asarray(x) | ||
y = _pocketfft.ihfft2(x, s=s, axes=axes, norm=norm, | ||
overwrite_x=overwrite_x, | ||
workers=workers, plan=plan) | ||
return xp.asarray(y) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.