Knowledge Distillation Trends
Knowledge Distillation is a famous model compression technique aiming to distill the “knowledge” from models with larger capacities, trained using expensive compute and storage to smaller lightweight models more suitable for deployment with smaller inference times.
A neural network typically outputs probabilities using a softmax function. The inputs to the final softmax function are known as logits.
What is Knowledge Distillation?
The concept of Knowledge Distillation was first introduced by Hinton et al. in Distilling the Knowledge in a Neural Network, 2015 as an efficient technique to transfer knowledge via minimizing the KL divergence between the logits of a teacher and student model. KL divergence is the statistical distance between two probability measures and by minimizing the KL divergence we ensure that the distribution of the logits in the student follows the same distribution as the logits in the teacher.
This allows us to transfer the generalization ability of the bigger teacher model using its class probabilities as “soft targets” for training the smaller student model, in addition to the available class labels. These soft probabilities reveal more information about the underlying data distribution which the ground truth values cannot encode.
Why should I care? Benefits of Knowledge Distillation
With increasing sizes of the current SOTA models, compression techniques such as Knowledge Distillation are of prime importance. Let’s take the leading segmentation model, for instance, the Segment Anything Model (SAM) by Meta AI is a 640M parameter model trained on 11M images with over a billion masks. Maintaining this model even for inference is quite expensive. However, if we use Knowledge Distillation using a pre-trained SAM model as the teacher we can distill down the knowledge into a much smaller model thereby enabling great performance at a decreased cost.
Coupled with Self-Supervised Learning, Knowledge Distillation can provide lightweight and efficient models while maintaining performance. Typically Knowledge Distillation relies on using supervisory signals in the form of labels however this is not needed when used alongside Self-Supervised Learning. The teacher model can be pre-trained using Self-Supervised learning on unlabelled data and the smaller student network can be taught to mimic the performance. This would enable the student model to learn more meaningful and generalizable features.
Why does Knowledge Distillation work?
Explaining how and why Knowledge Distillation outperforms learning from raw data remains a challenge. The most prominent and empirical proof comes from Explaining Knowledge Distillation by Quantifying the Knowledge by Cheng et al, 2020. based on the information bottleneck theory the success of KD by quantifying the knowledge encoded in the intermediate layer of a network.
- Knowledge Distillation makes a network learn more visual concepts: The authors define visual concepts as information that is discarded less when compared to the average information discarded in background regions. For example, in the case of image classification, visual concepts will be the objects in the foreground (task-relevant), while the objects in the background are task-irrelevant. The authors quantify this notion of task-relevant visual concepts and show that a well-trained teacher network encodes more visual concepts when compared to a baseline supervised model and since the student model learns to mimic the teacher model it’ll contain more task-relevant visual concepts.
- Knowledge Distillation ensures that a network is prone to learning various visual concepts simultaneously: The authors study two properties of networks (student model and baseline) as they train and get better: 1. Whether a network learns various visual concepts of a specific image quickly and 2. Whether a network learns various visual concepts of different images simultaneously. They find that a student model learns more foreground features and fewer background features when compared to a baseline supervised model.
- Knowledge Distillation yields more stable optimization directions than learning from raw data: In supervised training, a model tries various directions in the optimization space and eventually learns to discard the wrong directions (information-bottleneck theory) however in the case of KD the teacher network guides the student network to learn representations without any significant “detours”
Types of Knowledge Distillation
Recent works in Knowledge Distillation can be broadly classified into two categories:
- Logits Distillation: These classes of models best represent the vanilla Knowledge Distillation strategy wherein the logits from the larger teacher networks are used along with class labels to train the smaller student network. These methods require marginal computational and storage costs but have inferior performance when compared to features-based distillation.
- Feature-Based Distillation: These models differ in the sense that instead of taking the logits from the teacher network, we use features from intermediate layers of the teacher network and align them with features from the student network. Most of the feature-based methods perform significantly better when compared to logit-based methods however involve considerably high computational and storage costs.
Tricky Bits of Knowledge Distillation
Shared Temperature Scaling
In most Knowledge Distillation methodologies, the temperature in the softmax function is shared between the student and teacher.
This neglects the possibility of distinct temperature values in the KL divergence and implicitly enforces an exact match between the student and teacher’s logits.
The authors of Logit Standardization in Knowledge Distillation, 2024 address this by proposing a weighted logit standard deviation as an adaptive temperature and presenting a Z-score logit standardization as a pre-processing step before applying softmax.
Teacher vs Student Disparity
Cho et al. revealed some key properties regarding Knowledge Distillation in their paper On the Efficacy of Knowledge Distillation, 2019.
- More accurate teachers often don’t make good teachers. A possible explanation might be that as the teacher becomes both more confident and more accurate the output probabilities start resembling more and more a one-hot encoding of the true label and thus the information available to the student decreases.
- Larger models do not often make better teachers, owing to a capacity mismatch between the student and the teacher.
- Teacher accuracy is a poor predictor of student performance.
Decoupling Knowledge Distillation
Zhao et al. in Decoupled Knowledge Distillation, 2022 were able to bring back some attention to logit-based distillation by providing a novel viewpoint by reformulation the classical Knowledge Distillation loss into two parts:
- Target Class Knowledge Distillation (TCKD)
- Non-Target Class Knowledge Distillation (NCKD)
They revealed that the classical Knowledge Distillation Loss is a coupled formulation leading to suppressed effectiveness of NCKD and limits the flexibility to balance both parts.
This reformulation allows us to then investigate the individual effects of NCKD and TCKD, revealing key information about classical Knowledge Distillation.
- The authors performed ablation studies and found that target-class related knowledge (parts corresponding to TCKD) is not as important when compared to the knowledge among non-target classes. They back this by comparing the performance between a baseline, classical knowledge distillation, singly TCKD, and singly NCKD. For more details please refer to Section 3.2 of the paper.
- TCKD transfers the knowledge concerning the “difficulty” of training samples. Following the reformulation the authors hypothesise this and validate their claims by experimenting with various strategies to increase the difficulty of training data.
- NCKD is the prominent reason why logit distillation works but is greatly suppressed. When comparing baselines with singly NCKD one can see comparable or even better performance than classical Knowledge Distillation. Thus the knowledge among non-target classes is of prime importance to Knowledge Distillation. The reformulation suggests that NCKD is weighed with the compliment of the teacher’s confidence in the target class. Thus, the more confident the teacher is in the training sample, the more reliable and valuable knowledge it could provide. However, since it’s weighed with the compliment this knowledge ends up getting suppressed.
Following their learning the authors propose a novel knowledge distillation framework that considers TCKD and NCKD in a decoupled formulation.
Knowledge Distillation for Masked Image Modelling
Following the success in language, Masked Image Modeling (MIM) has emerged as a great technique to acquire meaningful representations by reconstructing masked images and is one of the most prominent ways for feature pre-training.
The leading work in Masked Image Modeling is based on Masked Autoencoders (MAE). These models consist of two components: the encoder, which projects unmasked patches to a latent space, and the decoder, which predicts the pixel values of masked patches.
This architecture has two key design principles:
- An asymmetric design in the sense that the encoder only operates on the visible tokens while the decoder operates on the latent representations and masked tokens
- We use a lightweight decoder to reconstruct the image.
The most straightforward approach to directly applying existing Knowledge Distillation methods empirically doesn’t result in much improvement. They fail to leverage key design elements of MAE such as the asymmetric design, feeding in patches as opposed to the full image and a lightweight decoder.
Bai et al. in Masked Autoencoders Enable Efficient Knowledge Distillers, 2022, study an alternative solution by directly applying Knowledge Distillation at the pre-training stage. This allows them to stick to the core design principle of MAEs. Both the teacher and the student network operate on patches (asymmetric design) while the (lightweight) decoder learns to reconstruct the images.
Since, MAE pre-training involves no categorical labels, distilling logits can hardly learn meaningful representations. Therefore, the authors use intermediate features instead.
In DMAE, masked inputs are fed into both the student and teacher models. A small projection head is then used to align features from the intermediate layers of the teacher with the student model. This addresses the possible feature dimension mismatch between teacher models and student models and provides extra flexibility for feature alignment. The overall loss function consists of two components:
- Reconstruction Loss
- Feature Alignment Loss
Notably, the authors report great significant performance gains when compared to vanilla MAEs and other Knowledge Distillation techniques with a marginal increase in training cost.
- When compared to a baseline MAE pre-trained for 100 epochs and then fine-tuned for 100 epochs, DMAE leads to (+2.4%) improvement on ImageNet.
- This performance improvement comes from a mere 5 more GPU hours as reported by the authors.
Asymmetric Masked Distillation
Zhao et al. in Asymmetric Masked Distillation for Pre-Training Small Foundation Models, 2024 aim to allow the teacher to acquire more contextual information than the student model by following an asymmetric masking strategy wherein the teacher is supplied with patches sampled from a lower masking ratio while the student is supplied with unmasked patches that are a subset of those of the teacher.
This results in the student model gaining even fewer input patches and thereby an increase in the difficulty of the reconstruction task. This also allows the teacher model to provide more contextual information to the student model. This results in a compromise between the masking ratio of the teacher and the computational cost.
The authors report Top-1 classification accuracy on ImageNet-1k using a masking ratio of the student and the teacher is 75% and 50% respectively, showing that they outperform DMAE and related works.
Conclusion
In this article, we went over the recent trends in Knowledge Distillation and saw how we can use modern-day Self-Supervised Learning architectures and techniques such as Masked Autoencoders (MAE) and Masked Image Modelling (MIM) along with Knowledge Distillation techniques.
References
- Distilling the Knowledge in a Neural Network, Hinton et al. 2015
- On the Efficacy of Knowledge Distillation, Cho and Hariharan. 2019
- Explaining Knowledge Distillation by Quantifying the Knowledge, Cheng et al. 2020
- Decoupled Knowledge Distillation, Zhao et al. 2022
- Masked Autoencoders Are Scalable Vision Learners, He et al. 2022
- Masked Autoencoders Enable Efficient Knowledge Distillers, Bai et al. 2022
- Asymmetric Masked Distillation for Pre-Training Small Foundation Models, Zhao et al. 2024
- Logit Standardization in Knowledge Distillation, Sun et al. 2024
Saurav,
Machine Learning Advocate Engineer
lightly.ai