Core Idea
The concept provides us an answer to the question if a smaller model can be made to perform as good as a large model on classification tasks . This question can be rephrased to distilling the knowledge of a larger model into a smaller model.
The approach is outlined by the paper titled “Distilling the Knowledge in a Neural Network” 1.
The above image outlines the different models and the loss function components that usually comprise KD. The image is taken from a paper titled “On the Efficacy of Knowledge Distillation” 2.
Motivations
Knowledge distillation primarily helps port your big beefy models to models with smaller memory and compute footprints. This has applications in edge devices and sensors where compute / memory and energy usage are of big concern. It also can just be used to speed up inference via smaller models in regular applications.
Smaller footprints can be achieved by using more efficient model architectures or low-precision models.
Training for Knowledge Distillation
You would require two models to perform this task with one of them being larger (in model size and complexity) than the other.
There are two stages to this training procedure:
- Train the larger model on your dataset. This is your regular training procedure.
- Train your smaller model on the same dataset with an additional KL divergence loss between the smaller and larger model predictions. This loss component constrains the outputs of the smaller model to match those of the larger model.
An Analogy
This training procedure also has a nice analogy of student-teacher relationship between the two models.
If you consider a textbook, the student trying to understand it by themselves might prove too difficult. A teacher who has digested the contents of the textbook better will have an easier time explaining the nuances of the textbook to a student.
Here the textbook represents the dataset and the student/teacher to be the smaller/larger models respectively. The additional loss component in the student’s training procedure constitutes the knowledge transfer between the teacher and student.
Understanding the loss function
$$ loss = \alpha * CE(y_{pred}, y_{true}) + \beta * KL(y_{pred}, y_{teacher})$$
The loss function for the teacher would just be cross entropy loss which is usually the case for classification tasks.
The student’s loss function would also be the cross entropy loss with the aforementioned KL-divergence loss between the students outputs and the teacher’s outputs. A simple description of the KL-divergence is that it measures the distance between two distributions (here the distributions being the class probabilities of the student/teacher). By minimizing the KL-div loss, we are teaching the student to make similar predictions to the teacher.
Where is the knowledge transfer coming from?
A teacher model’s predictions for a given datapoint isn’t an exact class but a probability distribution across the possible classes. This distribution contains the information that makes training the student easier.
Consider the example of an animal classifier with four classes dog / cat / eagle and snake. If you were to train with an image of a dog the expected cross entropy loss would be computed with [1,0,0,0] as a 1-hot encoded target.
If you were to consider the outputs from a teacher for the same image, it would be closer to something like [0.9,0.08,0.01,0.01], because we know that an image of a dog is closer to cat than that of an eagle or snake (they both have limbs and ears as compared to snake/eagle). By using a KL-div loss on these outputs, we wish to convey to the student that the classes of cat and dog are related (and weighed proportionally based on the output probabilities).
The key thing to note is that the only thing we need from the teacher is just its outputs on a given datapoint i.e the logits. You would think that given we have access to the teachers’s parameters there might something more complicated to transfer the knowledge to the student.
Hyperparameters
New approaches = More hyperparameters. There are a few new parameters to consider when deciding on the this task.
Temperature
Temperature | Logits | Softmax Probabilities |
---|---|---|
1 | [30, 5, 2, 2] | [1.0000e+00, 1.3888e-11, 6.9144e-13, 6.9144e-13] |
10 | [3, 0.5, 0.2, 0.2] | [0.8308, 0.0682, 0.0505, 0.0505] |
As we saw earlier, the predictions from the teacher is where the knowledge transfer comes from. If the teacher is extremely confident in its prediction, it won’t be too different from the true 1-hot encoded label. Consider the class scores (the model’s outputs before softmax) in the first row of the above table. Its corresponding softmax output isn’t much different from the ground truth [1,0,0,0]. Ideally we want the student learn that there’s a relation between the 1st and 2nd class. So before we compute the softmax, we divide the scores of each class by a constant known as the temperature. In the above example with a temperature value of 10, the new scores would be as shown in the 2nd row and the corresponding softmax probabilities are more distributed across the other classes. Higher temperatures produce softer probabilities which provides a stronger signal to the student for learning from the teacher.
When setting the temperatures, both the student and teacher’s pre-softmax outputs are divided by temperature only for the KL component of the loss. To offset this division, the KL term is multiplied by square of the temperature $T^2$
Weights for the student loss function.
We need to decide the proportions ($\alpha$ and $\beta$ in the loss function) on how much the regular cross-entropy loss and KL-div loss because this determines how much we value the student learning from scratch vs learning from the teacher. Usually $\beta$ is smaller than $\alpha$ which causes the KL loss to act as regularizer for the student model.
Teacher / Student Model architectures
It would be tempting to use a very large model as your teacher as it has the capacity to generalize to your dataset and transfer it to your smaller model. But this paper 2 shows that if there is a large mismatch between the capacities of your student and teacher models, it might have the opposite effect and instead hurt the performance of your student model. When the teacher model is too large, the student model has difficulties mimicking the predictions of the teachers.
Early stopping
With the KL-divergence loss, the authors of 2 show that towards the end of training, the KL div loss dominates the loss and in turn hurts the overall accuracy of the student model. Thus they propose that early stopping of the KL div loss greatly benefits the student model.
Related Topics
Quantization
Model quantization 3 is a technique where the weights of the model are converted from high-precision numbers to low precision numbers. Consider an analogy that multiplying 2 and 10 is easier than 2.2334 and 10.2938. But the downside is that this conversion process is not trivial and sometimes it needs to be done in tandem while training the model as well. Knowledge distillation is one way to stabilize training such quantized models from traditional large models.
Pruning
Pruning is a technique used to remove redundant or ineffective components of a trained neural network. It is known that not all parts of a neural network contribute evenly towards its performance. Pruning identifies such parts and helps remove them to reduce model size and inference speed. One way to combine pruning and knowledge distillation 4 is to
- Train a large model on a dataset
- Prune it to arrive at a smaller model
- Perform knowledge distillation between the trained model and pruned model.
This allows us to systematically reduce the model size to further improve model performance.
Conclusion
Knowledge distillation is an interesting way to improve inference quality allowing us to combine performance of large models with low resource architectures.
For more information on this topic please see the references section 5.
-
Distilling the Knowledge in a Neural Network - https://arxiv.org/pdf/1503. ↩︎
-
On the Efficacy of Knowledge Distillation - https://arxiv.org/abs/1910.01348 ↩︎
-
A Survey on Methods and Theories of Quantized Neural Networks - https://arxiv.org/abs/1808.04752 ↩︎
-
Using Distillation to Improve Network Performance after Pruning and Quantization - https://dl.acm.org/doi/abs/10.1145/3366750.3366751 ↩︎
-
Distiller library’s documentation on Knowledge distillation ↩︎