Skip to content

Conversation

lindawangg
Copy link
Contributor

@lindawangg lindawangg commented Sep 11, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

What are the changes made in this PR?

  • KD recipe (knowledge_distillation_single_device.py) is similar to lora_finetune_single_device.py. Main differences are:
    • adds kd loss, currently just ForwardKLLoss (in kd_losses.py), to CE loss
    • adds teacher model inference to get logits
  • KD config: knowledge_distillation_single_device.yaml

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
CUDA_VISIBLE_DEVICES=0 tune run knowledge_distillation_single_device --config qwen2/knowledge_distillation_single_device
Llama3.1 KD Training

Legend: KD Llama3.1 student (blue), LoRA Llama3.1 student (orange)
imageimageimage

Qwen2 KD Training

Legend: KD Qwen2 0.5B student (green), LoRA Qwen2 0.5B student (grey), LoRA Qwen2 1.5B teacher (blue)
imageimageimage

Llama3.1 Eval Results
image
Qwen2 Eval Results
image

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Sep 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1539

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6ea3329 with merge base 63208c6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 11, 2024
@lindawangg lindawangg marked this pull request as ready for review September 12, 2024 02:52
Comment on lines 6 to 7
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct
# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a bit more detail here? Tbh when I first looked at it I thought you had accidentally just copy-pasted a tune run command from another config 😅 . Maybe just add a couple statements explicitly saying something like "Run this to download the model: {tune download ...}. You will then need to fine-tune the teacher model, you can do this with {tune run...}"

# Teacher checkpoint
teacher_checkpointer:
_component_: torchtune.training.FullModelMetaCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned_single_device_epoch_1/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should make sure that this directory matches the output of whatever command you give for LoRA finetuning at the top of this file

epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16
compile: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, did you try with compile yet?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea tested that compile works (in pink)
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some nice improvements in training speed. I'm curious about the difference in the loss curves, do you see non-determinism across runs there without compile enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loss curves between compile and non-compile are fairly similar.
imageimage

I do see some non-determinism during eval. The losses are all around 1.2, but there's slight differences in eval. Also interesting that fine-tuning on alpaca dataset actually hurts performances on all benchmarks but Truthful QA.
image

Comment on lines 15 to 17
The Kullback-Leibler divergence loss for valid indexes.
Implementation of https://github.com/jongwooko/distillm/blob/master/distillm/losses.py.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised the linter didn't yell at you for this, can we add args (well I guess just single arg) with typehints here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also two nits on the link: (1) don't include the period at the end (makes it not clickable), and (2) replace master with a specific commit hash (in case things change in the future)

"""

teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(student_logits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: why would student logits be infinite? Does it mean there is some numerical issue? (I know it's in the original implementation, just curious about the rationale)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The teacher_logits could be infinite too. I believe the original implementation only considered that the student logits could be infinite because that's the model that's training. The inf in the student logits would cause the torch.sum part to be inf.

Comment on lines 61 to 62
the cross entropy normally, but upcasting only one chunk at a time saves considerable memory.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here about init args

standard_loss = fkl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also be good to run the loss from jongwooko repo with identical sets of values, use that to determine the expected value, then compare both chunked and standard losses to that (that way we know that we have numerical parity with a reference implementation)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I couldn't find any tests in the repo, so I randomly generated the logits and ran it through the distillm implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! You can also use something like the fixed_init_tensor util in case you are don't want to generate it manually. But in this case the tensors are relatively small anyways so no need for it

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I took another, more thorough pass, and this looks great! I left a handful more comments but after that there are no real concerns from me

standard_loss = fkl_loss(logits, teacher_logits, labels)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! You can also use something like the fixed_init_tensor util in case you are don't want to generate it manually. But in this case the tensors are relatively small anyways so no need for it

Comment on lines 284 to 287
Config(
name="llama3_1/kd_single_device",
file_path="llama3_1/kd_single_device.yaml",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to remove this in the latest version

# Logging
output_dir: /tmp/qwen_kd
metric_logger:
_component_: torchtune.training.metric_logging.TensorBoardLogger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: at least before landing make sure to switch this to torchtune.training.metric_logging.DiskLogger, since tensorboard is technically a dev dependency of our library

Comment on lines 72 to 73
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with
8-bit AdamW and Paged AdamW.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might wanna check this if you haven't already (I would be surprised if it doesn't work though)

epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16
compile: False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some nice improvements in training speed. I'm curious about the difference in the loss curves, do you see non-determinism across runs there without compile enabled?

Comment on lines 264 to 283
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
assert (
self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks
), "Number of output chunks for loss_fn and kd_loss_fn must be the same."
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks)
if self._model_compile:
log.info("Compiling loss with torch.compile...")
# For CEWithChunkedOutputLoss, if we compile the entire class
# we lose the benefits from the chunked loss.
# Therefore, we only compile the cross entropy function + upcasting
self._loss_fn.compute_cross_entropy = torch.compile(
self._loss_fn.compute_cross_entropy, backend=backend
)
else:
if self._model_compile:
log.info("Compiling loss with torch.compile...")
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh we have some utilities for this now, you can try to use those instead. See usage in our LoRA single-device recipe here. (If they don't work for KD out of the box lmk, we can refactor as needed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also try compiling KD loss function

Comment on lines 415 to 420
if compile_model:
log.info("Compiling model layers with torch.compile...")
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
for m in reversed(list(model.modules())):
if isinstance(m, modules.transformer.TransformerSelfAttentionLayer):
m.compile(backend=backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here about using the compile utilities

training.log_memory_stats(memory_stats)
return model

def _setup_teacher_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if it's worth it to also compile the teacher model when compile=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got an error when trying to compile the teacher model. I'm using torch.no_grad when inferencing the teacher model to save memory consumption. However, it seems that torch.no_grad isn't compatible with torch.compile (pytorch/pytorch#100241).

# Update the number of steps when the weights are updated
self.global_step += 1

loss_to_log = running_loss.item()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'm not positive, but it might be slightly slower to do things this way. Each .item() will cause a sync and you could probably get away with calling just running_class_loss.item() and running_kd_loss.item() and figuring out loss_to_log from that (I think).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling .item() twice instead of 3x helps slightly. I don't see much of a difference:
image

I removed running_loss and added computing from class and kd loss.

Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit / discussion point:

kd -> distillation. I think it's much easier to understand at a glance what it is. I didn't understand the abbreviation KD until I read through the PR description.

@lindawangg
Copy link
Contributor Author

Super nit / discussion point:

kd -> distillation. I think it's much easier to understand at a glance what it is. I didn't understand the abbreviation KD until I read through the PR description.

We could rename from kd_single_device to knowledge_distillation_single_device. Distillation might be confusing since there's many types.

@ebsmothers
Copy link
Contributor

We could rename from kd_single_device to knowledge_distillation_single_device. Distillation might be confusing since there's many types.

Yeah this sounds good to me. I agree that it's nice to just be explicit in the recipe name so it's obvious what the recipe is doing (even if the name is a bit longer as a result).

metric_logger._component_=torchtune.training.metric_logging.DiskLogger \
metric_logger.filename={log_file} \
compile={compile} \
kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually need this override? This should be the default in the config, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

Comment on lines 119 to 120
print(loss_values)
print(expected_loss_values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@codecov-commenter
Copy link

codecov-commenter commented Sep 19, 2024

Codecov Report

Attention: Patch coverage is 22.01835% with 340 lines in your changes missing coverage. Please review.

Project coverage is 69.00%. Comparing base (dd348ce) to head (6ea3329).
Report is 474 commits behind head on main.

Files with missing lines Patch % Lines
recipes/knowledge_distillation_single_device.py 0.00% 261 Missing ⚠️
...cipes/test_knowledge_distillation_single_device.py 22.00% 78 Missing ⚠️
torchtune/modules/loss/kd_losses.py 96.77% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1539      +/-   ##
==========================================
- Coverage   72.26%   69.00%   -3.26%     
==========================================
  Files         290      295       +5     
  Lines       14554    15079     +525     
==========================================
- Hits        10517    10406     -111     
- Misses       4037     4673     +636     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this!

@ebsmothers ebsmothers merged commit 4234b78 into pytorch:main Sep 19, 2024
17 checks passed
@lindawangg lindawangg mentioned this pull request Sep 20, 2024
13 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants