-
Notifications
You must be signed in to change notification settings - Fork 833
Closed
Description
Problem Description
Given the incredible performance of the DDPG + JAX prototype (#187), it's worth prototyping JAX with other algorithms as well! This issue tracks the overall progress of integrating JAX with CleanRL.
Useful resources
- (a working JAX + DDPG example as a reference implementation) prototype jax with ddpg #187
- CleanRL's DDPG docs: https://docs.cleanrl.dev/rl-algorithms/ddpg/
- (a working JAX + PPO example as a reference implementation) PPO + JAX + EnvPool + MuJoCo #217
- CleanRL's PPO docs: https://docs.cleanrl.dev/rl-algorithms/ppo/
Common gotchas and errors:
- Simple Classification with MSE Loss jax-ml/jax#2697 (comment)
- Putting the NN parameters as a non-first argument in loss fn results in a weird error: TypeError: unsupported operand type(s) for *: 'float' and 'FrozenDict' google-deepmind/optax#366 (comment)
Useful pattern when extending
In CleanRL a filediff is incredibly helpful. For example, if I want to learn how TD3 is different from DDPG, I could do
- open VS code and select
ddpg_continuous_action.py
andtd3_continuous_action.py
- right-click and left-click "compare selected"
- the following file diff window shows up
Contribution process
There is a contribution checklist to help streamline the contribution process. For each new contribution, we'd need to add documentation, tests, run benchmark experiments, etc. See #186 as an example.
Tracked issues
love2d-lua
Metadata
Metadata
Assignees
Labels
No labels