Transformer model distillation

Overview

Transformer models which were pre-trained on large corpora, such as BERT/XLNet/XLM, have shown to improve the accuracy of many NLP tasks. However, such models have two distinct disadvantages - (1) model size and (2) speed, since such large models are computationally heavy.

One possible approach to overcome these cons is to use Knowledge Distillation (KD). Using this approach a large model is trained on the data set and then used to teach a much smaller and more efficient network. This is often referred to a Student-Teacher training where a teacher network adds its error to the student’s loss function, thus, helping the student network to converge to a better solution.

Knowledge Distillation

One approach is similar to the method in Hinton 2015 [1]. The loss function is modified to include a measure of distributions divergence, which can be measured using KL divergence or MSE between the logits of the student and the teacher network.

\(loss = w_s \cdot loss_{student} + w_d \cdot KL(logits_{student} / T || logits_{teacher} / T)\)

where T is a value representing temperature for softening the logits prior to applying softmax. loss_{student} is the original loss of the student network obtained during regular training. Finally, the losses are weighted.

TeacherStudentDistill

This class can be added to support for distillation in a model. To add support for distillation, the student model must include handling of training using TeacherStudentDistill class, see nlp_architect.procedures.token_tagging.do_kd_training for an example how to train a neural tagger using a transformer model using distillation.

class nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source]

Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.

Parameters:
  • teacher_model (TrainableModel) – teacher model
  • temperature (float, optional) – KD temperature. Defaults to 1.0.
  • dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
  • loss_w (float, optional) – student loss weight. Defaults to 1.0.
  • loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
static add_args(parser: argparse.ArgumentParser)[source]

Add KD arguments to parser

Parameters:parser (argparse.ArgumentParser) – parser
distill_loss(loss, student_logits, teacher_logits)[source]

Add KD loss

Parameters:
  • loss – student loss
  • student_logits – student model logits
  • teacher_logits – teacher model logits
Returns:

KD loss

distill_loss_dict(loss, student_logits_dict, teacher_logits_dict)[source]

Add KD loss

Parameters:
  • loss – student loss
  • student_logits – student model logits
  • teacher_logits – teacher model logits
Returns:

KD loss

get_teacher_logits(inputs)[source]

Get teacher logits

Parameters:inputs – input
Returns:teachr logits

Supported models

NeuralTagger

Useful for training taggers from Transformer models. NeuralTagger model that uses LSTM and CNN based embedders are ~3M parameters in size (~30-100x smaller than BERT models) and ~10x faster on average.

Usage:

  1. Train a transformer tagger using TransformerTokenClassifier or using nlp-train transformer_token command
  2. Train a neural tagger Neural Tagger using the trained transformer model and use the TeacherStudentDistill model that was configured with the transformer model. This can be done using Neural Tagger’s train loop or by using nlp-train tagger_kd command

Note

More models supporting distillation will be added in next releases

Pseudo Labeling

This method can be used in order to produce pseudo-labels when training the student on unlabeled examples. The pseudo-guess is produced by applying arg max on the logits of the teacher model, and results in the following loss:

\[\begin{split}loss &= \Bigg\{\begin{eqnarray}CE(yˆ, y) && labeled&example\\ CE(yˆ, yˆt) && unlabeled&example\end{eqnarray}\end{split}\]

where CE is Cross Entropy loss, yˆ is the predicted entity label class by the student model and yˆt is the predicted label by the teacher model.

[1]Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyals, Jeff Dean, https://arxiv.org/abs/1503.02531