Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def __init__(
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the model used is already a liger model (and args.use_liger = False)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but as recommended by the Liger maintainer, we shouldn't be passing Liger models at the init (just patched via the config)

Should we deprecate passing the Liger model to the trainer or would you prefer an alternative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this command doesn't achieve what it's supposed to:

>>> from liger_kernel.transformers import AutoLigerKernelForCausalLM
>>> model =  AutoLigerKernelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
Applied Liger kernels to Qwen2
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
>>> isinstance(model, AutoLigerKernelForCausalLM)
False

I don't know of an easy way to test whether a model is liger (perhaps @ByronHsu does?).

Anyway, with your change you can still pass a liger model to the trainer. But you'll need to specify use_liger=True. Which sounds good to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't work when a Liger Model is converted to PEFT before passing into the trainer. It does not respect args.use_liger either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but as recommended by the Liger maintainer, we shouldn't be passing Liger models at the init (just patched via the config)

Should we deprecate passing the Liger model to the trainer or would you prefer an alternative?

Deprecating it might cause some problem. LoRA+ requires the model instance to be created before hand (to create optimizer). The flag use_liger does not convert PEFT wrapped model to liger model.

model = get_peft_model(model, lora_config)
optimizer = create_loraplus_optimizer(
    model=model,
    optimizer_cls=torch.optim.AdamW,
    lr=lr,
    eps=eps,
    betas=betas,
    weight_decay=weight_decay,
    loraplus_lr_ratio=loraplus_lr_ratio,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag use_liger does not convert PEFT wrapped model to liger model.

in sft_tariner.py:

        if args.use_liger:
            if not is_liger_kernel_available():
                raise ImportError("Please install Liger-kernel for use_liger=True")
            model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
        return model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, many VLMs can't be loaded with AutoLigerKernelForCausalLM. For example, monkey patching apply_liger_kernel_to_qwen2_vl() is required for Qwen2-VL


# PEFT configuration and model wrapping
if peft_config is not None:
Expand Down Expand Up @@ -472,7 +471,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if "labels" in inputs and not self.use_liger:
if "labels" in inputs and not self.args.use_liger:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()

Expand Down