-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
TRLParser allows parsing a yaml config and command line arguments. Logically, the command line arguments should override the yaml config so if config.yaml
contains gradient_accumulation_steps: 4
and the user calls python script.py --config config.yaml --gradient_accumulation_steps 1
then we would expect gradient_accumulation_steps
to be set to 1
.
But currently TRLParser will set it to 4, choosing the config over the user's input.
This is because merging the two dataclasses happens in the YAMLParser which takes in the command line parsed into dataclasses as well as the yaml parsed into a config. But to decide whether to override the dataclass with the config, it checks only if the default value has changed https://github.com/huggingface/trl/blob/main/trl/commands/cli_utils.py#L90
If the user specifies the default value on the command line e.g. --gradient_accumulation_steps 1
then this is indistinguishable from the dataclass having a default value of --gradient_accumulation_steps 1
.
The simplest way to make the config parser and argument parser work together is to manually parse the yaml file and prepend it to the args somehow before calling, like HfArgumentParser does with its config file https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/hf_argparser.py#L331
The best way, in my opinion, is just to parse the yaml file and then use it to set the defaults of all the dataclasses' with self.set_defaults
https://github.com/python/cpython/blob/3.12/Lib/argparse.py#L1427
Advice and pointers welcome!