Skip to content

Conversation

yash-srivastava19
Copy link
Contributor

#1751 mentioned that the TRL CLI is not completely capturing the torch_dtype. I thought the issue was urgent, so quickly patched a hacky fix, which at least initiates the SFT Trainer.

Original Issue :
On running the following command :

trl sft --model_name_or_path=facebook/opt-125m --dataset_name=imdb  --dataset_text_field=text --max_steps=1 --torch_dtype=bfloat16 --output_dir=./test

The error was that trl sft is does not identify it as a string when calling getattr(torch, model_init_kwargs["torch_dtype"]).

The fix was made which allows the to not break the pipeline at this stage. Although it is a hacky fix, I'm willing to work on it further :)

The error after that is from the transformers library that isn't able to serialize the dtype object(screenshot attached):

...
...
TypeError: Object of type dtype is not JSON serializable
Traceback (most recent call last):

Screenshot 2024-06-18 171826

@alvarobartt
Copy link
Member

Hi @yash-srivastava19 thanks for this PR, but this is not how we should fix that since ideally we should catch that either by checking that the received type is a torch.dtype or just ensuring that the str provided as torch_dtype via the CLI is not transformed to a torch.dtype before instantiating the SFTTrainer for example.

So a more suitable fix should be the following:

model_init_kwargs["torch_dtype"] = (
  model_init_kwargs["torch_dtype"]
  if model_init_kwargs["torch_dtype"] in ["auto", None]
  or isinstance(model_init_kwargs["torch_dtype"], torch.dtype)
  else getattr(torch, model_init_kwargs["torch_dtype"])
)

Anyway, I'll let the authors chime in with their thoughts and ideas about a potential fix! Thanks anyway 🤗

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:

diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e739b2d..80e11ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -159,11 +159,13 @@ class SFTTrainer(Trainer):
             raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
         else:
             model_init_kwargs = args.model_init_kwargs
-            model_init_kwargs["torch_dtype"] = (
-                model_init_kwargs["torch_dtype"]
-                if model_init_kwargs["torch_dtype"] in ["auto", None]
-                else getattr(torch, model_init_kwargs["torch_dtype"])
-            )
+            torch_dtype = model_init_kwargs["torch_dtype"]
+
+            # Convert to `torch.dtype` if an str is passed
+            if isinstance(torch_dtype, str) and torch_dtype != "auto":
+                torch_dtype = getattr(torch, torch_dtype)
+
+            model_init_kwargs["torch_dtype"] = torch_dtype

         if infinite is not None:
             warnings.warn(

And it worked fine on my end! Would you be happy to apply these changes instead in this PR?

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yash-srivastava19
Copy link
Contributor Author

Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:

diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e739b2d..80e11ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -159,11 +159,13 @@ class SFTTrainer(Trainer):
             raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
         else:
             model_init_kwargs = args.model_init_kwargs
-            model_init_kwargs["torch_dtype"] = (
-                model_init_kwargs["torch_dtype"]
-                if model_init_kwargs["torch_dtype"] in ["auto", None]
-                else getattr(torch, model_init_kwargs["torch_dtype"])
-            )
+            torch_dtype = model_init_kwargs["torch_dtype"]
+
+            # Convert to `torch.dtype` if an str is passed
+            if isinstance(torch_dtype, str) and torch_dtype != "auto":
+                torch_dtype = getattr(torch, torch_dtype)
+
+            model_init_kwargs["torch_dtype"] = torch_dtype

         if infinite is not None:
             warnings.warn(

And it worked fine on my end! Would you be happy to apply these changes instead in this PR?

Yes, it is much more optimal. Agreed

@yash-srivastava19
Copy link
Contributor Author

Did the json encoding error rectified as well or it pertains even after the fix?

@younesbelkada
Copy link
Contributor

Thanks ! that's another issue we can fix in a follow up PR !

@alvarobartt
Copy link
Member

Hi here @yash-srivastava19 friendly ping to check about the status of this PR 👍🏻 Is it something you are still happy / comfortable to work with? Or would you prefer us to take over instead? Just let us know, thanks 🤗

@alvarobartt
Copy link
Member

Hi here @yash-srivastava19 thanks for the effort, we'll be closing this PR in favour of #1807, and you've been included as a contributor there 🤗 Thanks a lot for the effort!

@alvarobartt alvarobartt closed this Jul 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants