Skip to content

✌️ Add support for FSDP2 #3317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 6, 2025
Merged

✌️ Add support for FSDP2 #3317

merged 10 commits into from
May 6, 2025

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Apr 17, 2025

What does this PR do?

This PR adds support for FSDP2 by (a) adding an accelerate config and (b) updating prepare_fsdp().

Useful migration guide: https://huggingface.co/docs/accelerate/main/en/concept_guides/fsdp1_vs_fsdp2

The FSDP2 API is still in beta, so we can hold off on merging this if we feel the accelerate version is too high to bump right now. The alternative would be to check the accelerate version being used and only allow FSDP2 to be run for versions > 1.6.0

In particular, FULL_STATE_DICT is not supported in FSDP2 yet, so it is not possible to save transformers friendly checkpoints. This is currently being worked on in accelerate, so we could hold off until then

Update: now supported on accelerate@main and tested with:

uv pip install accelerate @ git+https://github.com/huggingface/accelerate.git@c5caa11e8557633ba2187a84931b78cc25098c05

A quick test on SFT shows the two FSDP versions are almost identical, modulo grad norms:

Screenshot 2025-04-17 at 16 03 02

Commands to test with on 2 GPUs

accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml --num_processes 2  trl/scripts/sft.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --learning_rate 2.0e-5 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --eos_token '<|im_end|>' \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 100 \
    --output_dir Qwen2-0.5B-SFT

# LoRA
accelerate launch --config_file examples/accelerate_configs/fsdp2.yaml --num_processes 2 trl/scripts/sft.py \
    --model_name_or_path Qwen/Qwen2-0.5B \
    --dataset_name trl-lib/Capybara \
    --learning_rate 2.0e-4 \
    --num_train_epochs 1 \
    --packing \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --gradient_checkpointing \
    --eos_token '<|im_end|>' \
    --logging_steps 25 \
    --eval_strategy steps \
    --eval_steps 100 \
    --use_peft \
    --lora_r 32 \
    --lora_alpha 16 \
    --output_dir Qwen2-0.5B-SFT

Working SFT variants

  • Full training FSDP1
  • Full training FSDP2
  • LoRA FSDP1
  • LoRA FSDP2

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@lewtun lewtun marked this pull request as ready for review April 17, 2025 14:08
@lewtun lewtun requested review from qgallouedec and kashif April 17, 2025 14:09
@fabianlim
Copy link
Contributor

fabianlim commented Apr 17, 2025

@lewtun if its FSDP2, its it gauranteed that summoning named_parameters will give the full params? This looking at the this

@qgallouedec
Copy link
Member

if we feel the accelerate version is too high

I think it's ok to bump high since trl is also beta

The alternative would be to check the accelerate version being used and only allow FSDP2 to be run for versions > 1.6.0

So it seems that mergekit requires accelerate>=1.3.0,<1.4.dev0. So the above solution is probably the best.
Maybe just documenting it is enough? Something like # requires accelerate>=1.6 on top of FSDP2 config file?

@qgallouedec
Copy link
Member

qgallouedec commented Apr 21, 2025

I tried myself: looks good!

Screenshot 2025-04-21 at 16 06 02

and for the record, FSDP is way faster in my case:

Screenshot 2025-04-21 at 16 07 15

@fabianlim
Copy link
Contributor

Hi we had a meeting with @qgallouedec and he mentioned that i should update two thoughts on FSDP2.

  1. calling named_parameters may not scale as it summons the whole model. I have some hacky code to summon the model per each FSDP module. But perhaps we can make this nicer and contribute. AFAIK unfortunately FSDP2 does not have a nice API to summon parameters in a sharded manner.
  2. I have a bug report in vllm that warns FSDP1 does not play nice with vllm. This regards the new collocation PR that we are currently working on to have both vllm and training in the same GPU to improve utilization.

@BenasdTW
Copy link
Contributor

BenasdTW commented May 1, 2025

I tried myself: looks good!

@qgallouedec Would you mind sharing the script you ran? Cause I got errors trying to use fsdp2 with this PR.
my code:

# accelerate launch --config_file fsdp.yaml error.py
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from accelerate import PartialState
from peft import get_peft_model, LoraConfig

model_id = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map={"": PartialState().process_index},
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    use_cache=False,
)
# apply_liger_kernel_to_qwen2(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = 151643

dummy_dataset = Dataset.from_dict({"text": ["Dummy dataset"] * 32, })

training_args = SFTConfig(
    output_dir="trainer_output",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=2,
    report_to="none",
    bf16=True,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False}
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dummy_dataset,
    processing_class=tokenizer,
)

trainer.train()
trainer.save_model()

error:

Converting train dataset to ChatML: 100%|████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 7169.37 examples/s]
Adding EOS to train dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 9362.94 examples/s]
Tokenizing train dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 2514.95 examples/s]
Truncating train dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 13881.24 examples/s]
[rank0]:[W501 21:01:27.110829382 ProcessGroupNCCL.cpp:4561] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
Converting train dataset to ChatML: 100%|████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4829.89 examples/s]
Adding EOS to train dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 4561.20 examples/s]
Tokenizing train dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 2098.89 examples/s]
Truncating train dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 7907.25 examples/s]
[rank1]: Traceback (most recent call last):
[rank1]:   File "/workspaces/LLMTrain/error.py", line 67, in <module>
[rank1]:     
[rank1]:     ^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2239, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2368, in _inner_training_loop
[rank1]:     model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
[rank1]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 1438, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 1439, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 1281, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/accelerator.py", line 1605, in prepare_model
[rank1]:     model = fsdp2_prepare_model(self, model)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 644, in fsdp2_prepare_model
[rank1]:     fsdp2_load_full_state_dict(accelerator, model, original_sd)
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 512, in fsdp2_load_full_state_dict
[rank1]:     sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
[rank1]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 741, in distribute_tensor
[rank1]:     local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
[rank1]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py", line 176, in _shard_tensor
[rank1]:     mesh_scatter(output, scatter_list, mesh, mesh_dim=mesh_dim)
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_collective_utils.py", line 123, in mesh_scatter
[rank1]:     fut = scatter(
[rank1]:           ^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 4110, in scatter
[rank1]:     work = group.scatter(output_tensors, input_tensors, opts)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: NCCL Error 2: unhandled system error (run with NCCL_DEBUG=INFO for details)
W0501 21:01:29.642000 39036 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 39276 closing signal SIGTERM
E0501 21:01:30.058000 39036 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 1 (pid: 39277) of binary: /opt/conda/bin/python3
Traceback (most recent call last):
  File "/opt/conda/bin/accelerate", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
    args.func(args)
  File "/opt/conda/lib/python3.11/site-packages/accelerate/commands/launch.py", line 1179, in launch_command
    multi_gpu_launcher(args)
  File "/opt/conda/lib/python3.11/site-packages/accelerate/commands/launch.py", line 809, in multi_gpu_launcher
    distrib_run.run(args)
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/run.py", line 909, in run
    elastic_launch(
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
error.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-05-01_21:01:29
  host      : 711701cc4fe8
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 39277)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lewtun
Copy link
Member Author

lewtun commented May 2, 2025

FSDP2 fails with LoRA and the following error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/fsx/lewis/git/hf/trl/trl/scripts/sft.py", line 148, in <module>
[rank1]:     main(script_args, training_args, model_args)
[rank1]:   File "/fsx/lewis/git/hf/trl/trl/scripts/sft.py", line 128, in main
[rank1]:     trainer.train()
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/trainer.py", line 2238, in train
[rank1]:     return inner_training_loop(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/trainer.py", line 2367, in _inner_training_loop
[rank1]:     model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
[rank1]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/accelerator.py", line 1438, in prepare
[rank1]:     result = tuple(
[rank1]:              ^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/accelerator.py", line 1439, in <genexpr>
[rank1]:     self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
[rank1]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/accelerator.py", line 1281, in _prepare_one
[rank1]:     return self.prepare_model(obj, device_placement=device_placement)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/accelerator.py", line 1605, in prepare_model
[rank1]:     model = fsdp2_prepare_model(self, model)
[rank1]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 644, in fsdp2_prepare_model
[rank1]:     fsdp2_load_full_state_dict(accelerator, model, original_sd)
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 513, in fsdp2_load_full_state_dict
[rank1]:     to_contiguous, casting_dtype = _infer_parameter_dtype(
[rank1]:                                    ^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/utils/fsdp_utils.py", line 475, in _infer_parameter_dtype
[rank1]:     old_param = model.get_parameter_or_buffer(param_name)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/modeling_utils.py", line 5400, in get_parameter_or_buffer
[rank1]:     raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
[rank1]: AttributeError: `base_model.model.model.embed_tokens.weight` is neither a parameter nor a buffer.

Update: fixed by huggingface/accelerate#3545

@lewtun
Copy link
Member Author

lewtun commented May 2, 2025

@lewtun if its FSDP2, its it gauranteed that summoning named_parameters will give the full params? This looking at the this

Good catch, I don't think summoning will work with FSDP2 for now at least...

@qgallouedec
Copy link
Member

Good catch, I don't think summoning will work with FSDP2 for now at least...

@lewtun should we merge in the meantime?

@fabianlim
Copy link
Contributor

@lewtun @qgallouedec yea maybe we need to write some custom summoning code like this

@lewtun lewtun marked this pull request as draft May 5, 2025 10:28
@lewtun lewtun marked this pull request as ready for review May 5, 2025 10:31
Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's focus on efficient summoning in a follow-up PR

@qgallouedec qgallouedec changed the title Add support for FSDP2 ✌️ Add support for FSDP2 May 6, 2025
@lewtun lewtun merged commit 45f4c58 into main May 6, 2025
7 checks passed
@lewtun lewtun deleted the fix-fsdp2 branch May 6, 2025 06:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants