Compression of Google Neural Machine Translation Model


Google Neural Machine Translation (GNMT) is a Sequence to sequence (Seq2seq) model which learns a mapping from an input text to an output text.

The example below demonstrates how to train a highly sparse GNMT model with minimal loss in accuracy. The model is based on the GNMT model presented in the paper Google’s Neural Machine Translation System: Bridging the Gap between Human and Machine Translation [1] which consists of approximately 210M floating point parameters.

GNMT Model

The GNMT architecture is an encoder-decoder architecture with attention as presented in the original paper [1].

The encoder consists of an embedding layer followed by 1 bi-directional and 3 uni-directional LSTM layers with residual connections between them. The decoder consists of an embedding layer followed by 4 uni-directional LSTM layers and a linear Softmax layer. The attention mechanism connects between the encoder’s bi-directional LSTM layer to all of the decoder’s LSTM layers.

The GNMT model was adapted from the model shown in Neural Machine Translation (seq2seq) Tutorial [2] and from its repository.

The Sparse model implementation can be found in GNMTModel and offers several options to build the GNMT model.

Sparsity - Pruning GNMT

Sparse neural networks are networks where a portion of the network weights are zeros. A high sparsity ratio can help compress the model and accelerate inference, reduce power consumption used for memory transfer and computing.

In order to produce a sparse network the network weights are pruned while training by forcing weights to be zero. There are a number of methods to prune neural networks, for example the paper To prune, or not to prune: exploring the efficacy of pruning for model compression [3] presents a method for gradual pruning of weights with low amplitude.

The example below demonstrates how to prune the GNMT model up to 90% sparsity with minimal loss in BLEU score using the Tensorflow model_pruning package which implements the method presented in [3]

Post Training Weight Quantization

The weights of pre-trained GNMT models are usually represented in 32bit Floating-point format. The highly sparse pre-trained model below can be further compressed by uniform quantization of the weights to 8bits Integer, gaining a further compression ratio of 4x with negligible accuracy loss. The implementation of the weight quantization is based on TensorFlow API. When using the model for inference, the int8 weights of the sparse and quantized model are de-quantized back to fp32.


The models below were trained using the following datasets:

  • Europarlv7 [4]
  • Common Crawl Corpus
  • News Commentary 11
  • Development and test sets

All datasets are provided by WMT Shared Task: Machine Translation of News

You can use this script to download and prepare the data for training and evaluating your model.

Results & Pre-Trained Models

The following table presents some of our experiments and results. We provide pre-trained checkpoints for a 90% sparse GNMT model and a similar 90% sparse but with 2x2 sparsity blocks pattern. See table below and our Model Zoo. You can use these models to Run Inference using our Pre-Trained Models and evaluate them.

Model Sparsity BLEU Non-Zero Parameters Data Type
Baseline 0% 29.9 ~210M Float32
Sparse 90% 28.4 ~22M Float32
2x2 Block Sparse 90% 27.8 ~22M Float32
Quantized Sparse 90% 28.4 ~22M Integer8
Quantized 2x2 Block Sparse 90% 27.6 ~22M Integer8
  1. The pruning is applied to the embedding, decoder projection layer and all LSTM layers in both the encoder and decoder.
  2. BLEU score is measured using newstest2015 test set provided by the Shared Task.
  3. The accuracy of the quantized model was measure when we converted the 8 bits weights back to floating point during inference.

Running Modalities

Below are simple examples for training 90% sparse GNMTModel model, running inference using a pre-trained/trained model, quantizing a model to 8bit Integer and running inference using a quantized model. Before inference, the int8 weights of the sparse and quantized model are de-quantize back to fp32.


Train a German to English GNMT model with 90% sparsity using the WMT16 dataset:

# Download the dataset /tmp/wmt16_en_de

# Go to examples directory
cd <nlp_architect root>/examples

# Train the sparse GNMT
python -m sparse_gnmt.nmt \
    --src=de --tgt=en \
    --hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
    --out_dir=<output directory> \
    --vocab_prefix=/tmp/wmt16_en_de/vocab.bpe.32000 \
    --train_prefix=/tmp/wmt16_en_de/train.tok.clean.bpe.32000 \
    --dev_prefix=/tmp/wmt16_en_de/newstest2013.tok.bpe.32000 \
  • Train using GPUs by adding --num_gpus=<n>
  • Model configuration JSON files are found in examples/sparse_gnmt/standard_hparams directory.
  • Sparsity policy can be re-configured by changing the parameters given in --pruning_hparams. E.g. change target_policy=0.7 in order to train 70% sparse GNMT.
  • All pruning hyper parameters are listed in model_pruning.

While training Tensorflow checkpoints, Tensorboard events, Hyper-Parameters used and log files will be saved in the output directory given.


Run inference using a trained model:

# Go to examples directory
cd <nlp_architect root>/examples

# Run Inference
python -m sparse_gnmt.nmt \
--src=de --tgt=en \
--hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
--ckpt=<path to a trained checkpoint> \
--vocab_prefix=/tmp/wmt16_en_de/vocab.bpe.32000 \
--out_dir=<output directory> \
--inference_input_file=<file with lines in the source language> \
--inference_output_file=<target file to place translations>
  • Measure performance and BLEU score against a reference file by adding --inference_ref_file=<reference file in the target language>
  • Inference using GPUs by adding --num_gpus=<n>

Run Inference using our Pre-Trained Models

Run inference using our pre-trained models:

# Download pre-trained model zip file, e.g.

# Unzip checkpoint + vocabulary files
unzip -d /tmp/gnmt_sparse_checkpoint

# Go to examples directory
cd <nlp_architect root>/examples

# Run Inference
python -m sparse_gnmt.nmt \
    --src=de --tgt=en \
    --hparams_path=sparse_gnmt/standard_hparams/sparse_wmt16_gnmt_4_layer.json \
    --vocab_prefix=/tmp/gnmt_sparse_checkpoint/vocab.bpe.32000 \
    --out_dir=<output directory> \
    --inference_input_file=<file with lines in the source language> \
    --inference_output_file=<target file to place translations>

Important Note: use the vocabulary files provided with the checkpoint when using our pre-trained models

Quantized Inference

Add the following flags to the Inference command line in order to quantize the pre-trained models and run inference with the quantized models:

  • --quantize_ckpt=true: Produce a quantized checkpoint. Checkpoint will be saved in the output directory. Inference will run using the produced checkpoint.
  • --from_quantized_ckpt=true: Inference using an already quantized checkpoint

Custom Training/Inference Parameters

All customizable parameters can be obtained by running: python -m nlp-architect.examples.sparse_gnmt.nmt -h

-h, --help show this help message and exit
--num_units NUM_UNITS
 Network size.
--num_layers NUM_LAYERS
 Network depth.
--num_encoder_layers NUM_ENCODER_LAYERS
 Encoder depth, equal to num_layers if None.
--num_decoder_layers NUM_DECODER_LAYERS
 Decoder depth, equal to num_layers if None.
--encoder_type uni | bi | gnmt. For bi, we build num_encoder_layers/2 bi-directional layers. For gnmt, we build 1 bi- directional layer, and (num_encoder_layers - 1) uni- directional layers.
--residual Whether to add residual connections.
--time_major Whether to use time-major mode for dynamic RNN.
--num_embeddings_partitions NUM_EMBEDDINGS_PARTITIONS
 Number of partitions for embedding vars.
--attention luong | scaled_luong | bahdanau | normed_bahdanau or set to “” for no attention
 standard | gnmt | gnmt_v2. standard: use top layer to compute attention. gnmt: GNMT style of computing attention, use previous bottom layer to compute attention. gnmt_v2: similar to gnmt, but use current bottom layer to compute attention.
 Only used in standard attention_architecture. Whether use attention as the cell output at each timestep.
 Whether to pass encoder’s hidden state to decoder when using an attention based model.
--optimizer sgd | adam
--learning_rate LEARNING_RATE
 Learning rate. Adam: 0.001 | 0.0001
--warmup_steps WARMUP_STEPS
 How many steps we inverse-decay learning.
 How to warmup learning rates. Options include: t2t: Tensor2Tensor’s way, start with lr 100 times smaller, then exponentiate until the specified lr.
--decay_scheme How we decay learning rate. Options include: luong234: after 2/3 num train steps, we start halving the learning rate for 4 times before finishing. luong5: after 1/2 num train steps, we start halving the learning rate for 5 times before finishing. luong10: after 1/2 num train steps, we start halving the learning rate for 10 times before finishing.
--num_train_steps NUM_TRAIN_STEPS
 Num steps to train.
 Whether try colocating gradients with corresponding op
--init_op uniform | glorot_normal | glorot_uniform
--init_weight INIT_WEIGHT
 for uniform init_op, initialize weights between .
--src SRC Source suffix, e.g., en.
--tgt TGT Target suffix, e.g., de.
--train_prefix TRAIN_PREFIX
 Train prefix, expect files with src/tgt suffixes.
--dev_prefix DEV_PREFIX
 Dev prefix, expect files with src/tgt suffixes.
--test_prefix TEST_PREFIX
 Test prefix, expect files with src/tgt suffixes.
--out_dir OUT_DIR
 Store log/model files.
--vocab_prefix VOCAB_PREFIX
 Vocab prefix, expect files with src/tgt suffixes.
--embed_prefix EMBED_PREFIX
 Pretrained embedding prefix, expect files with src/tgt suffixes. The embedding files should be Glove formatted txt files.
