Skip to content

add generation_kwargs to GRPOTrainer, so people have more control when training #3562

@avishaiElmakies

Description

@avishaiElmakies

Feature request

I would like to have more control over grpo generation hyperparameters.
currently GRPOConfig allows only some of the features for the generations: e.g. temp, top_k, top_p, repetition_penalty and maybe some others.
The code overrides generation_config that the model uses, which means that it is not possible to change the models config to get the desired affect.

self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
pad_token_id=processing_class.pad_token_id,
bos_token_id=processing_class.bos_token_id,
eos_token_id=processing_class.eos_token_id,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
repetition_penalty=self.repetition_penalty,
cache_implementation=args.cache_implementation,
)

with unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model:
with (
FSDP.summon_full_params(self.model_wrapped, recurse=False)
if self.is_fsdp_enabled
else nullcontext()
):
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)

Motivation

This could give more flexibility when training using the trainer. e.g. we may want to supress some tokens when generating but it is currently not possible to do this

Your contribution

I might be able to help at a later date. so if someone would like to help, it will be great

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions