-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Restore the effect of liger_kernel's monkey_patch on global modules in UT. #3680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…t of liger_kernel's monkey_patch on global modules.
@kashif and @shirinyamani pls help review. |
actually your approach is also good... perhaps we can just use it with the @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_liger_kernel
def test_sft_trainer_with_liger(self, model_name, packing):
"""
Tests if passing use_liger=True to SFTConfig loads and runs the trainer with AutoLigerKernelForCausalLM as
expected.
"""
import importlib
def cleanup_liger_patches(trainer):
"""Clean up liger_kernel patches by reloading the model's specific module"""
try:
# Get the specific module that was used by the trainer's model
module_path = trainer.model.__module__
reload_module = importlib.import_module(module_path)
importlib.reload(reload_module)
except Exception:
pass # Continue if reload fails
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=2,
packing=packing,
max_length=self.max_length,
use_liger_kernel=True,
)
trainer = SFTTrainer(
model_name,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
)
# Register cleanup now that we have the trainer
self.addCleanup(cleanup_liger_patches, trainer)
trainer.train()
release_memory(trainer.model, trainer) |
Great idea, thanks for the suggestion! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…n UT. (huggingface#3680) Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
…n UT. (huggingface#3680) Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
…n UT. (huggingface#3680) Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
What does this PR do?
The liger_kernel modifies the global modules configuration. When performing UT in the same process, the impact of liger_kernel's monkey_patch should be promptly restored after testing.
Reproduce the issue
Based on the implementation of
tests/slow/test_sft_slow.py::test_sft_trainer_with_liger
, I wrote a simple script to demonstrate the issue.The script output is as follows:
Testing Done
If these changes are not manually reverted, they will affect the loading of related models in subsequent operations within the same process. Reloading the modules involved in liger can restore the global configuration.