--sos SOS Start-of-sentence symbol.
--eos EOS End-of-sentence symbol.
--share_vocab Whether to use the source vocab and embeddings for both source and target.
--check_special_token CHECK_SPECIAL_TOKEN
 Whether check special sos, eos, unk tokens exist in the vocab files.
--src_max_len SRC_MAX_LEN
 Max length of src sequences during training.
--tgt_max_len TGT_MAX_LEN
 Max length of tgt sequences during training.
--src_max_len_infer SRC_MAX_LEN_INFER
 Max length of src sequences during inference.
--tgt_max_len_infer TGT_MAX_LEN_INFER
 Max length of tgt sequences during inference. Also used to restrict the maximum decoding length.
--unit_type lstm | gru | layer_norm_lstm | nas | mlstm
 dense | sparse
 dense | sparse
--forget_bias FORGET_BIAS
 Forget bias for BasicLSTMCell.
--dropout DROPOUT
 Dropout rate (not keep_prob)
--max_gradient_norm MAX_GRADIENT_NORM
 Clip gradients to this norm.
--batch_size BATCH_SIZE
 Batch size.
--steps_per_stats STEPS_PER_STATS
 How many training steps to do per stats logging.Save checkpoint every 10x steps_per_stats
--max_train MAX_TRAIN
 Limit on the size of training data (0: no limit).
--num_buckets NUM_BUCKETS
 Put data into similar-length buckets.
--num_sampled_softmax NUM_SAMPLED_SOFTMAX
 Use sampled_softmax_loss if > 0.Otherwise, use full softmax loss.
 Set to bpe or spm to activate subword desegmentation.
--use_char_encode USE_CHAR_ENCODE
 Whether to split each word or bpe into character, and then generate the word-level representation from the character representation.
--num_gpus NUM_GPUS
 Number of gpus in each worker.
 Debug GPU allocation.
--metrics METRICS
 Comma-separated list of evaluations metrics (bleu,rouge,accuracy)
--steps_per_external_eval STEPS_PER_EXTERNAL_EVAL
 How many training steps to do per external evaluation. Automatically set based on data if None.
--scope SCOPE scope to put variables under
--hparams_path HPARAMS_PATH
 Path to standard hparams json file that overrides hparams values from FLAGS.
--random_seed RANDOM_SEED
 Random seed (>0, set a specific seed).
 Override loaded hparams with values specified
--num_keep_ckpts NUM_KEEP_CKPTS
 Max number of checkpoints to keep.
--avg_ckpts Average the last N checkpoints for external evaluation. N can be controlled by setting –num_keep_ckpts.
 True to train a language model, ignoring encoder
--ckpt CKPT Checkpoint file to load a model for inference.
--quantize_ckpt QUANTIZE_CKPT
 Set to True to produce a quantized checkpoint from existing checkpoint
--from_quantized_ckpt FROM_QUANTIZED_CKPT
 Set to True when the given checkpoint is quantized
--inference_input_file INFERENCE_INPUT_FILE
 Set to the text to decode.
--inference_list INFERENCE_LIST
 A comma-separated list of sentence indices (0-based) to decode.
--infer_batch_size INFER_BATCH_SIZE
 Batch size for inference mode.
--inference_output_file INFERENCE_OUTPUT_FILE
 Output file to store decoding results.
--inference_ref_file INFERENCE_REF_FILE
 Reference file to compute evaluation scores (if provided).
--infer_mode Which type of decoder to use during inference.
--beam_width BEAM_WIDTH
 beam width when using beam search decoder. If 0 (default), use standard decoder with greedy helper.
--length_penalty_weight LENGTH_PENALTY_WEIGHT
 Length penalty for beam search.
--sampling_temperature SAMPLING_TEMPERATURE
 Softmax sampling temperature for inference decoding, 0.0 means greedy decoding. This option is ignored when using beam search.
--num_translations_per_input NUM_TRANSLATIONS_PER_INPUT
 Number of translations generated for each sentence. This is only used for inference.
--jobid JOBID Task id of the worker.
--num_workers NUM_WORKERS
 Number of workers (inference only).
--num_inter_threads NUM_INTER_THREADS
 number of inter_op_parallelism_threads
--num_intra_threads NUM_INTRA_THREADS
 number of intra_op_parallelism_threads
--pruning_hparams PRUNING_HPARAMS
 model pruning parameters


[1](1, 2) Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and Macherey, Klaus and others. Google’s neural machine translation system: Bridging the gap between human and machine translation.
[2]Minh-Thang Luong and Eugene Brevdo and Rui Zhao. Neural Machine Translation (seq2seq) Tutorial.
[3](1, 2) Zhu, Michael and Gupta, Suyog. To prune, or not to prune: exploring the efficacy of pruning for model compression.
[4]A Parallel Corpus for Statistical Machine Translation, Philipp Koehn, MT Summit 2005