-
Notifications
You must be signed in to change notification settings - Fork 2.1k
✌️ Add support for FSDP2 #3317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
✌️ Add support for FSDP2 #3317
Conversation
I think it's ok to bump high since trl is also beta
So it seems that |
Hi we had a meeting with @qgallouedec and he mentioned that i should update two thoughts on FSDP2.
|
@qgallouedec Would you mind sharing the script you ran? Cause I got errors trying to use fsdp2 with this PR. # accelerate launch --config_file fsdp.yaml error.py
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from accelerate import PartialState
from peft import get_peft_model, LoraConfig
model_id = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map={"": PartialState().process_index},
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
use_cache=False,
)
# apply_liger_kernel_to_qwen2(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = 151643
dummy_dataset = Dataset.from_dict({"text": ["Dummy dataset"] * 32, })
training_args = SFTConfig(
output_dir="trainer_output",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=2,
report_to="none",
bf16=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False}
)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dummy_dataset,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model() error:
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
FSDP2 fails with LoRA and the following error:
Update: fixed by huggingface/accelerate#3545 |
@lewtun should we merge in the meantime? |
@lewtun @qgallouedec yea maybe we need to write some custom summoning code like this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's focus on efficient summoning in a follow-up PR
What does this PR do?
This PR adds support for FSDP2 by (a) adding an
accelerate
config and (b) updatingprepare_fsdp()
.Useful migration guide: https://huggingface.co/docs/accelerate/main/en/concept_guides/fsdp1_vs_fsdp2
The FSDP2 API is still in beta, so we can hold off on merging this if we feel the
accelerate
version is too high to bump right now. The alternative would be to check the accelerate version being used and only allow FSDP2 to be run for versions > 1.6.0In particular,FULL_STATE_DICT
is not supported in FSDP2 yet, so it is not possible to savetransformers
friendly checkpoints. This is currently being worked on inaccelerate
, so we could hold off until thenUpdate: now supported on
accelerate@main
and tested with:A quick test on SFT shows the two FSDP versions are almost identical, modulo grad norms:
Commands to test with on 2 GPUs
Working SFT variants
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.