Skip to content

GRPOTrainer does not have a feature flag to prevent dataset shuffling #2998

@sidmadala

Description

@sidmadala

Reproduction

In GRPOTrainer:

def __iter__(self):
    # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
    indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()

    #    [2, 4, 3, 1, 0, 6, 5]
    # -> [[2, 4, 3], [1, 0, 6], [5]]  (batch_size = 3)
    indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]

    #    [[2, 4, 3], [1, 0, 6], [5]]
    # -> [[2, 4, 3], [1, 0, 6]]
    indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]

    for chunk in indexes:
        for _ in range(self.repeat_count):
            for index in chunk:
                for _ in range(self.mini_repeat_count):
                    yield index

def __len__(self) -> int:
    return self.num_samples * self.mini_repeat_count * self.repeat_count

outputs:

Shuffled dataset even though I want to apply Curriculum Learning (no shuffling). I've pasted some code that works the way I want below, but would love any advice/help in potentially creating a PR for this issue since this would be my first time contributing, and I am unsure what other files would need to be changed (i.e. GRPOConfig, PretrainedConfig, etc.). I also want to keep the option for shuffling based on the Config class since others will likely want to retain the original behavior.

Potential Solution:

Modify or create new class similar to RepeatRandomShuffler called RepeatSequentialShuffler that modifies the iteration logic as follow:

def __iter__(self):
    # Sequential order, repeat each index `repeat_count` times
    indexes = [idx for idx in range(self.num_samples) for _ in range(self.repeat_count)]
    return iter(indexes)

def __len__(self):
    return self.num_samples * self.repeat_count

System Info

  • Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
  • Python version: 3.12.9
  • PyTorch version: 2.5.1
  • CUDA device(s): NVIDIA A100-SXM4-80GB
  • Transformers version: 4.47.1
  • Accelerate version: 1.2.0
  • Accelerate config: not found
  • Datasets version: 3.1.0
  • HF Hub version: 0.29.1
  • TRL version: 0.12.1
  • bitsandbytes version: 0.45.0
  • DeepSpeed version: 0.16.1
  • Diffusers version: 0.32.2
  • Liger-Kernel version: 0.4.2
  • LLM-Blender version: not installed
  • OpenAI version: 1.64.0
  • PEFT version: 0.14.0

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions