Skip to content

GRPOTrainer reward functions cannot access training progress #3668

@seungduk-yanolja

Description

@seungduk-yanolja

Feature request

Proposed Solution

Add the trainer state to the reward function kwargs before calling the reward function:

# In GRPOTrainer, before calling the reward function:
reward_kwargs["trainer_state"] = self.state

This would allow reward functions to access training progress:

def custom_reward_func(
    completions: List[List[Dict[str, str]]],
    sources: List[Optional[str]],
    trainer_state: Any,
    **kwargs
) -> List[float]:
    """Custom reward function with curriculum learning."""
    
    # Implement reward curriculum based on training progress
    training_progress = trainer_state.global_step / trainer_state.max_steps
    
    # Example: No reward for first 30% of training
    if training_progress < 0.3:
        return [0.0] * len(completions)
    
    # Apply actual reward function after warmup
    responses = [comp[0]['content'] for comp in completions]
    return [compute_reward(response, source) for response, source in zip(responses, sources)]

Benefits

  1. Flexibility: Enables dynamic reward strategies based on training progress
  2. Backward Compatible: Existing reward functions would continue to work (they would just ignore the extra parameter)
  3. Minimal Change: Requires only a single line addition to pass the trainer state
  4. Enables Advanced Training Strategies: Curriculum learning, warm-up periods, and progressive reward shaping

Alternative Solutions Considered

  1. Global variable: Not clean and could cause issues with multiple trainer instances
  2. Callback system: More complex and overkill for this simple use case
  3. Subclassing GRPOTrainer: Requires users to maintain custom trainer code

Additional Context

This feature would be particularly useful for:

  • Research on reward curriculum and progressive training
  • Production systems that need stable training with gradual reward introduction
  • Complex reward functions that benefit from warm-up periods

Motivation

Description

Currently, when using custom reward functions with GRPOTrainer in TRL, there's no way to access the training progress information (such as global_step and max_steps) within the reward function. This limitation prevents implementing dynamic reward strategies like reward curriculum, where rewards can be adjusted based on training progress.

Use Case

A common use case is implementing reward curriculum, where different reward strategies are applied at different stages of training. For example:

  • Early training: Use simpler rewards or no rewards to allow the model to explore
  • Mid training: Gradually introduce more complex reward signals
  • Late training: Apply full reward function

Problem

The reward function signature only receives:

  • prompts
  • completions
  • completion_ids
  • Additional custom kwargs

However, it lacks access to:

  • Current global step
  • Total training steps
  • Overall training progress

This makes it impossible to implement training-aware reward strategies.

Your contribution

I will write a simple PR for it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions