Skip to content

Conversation

btaba
Copy link
Collaborator

@btaba btaba commented Jun 10, 2025

pendula.xml does not work with JAX interop:

Warp CUDA error 900: operation not permitted when stream is capturing (in function cuda_unload_module, /builds/omniverse/warp/warp/native/warp.cu:3929)
2025-06-10 09:26:38.685489: E external/xla/xla/stream_executor/cuda/cuda_command_buffer.cc:715] CUDA error: Failed to destroy CUDA graph: CUDA_ERROR_INVALID_VALUE: invalid argument
E0610 09:26:38.685546  359201 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INTERNAL: CUDA error: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture
...
XlaRuntimeError: INTERNAL: CUDA error: Failed to end stream capture: CUDA_ERROR_STREAM_CAPTURE_INVALIDATED: operation failed due to a previous error during capture
pip freeze | grep jax
jax==0.6.1
jax-cuda12-pjrt==0.6.1
jax-cuda12-plugin==0.6.1
jaxlib==0.6.1

@shi-eric
Copy link
Collaborator

What CUDA driver version are you running?

@btaba
Copy link
Collaborator Author

btaba commented Jun 10, 2025

@shi-eric

nvidia-smi
Tue Jun 10 10:05:42 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06             Driver Version: 570.124.06     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5090        Off |   00000000:41:00.0 Off |                  N/A |
| 45%   33C    P0             72W /  575W |       1MiB /  32607MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

@shi-eric
Copy link
Collaborator

Thanks, what if you try removing this line?

os.environ["XLA_FLAGS"] = "--xla_gpu_graph_min_graph_size=1"

That command was present from the inception of the file from @thowell, but I found that when removing it, I can run the test on my computer...

@btaba
Copy link
Collaborator Author

btaba commented Jun 10, 2025

@shi-eric Hm ok it works if I remove the flag. How do I make sure the graph is being captured though?

@shi-eric
Copy link
Collaborator

I'm not knowledgeable about Jax, just trying to offer suggestions to understand more about what's going wrong

@btaba
Copy link
Collaborator Author

btaba commented Jun 10, 2025

Got it thanks @shi-eric. I'd want to make sure the graph is being captured, since this is a more minimal repro of the same bug I get with a much larger example where I don't use that flag at all.

@shi-eric
Copy link
Collaborator

@nvlukasz It seems like we're trying to call cuda_unload_module when an external framework is graph capturing?

@btaba
Copy link
Collaborator Author

btaba commented Jun 10, 2025

@shi-eric @nvlukasz maybe one hint here is that the difference between pendula.xml and humanoid.xml is that pendula.xml has multiple kinematic trees whereas humanoid.xml has just one. And tile operations get called in a loop:

Module mujoco_warp._src.smooth._tile_cholesky_factorize aa877a3 load on device 'cuda:0' took 11.58 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize 1c96286 load on device 'cuda:0' took 10.98 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize 9713ccb load on device 'cuda:0' took 10.93 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize d864c18 load on device 'cuda:0' took 10.55 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize db43ff4 load on device 'cuda:0' took 10.63 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize bdb05de load on device 'cuda:0' took 10.86 ms  (cached)
Module mujoco_warp._src.smooth._tile_cholesky_factorize 8f45d34 load on device 'cuda:0' took 10.79 ms  (cached)

@nvlukasz
Copy link

Hmmm, there should be a guard that prevents modules getting unloaded during graph capture. Perhaps something is falling through the cracks. Let me try to repro.

@nvlukasz
Copy link

Ok, looks like it's related to using local kernels and multiple module versions. This creates some module reloading churn, and there's a flaw in how we handle these situations. I have a fix incoming in Warp.

@btaba
Copy link
Collaborator Author

btaba commented Jun 10, 2025

Awesome thanks @nvlukasz for pinpointing the issue!

@nvlukasz
Copy link

Warp issue for reference: NVIDIA/warp#782

@nvlukasz
Copy link

The fix is merged now, should appear in the next nightly build.

@btaba
Copy link
Collaborator Author

btaba commented Jun 11, 2025

Confirming this passes now, thanks @nvlukasz

@btaba btaba marked this pull request as ready for review June 11, 2025 16:28
@btaba btaba requested review from shi-eric and erikfrey June 11, 2025 16:28
Copy link
Collaborator

@erikfrey erikfrey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, one nit


import mujoco_warp as mjwarp
from mujoco_warp._src.test_util import fixture

# TODO(team): JAX test is temporary, remove after we land MJX:Warp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this TODO still applies, we can talk to @shi-eric about what he thinks is the best way to test the chain of warp -> mjwarp -> MJX

We may still decide to keep the JAX test here but let's leave the TODO until we have that discussion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added back todo for now! Ideally we'd want these to fail early in the pipeline

@btaba btaba merged commit 6eb935f into google-deepmind:main Jun 11, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants