-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Issue #1751 Fix #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue #1751 Fix #1754
Conversation
Hi @yash-srivastava19 thanks for this PR, but this is not how we should fix that since ideally we should catch that either by checking that the received type is a So a more suitable fix should be the following: model_init_kwargs["torch_dtype"] = (
model_init_kwargs["torch_dtype"]
if model_init_kwargs["torch_dtype"] in ["auto", None]
or isinstance(model_init_kwargs["torch_dtype"], torch.dtype)
else getattr(torch, model_init_kwargs["torch_dtype"])
) Anyway, I'll let the authors chime in with their thoughts and ideas about a potential fix! Thanks anyway 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this ! I second what @alvarobartt said above, we can change this fix to something like:
diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py
index e739b2d..80e11ad 100644
--- a/trl/trainer/sft_trainer.py
+++ b/trl/trainer/sft_trainer.py
@@ -159,11 +159,13 @@ class SFTTrainer(Trainer):
raise ValueError("You passed model_init_kwargs to the SFTConfig, but your model is already instantiated.")
else:
model_init_kwargs = args.model_init_kwargs
- model_init_kwargs["torch_dtype"] = (
- model_init_kwargs["torch_dtype"]
- if model_init_kwargs["torch_dtype"] in ["auto", None]
- else getattr(torch, model_init_kwargs["torch_dtype"])
- )
+ torch_dtype = model_init_kwargs["torch_dtype"]
+
+ # Convert to `torch.dtype` if an str is passed
+ if isinstance(torch_dtype, str) and torch_dtype != "auto":
+ torch_dtype = getattr(torch, torch_dtype)
+
+ model_init_kwargs["torch_dtype"] = torch_dtype
if infinite is not None:
warnings.warn(
And it worked fine on my end! Would you be happy to apply these changes instead in this PR?
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Yes, it is much more optimal. Agreed |
Did the json encoding error rectified as well or it pertains even after the fix? |
Thanks ! that's another issue we can fix in a follow up PR ! |
Hi here @yash-srivastava19 friendly ping to check about the status of this PR 👍🏻 Is it something you are still happy / comfortable to work with? Or would you prefer us to take over instead? Just let us know, thanks 🤗 |
Hi here @yash-srivastava19 thanks for the effort, we'll be closing this PR in favour of #1807, and you've been included as a contributor there 🤗 Thanks a lot for the effort! |
#1751 mentioned that the TRL CLI is not completely capturing the torch_dtype. I thought the issue was urgent, so quickly patched a hacky fix, which at least initiates the SFT Trainer.
Original Issue :
On running the following command :
The error was that trl sft is does not identify it as a string when calling
getattr(torch, model_init_kwargs["torch_dtype"])
.The fix was made which allows the to not break the pipeline at this stage. Although it is a hacky fix, I'm willing to work on it further :)
The error after that is from the transformers library that isn't able to serialize the dtype object(screenshot attached):