-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
System Info
- Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
- Python version: 3.10.15
- PyTorch version: 2.5.1
- CUDA device(s): not available
- Transformers version: 4.46.3
- Accelerate version: 1.1.1
- Accelerate config: not found
- Datasets version: 3.1.0
- HF Hub version: 0.26.2
- TRL version: 0.13.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.16.0
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 1.57.1
- PEFT version: 0.13.2
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
I ran 2 experiments on RLOOTrainer with a dataset of 490 training examples on a single device.
experiment 1
num_train_epochs: 1
per_device_train_batch_size: 4
gradient_accumulation_steps: 1
num_ppo_epochs: 1
num_mini_batches: 1
rloo_k: 2
Progress bar showed max 62 steps, but actual steps exceeded it, with final value of 125. Training run for 0.5 epochs and logged 500 episodes.
experiment 2
Same as above but:
rloo_k: 4
Progress bar showed max 62 steps, but actual steps exceeded it, with final value of 125. Training run for 0.25 epochs and logged 500 episodes.
Expected behavior
I should start with how I understand:
episodes
- based on number of sequences generated in rollouts. Number of training episodes is affected byrloo_k
, e.g., ifrloo_k
is increased by x2, total episodes should also raise by x2.steps
- number of parameter updates - this is shown by the training progress bar. It should be affected bynum_ppo_epochs
andnum_mini_batches
, e.g., increasingnum_ppo_epochs
ornum_mini_batches
x2 should increase total training steps x2.
Obviously, the experiments should stop at 1 epoch. They should run for dataset_len
* rloo_k
episodes, which is about ~1000 for experiment 1 and ~2000 for experiment 2. It seems that steps/episodes/epochs calculations are off.
Proposed fixes
Total episodes
trl/trl/trainer/rloo_trainer.py
Lines 122 to 123 in aed5da5
if args.total_episodes is None: # allow the users to define episodes in terms of epochs. | |
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) |
This should be multiplied by
rloo_k
:args.total_episodes = int(args.num_train_epochs * self.train_dataset_len * args.rloo_k)
Batch size
trl/trl/trainer/rloo_trainer.py
Lines 127 to 129 in aed5da5
args.local_batch_size = ( | |
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches | |
) |
I'm not sure why it's multiplied by
args.num_mini_batches
and I think it should be removed. Without it, this would be a number of rollout samples in the main PPO loop.It's worth noticing that the checks below are not necessary with the current version, as underlying values are multiplied by
args.num_mini_batches
before, but will make sense after this change.trl/trl/trainer/rloo_trainer.py
Lines 132 to 137 in aed5da5
args.mini_batch_size = exact_div( | |
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" | |
) | |
args.local_mini_batch_size = exact_div( | |
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" | |
) |
Total steps
trl/trl/trainer/rloo_trainer.py
Line 278 in aed5da5
self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2 |
I'm not sure why
//2
was introduced here (#2433). As per my understanding of steps
described before, it should be: self.state.max_steps = args.num_total_batches * args.num_mini_batches * args.num_ppo_epochs
Update of global step also depends on it:
trl/trl/trainer/rloo_trainer.py
Line 483 in aed5da5
self.state.global_step += 1 |
So it should be changed to:
self.state.global_step += args.num_ppo_epochs * args.num_mini_batches
Number of train epochs
trl/trl/trainer/rloo_trainer.py
Line 279 in aed5da5
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len |
rloo_k
should be included as a consequence of including it in total_episodes
before:self.state.num_train_epochs = (args.total_episodes / args.rloo_k) / self.train_dataset_len
Effects
I've tested RLOO trainer with these changes with configurations from experiments 1 and 2. But also for different batch sizes, gradient accumulation steps, ppo epochs, mini batches, rloo_k and train epochs. Also in a distributed environment with 4 devices. It seems that all steps (progress bar values), epochs, and episodes are calculated and logged correctly.
I will create a pull request for this soon, but I think that it should be discussed if my understanding is correct here. Some of these changes may (I didn't check it) be also valid for PPO trainer, as it seems that RLOO trainer is based on it.
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete