-
Notifications
You must be signed in to change notification settings - Fork 70
Add pendula.xml to jax_test. #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
What CUDA driver version are you running? |
|
Thanks, what if you try removing this line?
That command was present from the inception of the file from @thowell, but I found that when removing it, I can run the test on my computer... |
@shi-eric Hm ok it works if I remove the flag. How do I make sure the graph is being captured though? |
I'm not knowledgeable about Jax, just trying to offer suggestions to understand more about what's going wrong |
Got it thanks @shi-eric. I'd want to make sure the graph is being captured, since this is a more minimal repro of the same bug I get with a much larger example where I don't use that flag at all. |
@nvlukasz It seems like we're trying to call |
@shi-eric @nvlukasz maybe one hint here is that the difference between
|
Hmmm, there should be a guard that prevents modules getting unloaded during graph capture. Perhaps something is falling through the cracks. Let me try to repro. |
Ok, looks like it's related to using local kernels and multiple module versions. This creates some module reloading churn, and there's a flaw in how we handle these situations. I have a fix incoming in Warp. |
Awesome thanks @nvlukasz for pinpointing the issue! |
Warp issue for reference: NVIDIA/warp#782 |
The fix is merged now, should appear in the next nightly build. |
Confirming this passes now, thanks @nvlukasz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, one nit
|
||
import mujoco_warp as mjwarp | ||
from mujoco_warp._src.test_util import fixture | ||
|
||
# TODO(team): JAX test is temporary, remove after we land MJX:Warp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this TODO still applies, we can talk to @shi-eric about what he thinks is the best way to test the chain of warp -> mjwarp -> MJX
We may still decide to keep the JAX test here but let's leave the TODO until we have that discussion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added back todo for now! Ideally we'd want these to fail early in the pipeline
pendula.xml does not work with JAX interop: