Skip to content

[Liger] liger DPO support #2568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jun 12, 2025
Merged

[Liger] liger DPO support #2568

merged 38 commits into from
Jun 12, 2025

Conversation

kashif
Copy link
Collaborator

@kashif kashif commented Jan 14, 2025

What does this PR do?

Add support for Liger-kernel losses for the DPO Kernel

Needs: linkedin/Liger-Kernel#521

Peft support: #3065

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec
Copy link
Member

liger loss isn't compatible with ref precomputing right? If so we could add a warning or an error.

@VProv
Copy link

VProv commented Mar 26, 2025

@VProv VProv mentioned this pull request Mar 26, 2025
5 tasks
@kashif
Copy link
Collaborator Author

kashif commented Mar 26, 2025

@VProv, at the moment, I was having issues getting the same outputs/metrics with and without liger in the trainer.

@VProv
Copy link

VProv commented Mar 26, 2025

@VProv, at the moment, I was having issues getting the same outputs/metrics with and without liger in the trainer.

What setup are you using?

@vaibhavjindal
Copy link
Contributor

Hi, I am working on fixing the output/metrics issue.
Added a PR in liger-kernel: linkedin/Liger-Kernel#676

@vaibhavjindal
Copy link
Contributor

@kashif @qgallouedec can you please review the following PR which fixes the output/metrics issue? Thanks :)
#3346

@hanbyul-kim
Copy link

Hi, thanks for sharing your work! Can I use your code with DeepSpeed Zero 3? I tried running it with that setup, but it doesn't seem to be working. I think it's related to parameter partitioning based on my analysis of the error log.

[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/dpo_loss.py", line 94, in forward
[rank5]:     return super().forward(
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 241, in forward
[rank5]:     accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 159, in accumulate_chunk
[rank5]:     ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 120, in fused_fwd_bwd
[rank5]:     return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)(
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/apis.py", line 440, in wrapper
[rank5]:     return eager_transforms.grad_and_value_impl(
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/vmap.py", line 48, in fn
[rank5]:     return f(*args, **kwargs)
[rank5]:   File "/root/.dpo_trainer_venv/lib/python3.10/site-packages/torch/_functorch/eager_transforms.py", line 1409, in grad_and_value_impl
[rank5]:     output = func(*args, **kwargs)
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 377, in _compute_loss
[rank5]:     ) = LigerFusedLinearPreferenceBase.chunk_forward(
[rank5]:   File "/mnt/nappipe/users/hanbyul-kim/RORL/apply_liger_loss/Liger-Kernel/src/liger_kernel/chunked_loss/fused_linear_preference.py", line 289, in chunk_forward
[rank5]:     logits_chunk = input_chunk @ weight.t()
[rank5]: RuntimeError: size mismatch, got input (322), mat (322x4096), vec (0)

@hanbyul-kim
Copy link

Continuing my analysis, I can confirm that it's definitely connected to DeepSpeed zero 3. When I switched to stage 2, it ran smoothly without any issues.

@kashif
Copy link
Collaborator Author

kashif commented May 5, 2025

thanks @hanbyul-kim for the report

@vaibhavjindal
Copy link
Contributor

@kashif just wanted to circle back and see if we can merge this now? We wanted to try it out internally at Linkedin.


if is_wandb_available():
import wandb


def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int) -> torch.Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pad_token_id isn't used?


if is_wandb_available():
import wandb


def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor:
"""Shift input ids one token to the right, and pad with pad_token_id"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring ain't accurate I think

@kashif kashif merged commit 53c4a7c into main Jun 12, 2025
11 checks passed
@kashif kashif deleted the liger-dpo branch June 12, 2025 10:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants