Skip to content

Restore the effect of liger_kernel's monkey_patch on global modules in UT. #3680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 6, 2025

Conversation

YangKai0616
Copy link
Contributor

What does this PR do?

The liger_kernel modifies the global modules configuration. When performing UT in the same process, the impact of liger_kernel's monkey_patch should be promptly restored after testing.

Reproduce the issue

Based on the implementation of tests/slow/test_sft_slow.py::test_sft_trainer_with_liger, I wrote a simple script to demonstrate the issue.

from transformers import AutoModelForCausalLM
from transformers.utils import is_liger_kernel_available
from transformers.models.llama import modeling_llama
import importlib

if __name__ == '__main__':
    model_path = 'trl-internal-testing/tiny-LlamaForCausalLM-3.2'
    model_init_kwargs = {}
    print(f"ori_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, ori_LlamaMLP:{modeling_llama.LlamaMLP}, ori_apply_rotary_pos_emb:{modeling_llama.apply_rotary_pos_emb}")
    model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)

    if is_liger_kernel_available():
        from liger_kernel.transformers import _apply_liger_kernel_to_instance
        _apply_liger_kernel_to_instance(model=model)
        print(f"liger_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, liger_LlamaMLP:{modeling_llama.LlamaMLP}, liger_apply_rotary_pos_emb:{modeling_llama.apply_rotary_pos_emb}")

    importlib.reload(modeling_llama)
    print(f"reload_LlamaRMSNorm:{modeling_llama.LlamaRMSNorm}, reload_LlamaMLP:{modeling_llama.LlamaMLP}, reload_apply_rotary_pos_emb:{modeling_llama.apply_rotary_pos_emb}")

The script output is as follows:

ori_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, ori_LlamaMLP:<class 'transformers.models.llama.modeling_llama.LlamaMLP'>, ori_apply_rotary_pos_emb:<function apply_rotary_pos_emb at 0x7f9a0797a700>
liger_LlamaRMSNorm:<class 'liger_kernel.transformers.rms_norm.LigerRMSNorm'>, liger_LlamaMLP:<class 'liger_kernel.transformers.swiglu.LigerSwiGLUMLP'>, liger_apply_rotary_pos_emb:<function liger_rotary_pos_emb at 0x7f9a069fa980>
reload_LlamaRMSNorm:<class 'transformers.models.llama.modeling_llama.LlamaRMSNorm'>, reload_LlamaMLP:<class 'transformers.models.llama.modeling_llama.LlamaMLP'>, reload_apply_rotary_pos_emb:<function apply_rotary_pos_emb at 0x7f9a07979f80>

Testing Done

If these changes are not manually reverted, they will affect the loading of related models in subsequent operations within the same process. Reloading the modules involved in liger can restore the global configuration.

…t of liger_kernel's monkey_patch on global modules.
@YangKai0616
Copy link
Contributor Author

@kashif and @shirinyamani pls help review.

@kashif
Copy link
Collaborator

kashif commented Jul 4, 2025

actually your approach is also good... perhaps we can just use it with the addCleanup:

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
    @require_liger_kernel
    def test_sft_trainer_with_liger(self, model_name, packing):
        """
        Tests if passing use_liger=True to SFTConfig loads and runs the trainer with AutoLigerKernelForCausalLM as
        expected.
        """
        import importlib
        
        def cleanup_liger_patches(trainer):
            """Clean up liger_kernel patches by reloading the model's specific module"""
            try:
                # Get the specific module that was used by the trainer's model
                module_path = trainer.model.__module__
                reload_module = importlib.import_module(module_path)
                importlib.reload(reload_module)
            except Exception:
                pass  # Continue if reload fails
        
        with tempfile.TemporaryDirectory() as tmp_dir:
            training_args = SFTConfig(
                output_dir=tmp_dir,
                logging_strategy="no",
                report_to="none",
                per_device_train_batch_size=2,
                max_steps=2,
                packing=packing,
                max_length=self.max_length,
                use_liger_kernel=True,
            )

            trainer = SFTTrainer(
                model_name,
                args=training_args,
                train_dataset=self.train_dataset,
                eval_dataset=self.eval_dataset,
            )

            # Register cleanup now that we have the trainer
            self.addCleanup(cleanup_liger_patches, trainer)

            trainer.train()

        release_memory(trainer.model, trainer)

@YangKai0616
Copy link
Contributor Author

actually your approach is also good... perhaps we can just use it with the addCleanup:

Great idea, thanks for the suggestion!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif merged commit c30344e into huggingface:main Jul 6, 2025
9 of 10 checks passed
marcandrelarochelle pushed a commit to marcandrelarochelle/trl that referenced this pull request Jul 29, 2025
…n UT. (huggingface#3680)

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants