Skip to content

[BUG] TRL CLI not capturing torch_dtype correctly #1751

@alvarobartt

Description

@alvarobartt

Description

Hi here! 🤗 Apparently the TRL's CLI command trl sft is not properly capturing the value provided to the --torch_dtype flag as it does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"]).

Most likely this issue happens on the rest of the implementations since when parsing the torch_dtype provided by the CLI the conversion to torch.dtype happens, and then the SFTTrainer in this case, receives the torch_dtype=torch.bfloat16 instead, and attempts to getattr(torch, torch.bfloat16).

So there's two potential fixes:

  • Handling the received type for torch_dtype within each ...Trainer subclass so as to provide it to the model_init_kwargs as it without the need of calling getattr(torch, ...)
  • Respecting the torch_dtype as a string and letting each ...Trainer subclass do the str -> torch.dtype conversion instead, which is more convenient IMO

To reproduce

trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb  --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions