Skip to content

An experiment with distillation using various combinations of naïve implementations, adversarial training, Jacobian penalties, and approximate Hessian penalties.

License

Notifications You must be signed in to change notification settings

bitterbridge/distillation

Repository files navigation

Distillation Proof-of-Concept

A PoC to learn a little more about how model distillation works.

I had the idea of using something like a Jacobian matrix to "disentangle" or "rectify" neural networks, to simplify their geometry and perhaps make them more generalizable, more reliable, and more efficient.

It turns out that this is already a thing:

But I thought it'd be cool to play with anyway, so I compared the teacher and student with the naïve distillation process, and then added in a Jacobian penalty to see if that would reduce complexity and inference time. It did!

On a lark, I decided to play with a Hessian penalty. This seemed like it would be prohibitively expensive, but I found an approximation that I thought might have much the same effect without so severe a penalty. It greatly reduced the inference time without significantly reducing accuracy, but (as we might expect) left the Jacobian norm rather high. Given that adversarial attacks exploit abrupt changes and that a Hessian penalty would smooth those too, this doesn't seem to be cause for concern in and of itself, just kind of an aesthetic thing for me.

That said, adversarial attacks did greatly reduce the accuracy of the naïve model, so I introduced adversarial training against the same four combinations with the goal of preserving their positive attributes while also increasing their resilience to attack. This had the desired effect; indeed, the combination of adversarial training and the pseudo-Hessian penalty produced a model with the highest post-distillation accuracy, a 90% reduction in parameters, a ~45% decrease in inference time, and the highest resilience to adversarial attacks at ε = 0.05 and 0.10.

I assume other researchers have investigated using an approximation of the Hessian as a penalty, but I learned a lot 🙂

Instructions

# Create the `distillation` environment from the env file.
conda env create -f environment.yml;

# Activate the environment.
conda activate distillation;

# Verify environment was created correctly.
conda env list

# Download the MNIST dataset.
python3 data.py

# Train the teacher, which will then train the student with:
# - normal (naïve) distillation
# - a Jacobian penalty applied (original idea)
# - approximate Hessian penalty applied (YOLO)
# - adversarial training and a naïve distillation
# - adversarial training + a Jacobian penalty
# - adversarial training + approximate Hessian penalty
# - adversarial training + a Jacobian penalty + approximate Hessian penalty
python3 train.py

Results

Without Jacobian Penalty

$ python3 train.py
Batch shape: torch.Size([64, 1, 28, 28])
Labels: tensor([7, 6, 5, 1, 4, 6, 3, 5, 0, 1])
...
Training complete.
Evaluating models...
Teacher accuracy: 0.9815
Student accuracy: 0.9664

With Jacobian Penalty

$ python3 train.py
Batch shape: torch.Size([64, 1, 28, 28])
Labels: tensor([0, 6, 2, 1, 1, 9, 5, 5, 7, 5])
...
Training complete.
Evaluating models...
Teacher accuracy: 0.9795
Student accuracy: 0.9681
Teacher params: 535818
Student params: 50890
Teacher time: 0.002031879425048828
Student time: 0.0009510612487792969
Teacher Jacobian norm: 42575.689453125
Student Jacobian norm: 0.798276400566101
Epsilon 0.05 - Teacher acc: 0.6835 | Student acc: 0.5324
Epsilon 0.10 - Teacher acc: 0.2377 | Student acc: 0.1116
Epsilon 0.20 - Teacher acc: 0.0549 | Student acc: 0.0213

TL;DR: We:

  • cut params by ~90%
  • cut inference time by >50%
  • dropped Jacobian norm by five orders of magnitude
  • lost <1.2% accuracy
  • unfortunately, it's not very robust...

So let's add in adversarial training and Hessian penalties!

All Combinations

$ python3 train.py
Batch shape: torch.Size([64, 1, 28, 28])
Labels: tensor([1, 5, 7, 2, 3, 3, 6, 5, 4, 5])
...
Training complete.
Evaluating models...
Teacher accuracy: 0.9807
Student accuracy: 0.9688
Student (Jacobian) accuracy: 0.9667
Student (Hessian) accuracy: 0.9676
Student (Jacobian + Hessian) accuracy: 0.9662
Student (adversarial) accuracy: 0.9674
Student (adversarial + Jacobian) accuracy: 0.9680
Student (adversarial + Hessian) accuracy: 0.9729
Student (adversarial + Jacobian + Hessian) accuracy: 0.9689
Teacher params: 535818
Student params: 50890
Student (Jacobian) params: 50890
Student (Hessian) params: 50890
Student (Jacobian + Hessian) params: 50890
Student (adversarial) params: 50890
Student (adversarial + Jacobian) params: 50890
Student (adversarial + Hessian) params: 50890
Student (adversarial + Jacobian + Hessian) params: 50890
Teacher time: 0.0029256391525268556
Student time: 0.0024872779846191405
Student (Jacobian) time: 0.0019856834411621093
Student (Hessian) time: 0.0012114858627319336
Student (Jacobian + Hessian) time: 0.0021112918853759765
Student (adversarial) time: 0.0012415409088134765
Student (adversarial + Jacobian) time: 0.0015691232681274414
Student (adversarial + Hessian) time: 0.0016123437881469726
Student (adversarial + Jacobian + Hessian) time: 0.001298227310180664
Teacher Jacobian norm: 48166.989453125
Student Jacobian norm: 30920.7513671875
Student (Jacobian) Jacobian norm: 0.29223392605781556
Student (Hessian) Jacobian norm: 27608.5517578125
Student (Jacobian + Hessian) Jacobian norm: 1.4812594175338745
Student (adversarial) Jacobian norm: 43936.63515625
Student (adversarial + Jacobian) Jacobian norm: 1.83348046541214
Student (adversarial + Hessian) Jacobian norm: 48621.565234375
Student (adversarial + Jacobian + Hessian) Jacobian norm: 1.0918631672859191
Epsilon: 0.05
Teacher accuracy under attack: 0.7455
Student accuracy under attack: 0.5367
Student (Jacobian) accuracy under attack: 0.5645
Student (Hessian) accuracy under attack: 0.5242
Student (Jacobian + Hessian) accuracy under attack: 0.5532
Student (adversarial) accuracy under attack: 0.8900
Student (adversarial + Jacobian) accuracy under attack: 0.8910
Student (adversarial + Hessian) accuracy under attack: 0.8990
Student (adversarial + Jacobian + Hessian) accuracy under attack: 0.8941
Epsilon: 0.1
Teacher accuracy under attack: 0.2930
Student accuracy under attack: 0.1105
Student (Jacobian) accuracy under attack: 0.1138
Student (Hessian) accuracy under attack: 0.1061
Student (Jacobian + Hessian) accuracy under attack: 0.1485
Student (adversarial) accuracy under attack: 0.7802
Student (adversarial + Jacobian) accuracy under attack: 0.7909
Student (adversarial + Hessian) accuracy under attack: 0.8126
Student (adversarial + Jacobian + Hessian) accuracy under attack: 0.7859
Epsilon: 0.2
Teacher accuracy under attack: 0.0443
Student accuracy under attack: 0.0193
Student (Jacobian) accuracy under attack: 0.0188
Student (Hessian) accuracy under attack: 0.0144
Student (Jacobian + Hessian) accuracy under attack: 0.0316
Student (adversarial) accuracy under attack: 0.3848
Student (adversarial + Jacobian) accuracy under attack: 0.4330
Student (adversarial + Hessian) accuracy under attack: 0.4174
Student (adversarial + Jacobian + Hessian) accuracy under attack: 0.4379
Evaluation complete.

Conclusions

Model Accuracy Params Time (E-3) Jacobian Norm ε = 0.05 ε = 0.10 ε = 0.20
Teacher 0.9807 535818 2.92563915 48166.9894531 0.7455 0.2930 0.0443
Student 0.9688 50890 2.48727798 30920.7513672 0.5367 0.1105 0.0193
Student (J) 0.9667 50890 1.98568344 00000.2922339🔥 0.5645 0.1138 0.0188
Student (H) 0.9676 50890 1.21148586🚀 27608.5517578 0.5242 0.1061 0.0144
Student (J + H) 0.9662 50890 2.11129189 00001.4812594 0.5532 0.1485 0.0316
Student (A) 0.9674 50890 1.24154154 43936.6351563 0.8900 0.7802 0.3848
Student (A + J) 0.9680 50890 1.56912327 00001.8334805 0.8910 0.7909 0.4330
Student (A + H) 0.9729😍 50890 1.61234379 48621.5652344 0.8990💅🏻 0.8126😎 0.4174
Student (A + J + H) 0.9689 50890 1.29822731 00001.0918632 0.8941 0.7859 0.4379💪🏻

About

An experiment with distillation using various combinations of naïve implementations, adversarial training, Jacobian penalties, and approximate Hessian penalties.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages