-
Notifications
You must be signed in to change notification settings - Fork 669
Add single device KD recipe #1539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1539
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6ea3329 with merge base 63208c6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct | ||
# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a bit more detail here? Tbh when I first looked at it I thought you had accidentally just copy-pasted a tune run
command from another config 😅 . Maybe just add a couple statements explicitly saying something like "Run this to download the model: {tune download ...}. You will then need to fine-tune the teacher model, you can do this with {tune run...}"
# Teacher checkpoint | ||
teacher_checkpointer: | ||
_component_: torchtune.training.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned_single_device_epoch_1/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should make sure that this directory matches the output of whatever command you give for LoRA finetuning at the top of this file
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
compile: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, did you try with compile yet?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like some nice improvements in training speed. I'm curious about the difference in the loss curves, do you see non-determinism across runs there without compile enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchtune/modules/loss/kd_losses.py
Outdated
The Kullback-Leibler divergence loss for valid indexes. | ||
Implementation of https://github.com/jongwooko/distillm/blob/master/distillm/losses.py. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised the linter didn't yell at you for this, can we add args (well I guess just single arg) with typehints here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also two nits on the link: (1) don't include the period at the end (makes it not clickable), and (2) replace master with a specific commit hash (in case things change in the future)
""" | ||
|
||
teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) | ||
inf_mask = torch.isinf(student_logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: why would student logits be infinite? Does it mean there is some numerical issue? (I know it's in the original implementation, just curious about the rationale)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The teacher_logits could be infinite too. I believe the original implementation only considered that the student logits could be infinite because that's the model that's training. The inf
in the student logits would cause the torch.sum part to be inf
.
torchtune/modules/loss/kd_losses.py
Outdated
the cross entropy normally, but upcasting only one chunk at a time saves considerable memory. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here about init args
standard_loss = fkl_loss(logits, teacher_logits, labels) | ||
|
||
# Assert | ||
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would also be good to run the loss from jongwooko repo with identical sets of values, use that to determine the expected value, then compare both chunked and standard losses to that (that way we know that we have numerical parity with a reference implementation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I couldn't find any tests in the repo, so I randomly generated the logits and ran it through the distillm implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! You can also use something like the fixed_init_tensor util in case you are don't want to generate it manually. But in this case the tensors are relatively small anyways so no need for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I took another, more thorough pass, and this looks great! I left a handful more comments but after that there are no real concerns from me
standard_loss = fkl_loss(logits, teacher_logits, labels) | ||
|
||
# Assert | ||
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! You can also use something like the fixed_init_tensor util in case you are don't want to generate it manually. But in this case the tensors are relatively small anyways so no need for it
torchtune/_recipe_registry.py
Outdated
Config( | ||
name="llama3_1/kd_single_device", | ||
file_path="llama3_1/kd_single_device.yaml", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to remove this in the latest version
# Logging | ||
output_dir: /tmp/qwen_kd | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.TensorBoardLogger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: at least before landing make sure to switch this to torchtune.training.metric_logging.DiskLogger
, since tensorboard
is technically a dev dependency of our library
recipes/kd_single_device.py
Outdated
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with | ||
8-bit AdamW and Paged AdamW. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might wanna check this if you haven't already (I would be surprised if it doesn't work though)
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
compile: False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like some nice improvements in training speed. I'm curious about the difference in the loss curves, do you see non-determinism across runs there without compile enabled?
recipes/kd_single_device.py
Outdated
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") | ||
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": | ||
# set num_output_chunks for model | ||
assert ( | ||
self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks | ||
), "Number of output chunks for loss_fn and kd_loss_fn must be the same." | ||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) | ||
self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) | ||
if self._model_compile: | ||
log.info("Compiling loss with torch.compile...") | ||
# For CEWithChunkedOutputLoss, if we compile the entire class | ||
# we lose the benefits from the chunked loss. | ||
# Therefore, we only compile the cross entropy function + upcasting | ||
self._loss_fn.compute_cross_entropy = torch.compile( | ||
self._loss_fn.compute_cross_entropy, backend=backend | ||
) | ||
else: | ||
if self._model_compile: | ||
log.info("Compiling loss with torch.compile...") | ||
self._loss_fn = torch.compile(self._loss_fn, backend=backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh we have some utilities for this now, you can try to use those instead. See usage in our LoRA single-device recipe here. (If they don't work for KD out of the box lmk, we can refactor as needed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also try compiling KD loss function
recipes/kd_single_device.py
Outdated
if compile_model: | ||
log.info("Compiling model layers with torch.compile...") | ||
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") | ||
for m in reversed(list(model.modules())): | ||
if isinstance(m, modules.transformer.TransformerSelfAttentionLayer): | ||
m.compile(backend=backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment here about using the compile utilities
recipes/kd_single_device.py
Outdated
training.log_memory_stats(memory_stats) | ||
return model | ||
|
||
def _setup_teacher_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if it's worth it to also compile the teacher model when compile=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got an error when trying to compile the teacher model. I'm using torch.no_grad when inferencing the teacher model to save memory consumption. However, it seems that torch.no_grad isn't compatible with torch.compile (pytorch/pytorch#100241).
recipes/kd_single_device.py
Outdated
# Update the number of steps when the weights are updated | ||
self.global_step += 1 | ||
|
||
loss_to_log = running_loss.item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'm not positive, but it might be slightly slower to do things this way. Each .item()
will cause a sync and you could probably get away with calling just running_class_loss.item()
and running_kd_loss.item()
and figuring out loss_to_log
from that (I think).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nit / discussion point:
kd -> distillation. I think it's much easier to understand at a glance what it is. I didn't understand the abbreviation KD until I read through the PR description.
We could rename from |
Yeah this sounds good to me. I agree that it's nice to just be explicit in the recipe name so it's obvious what the recipe is doing (even if the name is a bit longer as a result). |
metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ | ||
metric_logger.filename={log_file} \ | ||
compile={compile} \ | ||
kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you actually need this override? This should be the default in the config, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
print(loss_values) | ||
print(expected_loss_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1539 +/- ##
==========================================
- Coverage 72.26% 69.00% -3.26%
==========================================
Files 290 295 +5
Lines 14554 15079 +525
==========================================
- Hits 10517 10406 -111
- Misses 4037 4673 +636 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this!
Context
What is the purpose of this PR? Is it to
Changelog
What are the changes made in this PR?
knowledge_distillation_single_device.py
) is similar tolora_finetune_single_device.py
. Main differences are:ForwardKLLoss
(inkd_losses.py
), to CE lossknowledge_distillation_single_device.yaml
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Llama3.1 KD Training
Legend: KD Llama3.1 student (blue), LoRA Llama3.1 student (orange)



Qwen2 KD Training
Legend: KD Qwen2 0.5B student (green), LoRA Qwen2 0.5B student (grey), LoRA Qwen2 1.5B teacher (blue)



Llama3.1 Eval Results
Qwen2 Eval Results
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models