generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Description
Description
Hi here! 🤗 Apparently the TRL's CLI command trl sft
is not properly capturing the value provided to the --torch_dtype
flag as it does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"])
.
Most likely this issue happens on the rest of the implementations since when parsing the torch_dtype
provided by the CLI the conversion to torch.dtype
happens, and then the SFTTrainer
in this case, receives the torch_dtype=torch.bfloat16
instead, and attempts to getattr(torch, torch.bfloat16)
.
So there's two potential fixes:
- Handling the received type for
torch_dtype
within each...Trainer
subclass so as to provide it to themodel_init_kwargs
as it without the need of callinggetattr(torch, ...)
- Respecting the
torch_dtype
as a string and letting each...Trainer
subclass do thestr
->torch.dtype
conversion instead, which is more convenient IMO
To reproduce
trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test
younesbelkada and neelsjain
Metadata
Metadata
Assignees
Labels
No labels