nlp_architect.nn.torch package
Subpackages
Submodules
nlp_architect.nn.torch.distillation module
-
class
nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source] Bases:
objectTeacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
- loss_w (float, optional) – student loss weight. Defaults to 1.0.
- loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
-
static
add_args(parser: argparse.ArgumentParser)[source] Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
-
distill_loss(loss, student_logits, teacher_logits)[source] Add KD loss
Parameters: - loss – student loss
- student_logits – student model logits
- teacher_logits – teacher model logits
Returns: KD loss
nlp_architect.nn.torch.quantization module
Quantization ops
-
class
nlp_architect.nn.torch.quantization.FakeLinearQuantizationWithSTE[source] Bases:
torch.autograd.function.FunctionSimulates error caused by quantization. Uses Straight-Through Estimator for Back prop
-
static
backward(ctx, grad_output)[source] Calculate estimated gradients for fake quantization using Straight-Through Estimator (STE) according to: https://openreview.net/pdf?id=B1ae1lZRb
-
static
-
class
nlp_architect.nn.torch.quantization.QuantizationConfig(**kwargs)[source] Bases:
nlp_architect.common.config.ConfigQuantization Configuration Object
-
ATTRIBUTES= {'activation_bits': 8, 'ema_decay': 0.9999, 'mode': 'none', 'requantize_output': True, 'start_step': 0, 'weight_bits': 8}
-
-
class
nlp_architect.nn.torch.quantization.QuantizationMode[source] Bases:
enum.EnumAn enumeration.
-
DYNAMIC= 2
-
EMA= 3
-
NONE= 1
-
-
class
nlp_architect.nn.torch.quantization.QuantizedEmbedding(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source] Bases:
nlp_architect.nn.torch.quantization.QuantizedLayer,torch.nn.modules.sparse.EmbeddingEmbedding layer with quantization aware training capability
-
class
nlp_architect.nn.torch.quantization.QuantizedLayer(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source] Bases:
abc.ABCQuantized Layer interface
-
CONFIG_ATTRIBUTES= ['weight_bits', 'start_step', 'mode']
-
REPR_ATTRIBUTES= ['mode', 'weight_bits']
-
fake_quantized_weight
-
classmethod
from_config(*args, config=None, **kwargs)[source] Initialize quantized layer from config
-
weight_scale
-
-
class
nlp_architect.nn.torch.quantization.QuantizedLinear(*args, activation_bits=8, requantize_output=True, ema_decay=0.9999, **kwargs)[source] Bases:
nlp_architect.nn.torch.quantization.QuantizedLayer,torch.nn.modules.linear.LinearLinear layer with quantization aware training capability
-
CONFIG_ATTRIBUTES= ['weight_bits', 'start_step', 'mode', 'activation_bits', 'requantize_output', 'ema_decay']
-
REPR_ATTRIBUTES= ['mode', 'weight_bits', 'activation_bits', 'accumulation_bits', 'ema_decay', 'requantize_output']
-
inference_quantized_forward(input)[source] Simulate quantized inference. quantize input and perform calculation with only integer numbers. This function should only be used while doing inference
-
quantized_bias
-
-
nlp_architect.nn.torch.quantization.calc_max_quant_value(bits)[source] Calculate the maximum symmetric quantized value according to number of bits
-
nlp_architect.nn.torch.quantization.dequantize(input, scale)[source] linear dequantization according to some scale
-
nlp_architect.nn.torch.quantization.get_dynamic_scale(x, bits, with_grad=False)[source] Calculate dynamic scale for quantization from input by taking the maximum absolute value from x and number of bits