Transformer model distillation
Overview
Transformer models which were pre-trained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages - (1) model size and (2) speed, since such large models are computationally heavy.
One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a Student-Teacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.
Knowledge Distillation
One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.
\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)\)
where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.
TeacherStudentDistill
This class can be added to support for distillation in a model.
To add support for distillation, the student model must include handling of training
using TeacherStudentDistill
class, see nlp_architect.procedures.token_tagging.do_kd_training
for
an example how to train a neural tagger using a transformer model using distillation.
-
class
nlp_architect.nn.torch.distillation.
TeacherStudentDistill
(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source] Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
- loss_w (float, optional) – student loss weight. Defaults to 1.0.
- loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
-
static
add_args
(parser: argparse.ArgumentParser)[source] Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
-
distill_loss
(loss, student_logits, teacher_logits)[source] Add KD loss
Parameters: - loss – student loss
- student_logits – student model logits
- teacher_logits – teacher model logits
Returns: KD loss
Supported models
NeuralTagger
Useful for training taggers from Transformer models. NeuralTagger
model that uses LSTM and CNN based embedders are ~3M parameters in size (~30-100x smaller than BERT models) and ~10x faster on average.
Usage:
- Train a transformer tagger using
TransformerTokenClassifier
or usingnlp-train transformer_token
command - Train a neural tagger
Neural Tagger
using the trained transformer model and use theTeacherStudentDistill
model that was configured with the transformer model. This can be done usingNeural Tagger
’s train loop or by usingnlp-train tagger_kd
command
Note
More models supporting distillation will be added in next releases
Pseudo Labeling
This method can be used in order to produce pseudo-labels when training the student on unlabeled examples. The pseudo-guess is produced by applying arg max on the logits of the teacher model, and results in the following loss:
where CE is Cross Entropy loss, yˆ is the predicted entity label class by the student model and yˆt is the predicted label by the teacher model.
[1] | Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531 |