nlp_architect.nn.torch package
Subpackages
Submodules
nlp_architect.nn.torch.distillation module
-
class
nlp_architect.nn.torch.distillation.
TeacherStudentDistill
(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source] Bases:
object
Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.
Parameters: - teacher_model (TrainableModel) – teacher model
- temperature (float, optional) – KD temperature. Defaults to 1.0.
- dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
- loss_w (float, optional) – student loss weight. Defaults to 1.0.
- loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
-
static
add_args
(parser: argparse.ArgumentParser)[source] Add KD arguments to parser
Parameters: parser (argparse.ArgumentParser) – parser
-
distill_loss
(loss, student_logits, teacher_logits)[source] Add KD loss
Parameters: - loss – student loss
- student_logits – student model logits
- teacher_logits – teacher model logits
Returns: KD loss
nlp_architect.nn.torch.quantization module
Quantization ops
-
class
nlp_architect.nn.torch.quantization.
FakeLinearQuantizationWithSTE
[source] Bases:
torch.autograd.function.Function
Simulates error caused by quantization. Uses Straight-Through Estimator for Back prop
-
static
backward
(ctx, grad_output)[source] Calculate estimated gradients for fake quantization using Straight-Through Estimator (STE) according to: https://openreview.net/pdf?id=B1ae1lZRb
-
static
-
class
nlp_architect.nn.torch.quantization.
QuantizationConfig
(**kwargs)[source] Bases:
nlp_architect.common.config.Config
Quantization Configuration Object
-
ATTRIBUTES
= {'activation_bits': 8, 'ema_decay': 0.9999, 'mode': 'none', 'requantize_output': True, 'start_step': 0, 'weight_bits': 8}
-
-
class
nlp_architect.nn.torch.quantization.
QuantizationMode
[source] Bases:
enum.Enum
An enumeration.
-
DYNAMIC
= 2
-
EMA
= 3
-
NONE
= 1
-
-
class
nlp_architect.nn.torch.quantization.
QuantizedEmbedding
(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source] Bases:
nlp_architect.nn.torch.quantization.QuantizedLayer
,torch.nn.modules.sparse.Embedding
Embedding layer with quantization aware training capability
-
class
nlp_architect.nn.torch.quantization.
QuantizedLayer
(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source] Bases:
abc.ABC
Quantized Layer interface
-
CONFIG_ATTRIBUTES
= ['weight_bits', 'start_step', 'mode']
-
REPR_ATTRIBUTES
= ['mode', 'weight_bits']
-
fake_quantized_weight
-
classmethod
from_config
(*args, config=None, **kwargs)[source] Initialize quantized layer from config
-
weight_scale
-
-
class
nlp_architect.nn.torch.quantization.
QuantizedLinear
(*args, activation_bits=8, requantize_output=True, ema_decay=0.9999, **kwargs)[source] Bases:
nlp_architect.nn.torch.quantization.QuantizedLayer
,torch.nn.modules.linear.Linear
Linear layer with quantization aware training capability
-
CONFIG_ATTRIBUTES
= ['weight_bits', 'start_step', 'mode', 'activation_bits', 'requantize_output', 'ema_decay']
-
REPR_ATTRIBUTES
= ['mode', 'weight_bits', 'activation_bits', 'accumulation_bits', 'ema_decay', 'requantize_output']
-
inference_quantized_forward
(input)[source] Simulate quantized inference. quantize input and perform calculation with only integer numbers. This function should only be used while doing inference
-
quantized_bias
-
-
nlp_architect.nn.torch.quantization.
calc_max_quant_value
(bits)[source] Calculate the maximum symmetric quantized value according to number of bits
-
nlp_architect.nn.torch.quantization.
dequantize
(input, scale)[source] linear dequantization according to some scale
-
nlp_architect.nn.torch.quantization.
get_dynamic_scale
(x, bits, with_grad=False)[source] Calculate dynamic scale for quantization from input by taking the maximum absolute value from x and number of bits