Skip to content

Conversation

dfm
Copy link
Contributor

@dfm dfm commented Apr 2, 2025

Category

  • New feature
  • Bugfix
  • Breaking change
  • Refactoring
  • Documentation
  • Other (please explain)

Description

The jaxlib.xla_xlient.register_custom_call_target function is deprecated in favor of the similar function under jax.ffi: https://docs.jax.dev/en/latest/_autosummary/jax.ffi.register_ffi_target.html

By explicitly setting api_version=0, this should have no affect on behavior.

Changelog

N/A

Before your PR is "Ready for review"

  • All commits are signed-off to indicate that your contribution
    adheres to the Developer Certificate of Origin requirements
  • Necessary tests have been added
  • Documentation is up-to-date
  • Auto-generated files modified by compiling Warp and building the documentation have been updated (e.g. stubs.py, functions.rst)
  • Code passes formatting and linting checks with pre-commit run -a

@shi-eric shi-eric requested a review from nvlukasz April 2, 2025 15:10
@nvlukasz
Copy link
Contributor

nvlukasz commented Apr 2, 2025

Thanks! One of the reasons that we keep this old custom_call around is to support users with older versions of JAX. Users with newer JAX are encouraged to switch to the new FFI interop utilities.

So I'd like to make sure that we can still support older JAX versions here.

@dfm
Copy link
Contributor Author

dfm commented Apr 2, 2025

Makes sense. What's the minimum version that you want to support?

@nvlukasz
Copy link
Contributor

nvlukasz commented Apr 2, 2025

The minimum JAX version we'd like to support for custom_call is 0.4.25.

…get.

Signed-off-by: Dan Foreman-Mackey <danfm@google.com>
@dfm dfm force-pushed the deprecated-function branch from 1da0d1e to 9e35db8 Compare April 2, 2025 18:16
@dfm
Copy link
Contributor Author

dfm commented Apr 2, 2025

Sounds good! I've updated the PR to work on older versions of JAX, even when the jax.ffi submodule isn't available.

Edited to add: It will be useful to merge this change once we're happy that it supports the appropriate versions because the xla_client version will be removed in an upcoming JAX release!

@nvlukasz
Copy link
Contributor

nvlukasz commented Apr 2, 2025

Looks good, thanks!

@shi-eric shi-eric merged commit 9e35db8 into NVIDIA:main Apr 3, 2025
2 checks passed
pull bot pushed a commit to mcx/warp that referenced this pull request Apr 3, 2025
Improved handling of deprecated JAX features (NVIDIAGH-613)

See merge request omniverse/warp!1214
@shi-eric
Copy link
Contributor

shi-eric commented Apr 3, 2025

This is now merged into main, thanks for your contribution @dfm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants