-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Enable number of printed completions to be set #3149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
trl/trainer/grpo_trainer.py
Outdated
@@ -905,6 +907,8 @@ def _generate_and_score_completions( | |||
"reward": rewards.tolist(), | |||
} | |||
df = pd.DataFrame(table) | |||
if self.num_completions_to_log is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency, I've enabled subsampling here but we could skip it for WandB since it doesn't really matter if there's a lot of completions AFAIK
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for wandb we should just keep everything? It is more of an issue of the logs being spammed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I'll revert
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM apart from comment What happens with num_samples=0 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
The CI failing is not related btw, you can safely ignore |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
…e (don't print)" This reverts commit f6f93c7.
>>> print_prompt_completions_sample(prompts, completions, rewards, 42, -1)
>>> print_prompt_completions_sample(prompts, completions, rewards, 42, 0)
>>> print_prompt_completions_sample(prompts, completions, rewards, 42, 1)
╭────────────── Step 42 ───────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Reward ┃ │
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━┩ │
│ │ The sky is │ blue. │ 0.12 │ │
│ └────────────┴────────────┴────────┘ │
╰──────────────────────────────────────╯
>>> print_prompt_completions_sample(prompts, completions, rewards, 42, 2)
╭─────────────── Step 42 ────────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Reward ┃ │
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━┩ │
│ │ The sky is │ blue. │ 0.12 │ │
│ ├────────────┼──────────────┼────────┤ │
│ │ The sun is │ in the sky. │ 0.69 │ │
│ └────────────┴──────────────┴────────┘ │
╰────────────────────────────────────────╯
>>> print_prompt_completions_sample(prompts, completions, rewards, 42, 3)
╭─────────────── Step 42 ────────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Reward ┃ │
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━┩ │
│ │ The sky is │ blue. │ 0.12 │ │
│ ├────────────┼──────────────┼────────┤ │
│ │ The sun is │ in the sky. │ 0.69 │ │
│ └────────────┴──────────────┴────────┘ │
╰────────────────────────────────────────╯ |
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
What does this PR do?
This PR exposes a
num_completions_to_print
arg in theGRPOConfig
so that users can control how many completions are printed to the terminal. I found the default (log everything) made the logs very verbose / large.After discussing with @edbeeching I decided not to expose this for WandB since there we don't have to worry about the logs becoming overloaded.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.