-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Feature request
I would like to have more control over grpo generation hyperparameters.
currently GRPOConfig
allows only some of the features for the generations: e.g. temp, top_k, top_p, repetition_penalty and maybe some others.
The code overrides generation_config that the model uses, which means that it is not possible to change the models config to get the desired affect.
trl/trl/trainer/grpo_trainer.py
Lines 668 to 680 in 1314aac
self.generation_config = GenerationConfig( | |
max_new_tokens=self.max_completion_length, | |
do_sample=True, | |
pad_token_id=processing_class.pad_token_id, | |
bos_token_id=processing_class.bos_token_id, | |
eos_token_id=processing_class.eos_token_id, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
top_k=self.top_k, | |
min_p=self.min_p, | |
repetition_penalty=self.repetition_penalty, | |
cache_implementation=args.cache_implementation, | |
) |
trl/trl/trainer/grpo_trainer.py
Lines 1079 to 1089 in 1314aac
with unwrap_model_for_generation( | |
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation | |
) as unwrapped_model: | |
with ( | |
FSDP.summon_full_params(self.model_wrapped, recurse=False) | |
if self.is_fsdp_enabled | |
else nullcontext() | |
): | |
prompt_completion_ids = unwrapped_model.generate( | |
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config | |
) |
Motivation
This could give more flexibility when training using the trainer. e.g. we may want to supress some tokens when generating but it is currently not possible to do this
Your contribution
I might be able to help at a later date. so if someone would like to help, it will be great