# ******************************************************************************
# Copyright 2017-2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# pylint: disable=bad-super-call
"""
Quantized BERT layers and model
"""
import logging
import os
import sys
import torch
from torch import nn
from transformers.modeling_bert import (
ACT2FN,
BertAttention,
BertConfig,
BertEmbeddings,
BertEncoder,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
BertIntermediate,
BertLayer,
BertLayerNorm,
BertModel,
BertOutput,
BertPooler,
BertPreTrainedModel,
BertSelfAttention,
BertSelfOutput,
)
from nlp_architect.models.pretrained_models import S3_PREFIX
from nlp_architect.nn.torch.quantization import (
QuantizationConfig,
QuantizedEmbedding,
QuantizedLayer,
QuantizedLinear,
)
logger = logging.getLogger(__name__)
QUANT_WEIGHTS_NAME = "quant_pytorch_model.bin"
QUANT_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"bert-base-uncased": S3_PREFIX + "/models/transformers/bert-base-uncased.json", # noqa: E501
"bert-large-uncased": S3_PREFIX + "/models/transformers/bert-large-uncased.json", # noqa: E501
}
[docs]def quantized_linear_setup(config, name, *args, **kwargs):
"""
Get QuantizedLinear layer according to config params
"""
try:
quant_config = QuantizationConfig.from_dict(getattr(config, name))
linear = QuantizedLinear.from_config(*args, **kwargs, config=quant_config)
except AttributeError:
linear = nn.Linear(*args, **kwargs)
return linear
[docs]def quantized_embedding_setup(config, name, *args, **kwargs):
"""
Get QuantizedEmbedding layer according to config params
"""
try:
quant_config = QuantizationConfig.from_dict(getattr(config, name))
embedding = QuantizedEmbedding.from_config(*args, **kwargs, config=quant_config)
except AttributeError:
embedding = nn.Embedding(*args, **kwargs)
return embedding
[docs]class QuantizedBertConfig(BertConfig):
pretrained_config_archive_map = QUANT_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
[docs]class QuantizedBertEmbeddings(BertEmbeddings):
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = quantized_embedding_setup(
config, "word_embeddings", config.vocab_size, config.hidden_size, padding_idx=0
)
self.position_embeddings = quantized_embedding_setup(
config, "position_embeddings", config.max_position_embeddings, config.hidden_size
)
self.token_type_embeddings = quantized_embedding_setup(
config, "token_type_embeddings", config.type_vocab_size, config.hidden_size
)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs]class QuantizedBertSelfAttention(BertSelfAttention):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
)
self.output_attentions = config.output_attentions
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = quantized_linear_setup(
config, "attention_query", config.hidden_size, self.all_head_size
)
self.key = quantized_linear_setup(
config, "attention_key", config.hidden_size, self.all_head_size
)
self.value = quantized_linear_setup(
config, "attention_value", config.hidden_size, self.all_head_size
)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
[docs]class QuantizedBertSelfOutput(BertSelfOutput):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = quantized_linear_setup(
config, "attention_output", config.hidden_size, config.hidden_size
)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs]class QuantizedBertAttention(BertAttention):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = QuantizedBertSelfAttention(config)
self.output = QuantizedBertSelfOutput(config)
[docs] def prune_heads(self, heads):
raise NotImplementedError("pruning heads is not implemented for Quantized BERT")
[docs]class QuantizedBertOutput(BertOutput):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = quantized_linear_setup(
config, "ffn_output", config.intermediate_size, config.hidden_size
)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs]class QuantizedBertLayer(BertLayer):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = QuantizedBertAttention(config)
self.is_decoder = config.is_decoder
if self.is_decoder:
logger.warning("Using QuantizedBertLayer as decoder was not tested.")
self.crossattention = QuantizedBertAttention(config)
self.intermediate = QuantizedBertIntermediate(config)
self.output = QuantizedBertOutput(config)
[docs]class QuantizedBertEncoder(BertEncoder):
def __init__(self, config):
super(BertEncoder, self).__init__()
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList(
[QuantizedBertLayer(config) for _ in range(config.num_hidden_layers)]
)
[docs]class QuantizedBertPooler(BertPooler):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = quantized_linear_setup(
config, "pooler", config.hidden_size, config.hidden_size
)
self.activation = nn.Tanh()
[docs]class QuantizedBertPreTrainedModel(BertPreTrainedModel):
config_class = QuantizedBertConfig
base_model_prefix = "quant_bert"
[docs] def init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear, nn.Embedding, QuantizedLinear, QuantizedEmbedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
[docs] @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, from_8bit=False, **kwargs):
"""load trained model from 8bit model"""
if not from_8bit:
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
config = kwargs.pop("config", None)
output_loading_info = kwargs.pop("output_loading_info", False)
# Load config
if config is None:
config = cls.config_class.from_pretrained(
pretrained_model_name_or_path, *args, **kwargs
)
# Load model
model_file = os.path.join(pretrained_model_name_or_path, QUANT_WEIGHTS_NAME)
# Instantiate model.
model = cls(config)
# Set model to initialize variables to be loaded from quantized
# checkpoint which are None by Default
model.eval()
# Get state dict of model
state_dict = torch.load(model_file, map_location="cpu")
logger.info("loading weights file {}".format(model_file))
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
# Make sure we are able to load base models as well as derived models (with heads)
start_prefix = ""
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
start_prefix = cls.base_model_prefix + "."
if hasattr(model, cls.base_model_prefix) and not any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()
):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys
)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys
)
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
model.__class__.__name__, "\n\t".join(error_msgs)
)
)
if hasattr(model, "tie_weights"):
model.tie_weights() # make sure word embedding weights are still tied
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"error_msgs": error_msgs,
}
return model, loading_info
return model
[docs] def save_pretrained(self, save_directory):
"""save trained model in 8bit"""
super().save_pretrained(save_directory)
# Only save the model it-self if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
model_to_save.toggle_8bit(True)
output_model_file = os.path.join(save_directory, QUANT_WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.toggle_8bit(False)
[docs] def toggle_8bit(self, mode: bool):
def _toggle_8bit(module):
if isinstance(module, QuantizedLayer):
module.mode_8bit = mode
self.apply(_toggle_8bit)
if mode:
training = self.training
self.eval()
self.train(training)
[docs]class QuantizedBertModel(QuantizedBertPreTrainedModel, BertModel):
def __init__(self, config):
# we only want BertForQuestionAnswering init to run to avoid unnecessary
# initializations
super(BertModel, self).__init__(config)
self.embeddings = QuantizedBertEmbeddings(config)
self.encoder = QuantizedBertEncoder(config)
self.pooler = QuantizedBertPooler(config)
self.apply(self.init_weights)
[docs]class QuantizedBertForSequenceClassification(
QuantizedBertPreTrainedModel, BertForSequenceClassification
):
def __init__(self, config):
# we only want BertForQuestionAnswering init to run to avoid unnecessary
# initializations
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.bert = QuantizedBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = quantized_linear_setup(
config, "head", config.hidden_size, self.config.num_labels
)
self.apply(self.init_weights)
[docs]class QuantizedBertForQuestionAnswering(QuantizedBertPreTrainedModel, BertForQuestionAnswering):
def __init__(self, config):
# we only want BertForQuestionAnswering init to run to avoid unnecessary
# initializations
super(BertForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels
self.bert = QuantizedBertModel(config)
self.qa_outputs = quantized_linear_setup(
config, "head", config.hidden_size, config.num_labels
)
self.apply(self.init_weights)
[docs]class QuantizedBertForTokenClassification(QuantizedBertPreTrainedModel, BertForTokenClassification):
def __init__(self, config):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels
self.bert = QuantizedBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = quantized_linear_setup(
config, "head", config.hidden_size, config.num_labels
)
self.apply(self.init_weights)