-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Feature request
Proposed Solution
Add the trainer state to the reward function kwargs before calling the reward function:
# In GRPOTrainer, before calling the reward function:
reward_kwargs["trainer_state"] = self.state
This would allow reward functions to access training progress:
def custom_reward_func(
completions: List[List[Dict[str, str]]],
sources: List[Optional[str]],
trainer_state: Any,
**kwargs
) -> List[float]:
"""Custom reward function with curriculum learning."""
# Implement reward curriculum based on training progress
training_progress = trainer_state.global_step / trainer_state.max_steps
# Example: No reward for first 30% of training
if training_progress < 0.3:
return [0.0] * len(completions)
# Apply actual reward function after warmup
responses = [comp[0]['content'] for comp in completions]
return [compute_reward(response, source) for response, source in zip(responses, sources)]
Benefits
- Flexibility: Enables dynamic reward strategies based on training progress
- Backward Compatible: Existing reward functions would continue to work (they would just ignore the extra parameter)
- Minimal Change: Requires only a single line addition to pass the trainer state
- Enables Advanced Training Strategies: Curriculum learning, warm-up periods, and progressive reward shaping
Alternative Solutions Considered
- Global variable: Not clean and could cause issues with multiple trainer instances
- Callback system: More complex and overkill for this simple use case
- Subclassing GRPOTrainer: Requires users to maintain custom trainer code
Additional Context
This feature would be particularly useful for:
- Research on reward curriculum and progressive training
- Production systems that need stable training with gradual reward introduction
- Complex reward functions that benefit from warm-up periods
Motivation
Description
Currently, when using custom reward functions with GRPOTrainer
in TRL, there's no way to access the training progress information (such as global_step
and max_steps
) within the reward function. This limitation prevents implementing dynamic reward strategies like reward curriculum, where rewards can be adjusted based on training progress.
Use Case
A common use case is implementing reward curriculum, where different reward strategies are applied at different stages of training. For example:
- Early training: Use simpler rewards or no rewards to allow the model to explore
- Mid training: Gradually introduce more complex reward signals
- Late training: Apply full reward function
Problem
The reward function signature only receives:
prompts
completions
completion_ids
- Additional custom kwargs
However, it lacks access to:
- Current global step
- Total training steps
- Overall training progress
This makes it impossible to implement training-aware reward strategies.
Your contribution
I will write a simple PR for it.