Source code for nlp_architect.nn.torch.distillation

# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import argparse
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F

from nlp_architect.models import TrainableModel

logger = logging.getLogger(__name__)

MSE_loss = nn.MSELoss(reduction="mean")
KL_loss = nn.KLDivLoss(reduction="batchmean")

losses = {
    "kl": KL_loss,
    "mse": MSE_loss,

TEACHER_TYPES = ["bert"]

[docs]class TeacherStudentDistill: """ Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model. Args: teacher_model (TrainableModel): teacher model temperature (float, optional): KD temperature. Defaults to 1.0. dist_w (float, optional): distillation loss weight. Defaults to 0.1. loss_w (float, optional): student loss weight. Defaults to 1.0. loss_function (str, optional): loss function to use (kl for KLDivLoss, mse for MSELoss) """ def __init__( self, teacher_model: TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function="kl", ): self.teacher = teacher_model self.t = temperature self.dist_w = dist_w self.loss_w = loss_w self.loss_fn = losses.get(loss_function, KL_loss)
[docs] def get_teacher_logits(self, inputs): """ Get teacher logits Args: inputs: input Returns: teachr logits """ return self.teacher.get_logits(inputs)
[docs] @staticmethod def add_args(parser: argparse.ArgumentParser): """ Add KD arguments to parser Args: parser (argparse.ArgumentParser): parser """ parser.add_argument( "--teacher_model_path", type=str, required=True, help="Path to teacher model" ) parser.add_argument( "--teacher_model_type", type=str, required=True, choices=TEACHER_TYPES, help="Teacher model class type", ) parser.add_argument("--kd_temp", type=float, default=1.0, help="KD temperature value") parser.add_argument( "--kd_loss_fn", type=str, choices=["kl", "mse"], default="mse", help="KD loss function" ) parser.add_argument("--kd_dist_w", type=float, default=0.1, help="KD weight on loss") parser.add_argument( "--kd_student_w", type=float, default=1.0, help="KD student weight on loss" )
[docs] def distill_loss(self, loss, student_logits, teacher_logits): """ Add KD loss Args: loss: student loss student_logits: student model logits teacher_logits: teacher model logits Returns: KD loss """ student_log_sm = F.log_softmax(student_logits / self.t, dim=-1) teacher_log_sm = F.softmax(teacher_logits / self.t, dim=-1) distill_loss = self.loss_fn(input=student_log_sm, target=teacher_log_sm) return (self.loss_w * loss) + (distill_loss * self.dist_w * (self.t ** 2))
[docs] def distill_loss_dict(self, loss, student_logits_dict, teacher_logits_dict): """ Add KD loss Args: loss: student loss student_logits: student model logits teacher_logits: teacher model logits Returns: KD loss """ student_sm_dict = {} for i in range(len(student_logits_dict.keys())): student_sm_dict[i] = F.log_softmax(student_logits_dict[i] / self.t, dim=-1) teacher_sm_dict = {} for i in range(len(teacher_logits_dict.keys())): teacher_sm_dict[i] = F.softmax(teacher_logits_dict[i] / self.t, dim=-1) distill_losses = [ self.loss_fn(input=student_sm_dict[i], target=teacher_sm_dict[i]) for i in range(len(student_sm_dict.keys())) ] distill_loss = torch.mean(torch.stack(distill_losses)) return (self.loss_w * loss) + (distill_loss * self.dist_w * (self.t ** 2))