Transformer model distillation¶
Overview¶
Transformer models which were pretrained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages  (1) model size and (2) speed, since such large models are computationally heavy.
One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a StudentTeacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.
Knowledge Distillation¶
One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.
\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T  logits_{teacher} / T)\)
where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.
TeacherStudentDistill
¶
This class can be added to support for distillation in a model.
To add support for distillation, the student model must include handling of training
using TeacherStudentDistill
class, see nlp_architect.procedures.token_tagging.do_kd_training
for
an example how to train a neural tagger using a transformer model using distillation.

class
nlp_architect.nn.torch.distillation.
TeacherStudentDistill
(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source]¶ TeacherStudent knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters:  teacher_model (TrainableModel) – teacher model
 temperature (float, optional) – KD temperature. Defaults to 1.0.
 dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
 loss_w (float, optional) – student loss weight. Defaults to 1.0.
 loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)

static
add_args
(parser: argparse.ArgumentParser)[source]¶ Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser

distill_loss
(loss, student_logits, teacher_logits)[source]¶ Add KD loss
Parameters:  loss – student loss
 student_logits – student model logits
 teacher_logits – teacher model logits
Returns: KD loss
Supported models¶
NeuralTagger
¶
Useful for training taggers from Transformer models. NeuralTagger
model that uses LSTM and CNN based embedders are ~3M parameters in size (~30100x smaller than BERT models) and ~10x faster on average.
Usage:
 Train a transformer tagger using
TransformerTokenClassifier
or usingnlptrain transformer_token
command  Train a neural tagger
Neural Tagger
using the trained transformer model and use theTeacherStudentDistill
model that was configured with the transformer model. This can be done usingNeural Tagger
’s train loop or by usingnlptrain tagger_kd
command
Note
More models supporting distillation will be added in next releases
Pseudo Labeling¶
This method can be used in order to produce pseudolabels when training the student on unlabeled examples. The pseudoguess is produced by applying arg max on the logits of the teacher model, and results in the following loss:
where CE is Cross Entropy loss, yˆ is the predicted entity label class by the student model and yˆt is the predicted label by the teacher model.
[1]  Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531 