From d41ca096d37012f6889e16f26b071815641a5299 Mon Sep 17 00:00:00 2001 From: pxc Date: Thu, 19 Oct 2023 11:41:59 +0000 Subject: [PATCH] support training and inference of early-exit LLM --- README.md | 539 ++----------- README_Megatron_LM.md | 526 +++++++++++++ examples/early_exit/1-3B.sh | 155 ++++ examples/early_exit/13B.sh | 157 ++++ examples/early_exit/30B.sh | 157 ++++ examples/early_exit/7B.sh | 157 ++++ examples/early_exit/ee_inference_server.sh | 42 + megatron/arguments.py | 77 +- megatron/checkpointing.py | 3 + megatron/core/inference_params.py | 2 + megatron/core/models/gpt/__init__.py | 2 +- megatron/core/parallel_state.py | 99 ++- megatron/core/pipeline_parallel/schedules.py | 715 ++++++++++++++++-- ...y => early_exit_text_generation_server.py} | 20 +- megatron/global_vars.py | 50 +- megatron/model/__init__.py | 2 +- ...t_gpt_model.py => early_exit_gpt_model.py} | 127 ++-- megatron/model/language_model.py | 161 +++- megatron/model/module.py | 2 +- megatron/model/transformer.py | 367 +++++---- megatron/text_generation/api.py | 74 +- megatron/text_generation/communication.py | 115 ++- megatron/text_generation/forward_step.py | 49 +- megatron/text_generation/generation.py | 249 +++++- megatron/text_generation/inference_params.py | 102 +++ megatron/training.py | 23 +- ..._exit_gpt.py => pretrain_early_exit_gpt.py | 24 +- pretrain_gpt.py | 48 +- tools/checkpoint/checkpoint_converter.py | 250 ++++++ tools/checkpoint/loader_llama2_hf.py | 2 +- tools/checkpoint/loader_megatron.py | 62 +- tools/checkpoint/saver_megatron.py | 96 ++- tools/checkpoint/util.py | 2 +- tools/met_server.sh | 45 -- tools/prompt_example.jsonl | 2 + tools/request_client.py | 100 +-- ... run_early_exit_text_generation_server.py} | 24 +- 37 files changed, 3595 insertions(+), 1032 deletions(-) create mode 100644 README_Megatron_LM.md create mode 100755 examples/early_exit/1-3B.sh create mode 100755 examples/early_exit/13B.sh create mode 100755 examples/early_exit/30B.sh create mode 100755 examples/early_exit/7B.sh create mode 100755 examples/early_exit/ee_inference_server.sh rename megatron/{multi_exit_text_generation_server.py => early_exit_text_generation_server.py} (94%) rename megatron/model/{multi_exit_gpt_model.py => early_exit_gpt_model.py} (60%) create mode 100644 megatron/text_generation/inference_params.py rename pretrain_multi_exit_gpt.py => pretrain_early_exit_gpt.py (84%) create mode 100644 tools/checkpoint/checkpoint_converter.py delete mode 100755 tools/met_server.sh create mode 100644 tools/prompt_example.jsonl rename tools/{run_multi_exit_text_generation_server.py => run_early_exit_text_generation_server.py} (72%) diff --git a/README.md b/README.md index dfe29ffb..70b22c42 100644 --- a/README.md +++ b/README.md @@ -1,526 +1,95 @@ -Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel ([tensor](https://arxiv.org/pdf/1909.08053.pdf), [sequence](https://arxiv.org/pdf/2205.05198), and [pipeline](https://arxiv.org/pdf/2104.04473.pdf)), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision. - -Below are some of the projects where we have directly used Megatron: -* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf) -* [BioMegatron: Larger Biomedical Domain Language Model](https://www.aclweb.org/anthology/2020.emnlp-main.379.pdf) -* [End-to-End Training of Neural Retrievers for Open-Domain Question Answering](https://arxiv.org/abs/2101.00408) -* [Large Scale Multi-Actor Generative Dialog Modeling](https://www.aclweb.org/anthology/2020.acl-main.8.pdf) -* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150) -* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf) -* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html) -* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) -* [Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases](https://arxiv.org/abs/2112.07868) -* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173) -* [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](https://arxiv.org/abs/2201.11990) -* [Multi-Stage Prompting for Knowledgeable Dialogue Generation](https://arxiv.org/abs/2203.08745) -* [Evaluating Parameter Efficient Learning for Generation](https://aclanthology.org/2022.emnlp-main.319.pdf) - -Megatron is also used in [NeMo Megatron](https://developer.nvidia.com/nvidia-nemo#nemo-megatron), a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters. - -Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The graph below shows that we scale nearly linear up to 1 trillion parameter models running on 3072 GPUs. Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. - -![Scaling Graph](images/Achieved_petaFLOPs.png) - -The following table shows both model (MFU) and hardware (HFU) FLOPs utilization for select configurations up to 1T parameters (see [our paper](https://arxiv.org/pdf/2205.05198) for a description of how these are calculated). As the model size increases, we achieve better GPU utilization and for the one trillion parameter model, we reach a MFU and HFU of 56.3% and 57.0%, respectively. Note that these numbers are also measured on benchmark runs and in this case are measured using a data parallel size of one. Data parallelism introduces some overhead due to the gradient all-reduce required between the data parallel groups. However, for large transformer models, this overhead is not large and can almost entirely eliminated by overlapping the gradient all-reduce with backpropagation. - -| Model Size | Model FLOPs Utilization | Hardware FLOPs Utilization | -| :---: | :---: | :---: | -| 22B | 41.5% | 43.7% | -| 175B | 51.4% | 52.8% | -| 530B | 56.0% | 57.0% | -| 1T | 56.3% | 57.0% | - -# Contents - * [Contents](#contents) - * [Setup](#setup) - * [Downloading Checkpoints](#downloading-checkpoints) - * [Usage](#usage) - * [Training](#training) - * [Data Preprocessing](#data-preprocessing) - * [BERT Pretraining](#bert-pretraining) - * [GPT Pretraining](#gpt-pretraining) - * [T5 Pretraining](#t5-pretraining) - * [Distributed Pretraining](#distributed-pretraining) - * [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) - * [Distributed Optimizer](#distributed-optimizer) - * [FlashAttention](#flashattention) - * [GPT-3 Example](#gpt-3-example) - * [Retro](#retro) - * [Evaluation and Tasks](#evaluation-and-tasks) - * [GPT Text Generation](#gpt-text-generation) - * [GPT Evaluation](#gpt-evaluation) - * [WikiText Perplexity Evaluation](#wikitext-perplexity-evaluation) - * [LAMBADA Cloze Accuracy](#lambada-cloze-accuracy) - * [BERT Task Evaluation](#bert-task-evaluation) - * [RACE Evaluation](#race-evaluation) - * [MNLI Evaluation](#mnli-evaluation) - * [Llama-2 Inference and Finetuning](#llama-2-inference-and-finetuning) - * [Datasets](#datasets) - * [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) - * [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) - * [Reproducibility](#reproducibility) - -# Setup -We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. - -You can launch an instance of the PyTorch container and mount Megatron, your dataset, and checkpoints with the following Docker commands: -``` -docker pull nvcr.io/nvidia/pytorch:xx.xx-py3 -docker run --gpus all -it --rm -v /path/to/megatron:/workspace/megatron -v /path/to/dataset:/workspace/dataset -v /path/to/checkpoints:/workspace/checkpoints nvcr.io/nvidia/pytorch:xx.xx-py3 -``` - -## Downloading Checkpoints -We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). - -Alternatively, you can directly download the checkpoints using: - -
-BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip
-BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip
-GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
-
- -The models require vocabulary files to run. The BERT WordPiece vocab file can be extracted from Google's pretrained BERT models: [uncased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt), [cased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt). The GPT [vocab file](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json) and [merge table](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt) can be downloaded directly. - -# Usage - -After installation, there are several possible workflows. The most comprehensive is: -1. Data preprocessing -2. Pretraining -3. Finetuning (Optional for zero-shot tasks) -4. Downstream task evaluation or text generation - -However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above. - -We've provided several scripts for pretraining both BERT and GPT in [`examples`](./examples) directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation. - -# Training -## Data Preprocessing -The training data requires preprocessing. First, place your training data in a loose json format, with one json containing a text sample per line. For example: -
-{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
-{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
-
- -The name of the `text` field of the json can be changed by using the `--json-key` flag in [`preprocess_data.py`](./tools/preprocess_data.py) The other metadata are optional and are not used in training. +# README -The loose json is then processed into a binary format for training. To convert the json into mmap format use `preprocess_data.py`. An example script to prepare data for BERT training is: -
-python tools/preprocess_data.py \
-       --input my-corpus.json \
-       --output-prefix my-bert \
-       --vocab-file bert-vocab.txt \
-       --tokenizer-type BertWordPieceLowerCase \
-       --split-sentences
-
+EE-LLM is a framework for large-scale training and inference of early-exit (EE) large language models (LLMs), which is built upon [Megatron-LM](https://github.com/NVIDIA/Megatron-LM). -The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension. -For T5 use the same preprocessing as BERT, perhaps renaming it to: -
-       --output-prefix my-t5 \
-
+## Installation -Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type: -
-python tools/preprocess_data.py \
-       --input my-corpus.json \
-       --output-prefix my-gpt2 \
-       --vocab-file gpt2-vocab.json \
-       --tokenizer-type GPT2BPETokenizer \
-       --merge-file gpt2-merges.txt \
-       --append-eod
-
+The installation of EE-LLM is the same as Megatron-LM. +We recommand using the 22.12 version of [NGC's PyTorch container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) (nvcr.io/nvidia/pytorch:22.12-py3), which is also the development environment of EE-LLM. -Here the output files are named `my-gpt2_text_document.bin` and `my-gpt2_text_document.idx`. As before, in GPT training, use the longer name without the extension as `--data-path`. +For more details about the installation of Megatron-LM, please refer to Megatron-LM's [README](README_Megatron_LM.md). -Further command line arguments are described in the source file [`preprocess_data.py`](./tools/preprocess_data.py). -## BERT Pretraining +## Training +Below are several example training scripts used in our paper. -The [`examples/pretrain_bert.sh`](./examples/pretrain_bert.sh) script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at `--lr` to a minimum set by `--min-lr` over `--lr-decay-iters` iterations. The fraction of training iterations used for warmup is set by `--lr-warmup-fraction`. While this is single GPU training, the batch size specified by `--micro-batch-size` is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reaches `global-batch-size` which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with `--seed`). We use `train-iters` as the training iterations requested. Alternatively, one can provide `--train-samples` which is total number of samples to train on. If this option is present, then instead of providing `--lr-decay-iters`, one will need to provide `--lr-decay-samples`. -The logging, checkpoint-saving, and evaluation intervals are specified. Checkpointing the activations facilitates the training of larger models and/or batches. Note that the `--data-path` now includes the additional `_text_sentence` suffix added in preprocessing, but does not include the file extensions. - -Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). - -To run `examples/pretrain_bert.sh`, make any desired modifications including setting the environment variables for `CHECKPOINT_PATH`, `VOCAB_FILE`, and `DATA_PATH`. Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained in [Setup](#setup)) and run the example script. - -## GPT Pretraining - -The `examples/pretrain_gpt.sh` script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training. - -It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and a `json` vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the `--lr-decay-style` has been set to cosine decay. Note that the `--data-path` now includes the additional `_text_document` suffix added in preprocessing, but does not include the file extensions. - -Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). - -`examples/pretrain_gpt.sh` can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script. +``` +# train 1.3B model +./examples/early_exit/1-3B.sh -## T5 Pretraining +# train 7B model +./examples/early_exit/7B.sh -Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture: +# train 13B model +./example/early_exit/13B.sh -* `--kv-channels` sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5. +# train 30B model +./example/early_exit/30B.sh +``` -* `--ffn-hidden-size` sets the hidden size in the feed-forward networks within a transformer layer. For BERT and GPT this defaults to 4 times the transformer hidden size, but can be configured for T5. -* `--encoder-seq-length` and `--decoder-seq-length` set the sequence length for the encoder and decoder separately. +The training data used in these scripts can be found in [Data-Juicer](https://github.com/alibaba/data-juicer/blob/main/configs/data_juicer_recipes/README.md). +You can modify the `DATA_PATH` environment variable in the scripts to use your own dataset. +Note that Megatron-LM can only recognize preprocessed binary data; +for more details about Megatron-LM's data preprocessing, please refer to [Data Preprocessing](README_Megatron_LM.md) -All of the other arguments remain as they were for BERT and GPT pretraining. Run this example with the same steps described above for the other scripts. +> Running the training scripts requires 16 Nvidia A100-80G GPUs or higher hardware specifications. To run them with fewer GPUs, please set the parallelism degrees therein to smaller values. -## Distributed Pretraining -The `examples/pretrain_{bert,gpt,t5}_distributed.sh` scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorch [documentation](https://pytorch.org/docs/stable/elastic/run.html#launcher-api) for further description of these [environment variables](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization). By default, multi-node training uses the [nccl](https://developer.nvidia.com/nccl) distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the `torchrun` elastic launcher (equivalent to `python -m torch.distributed.run`) are the only additional requirements to adopt distributed training. See any of `examples/pretrain_{bert,gpt,t5}_distributed.sh` for more details. +Below are the new configurations of EE-LLM compared to Megatron-LM. You can customize your own early-exit LLM by modifying these configurations. -We use two types of parallelism: data and model parallelism. We facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. +### Configurations for model architectures -Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 of [our paper](https://arxiv.org/pdf/1909.08053.pdf)), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use sequence parallelism specify `--sequence-parallel`, which requires tensor model parallel as it split among the same GPUs (more details in Section 4.2.2 of [our paper](https://arxiv.org/pdf/2205.05198.pdf)). +- `--exit-layer-nums`: indices of the Transformer layers converted to early-exit Transformer layers, starting from 1. + > For example, `--exit-layer-nums 6 12` will add early exits to the 6th and 12th Transformer layers. -To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches, see Section 2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each). +- `--pre-exit`: If set, the early-exit modules will be placed before the backbone of the Transformer layer, otherwise they will be placed after the backbone by default. + > For example, the overall model architectures represented by `--exit-layer-nums 6 12` and `--exit-layer-nums 7 13 --pre-exit` are the same. - +- `--untie-exit-output-weights`: If set, each early exit uses a different output word embedding, otherwise all early exits share the same output word embedding. -We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`: +- `--use-exit-norm`: If set, add a Norm layer before the early-exit output word embedding. -Other than these minor changes, the distributed training is identical to the training on a single GPU. +- `--use-exit-mlp`: If set, add a MLP layer before the early-exit output word embedding. -The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)) can be enabled using the `--num-layers-per-virtual-pipeline-stage` argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with `NUM_LAYERS / PIPELINE_MP_SIZE` transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as `GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)`) should be divisible by the `PIPELINE_MP_SIZE` when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (`PIPELINE_MP_SIZE=2`). +- `--use-exit-block`: If set, add a complete Transformer layer before the early-exit output word embedding. -## Activation Checkpointing and Recomputation +### Configurations for training -To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We support two levels of recompute granularity: `selective` and `full`. Selective recomputation is the default and recommended in almost all cases. It saves the activations that take less space and are expensive to recompute and recomputes activations that take a lot of space but are relatively cheap to recompute (see [our paper](https://arxiv.org/pdf/2205.05198) for details). To enable selective activation recompute simply use `--recompute-activations`. +- `--exit-layer-weight`: The targeted loss weights of early exits. Must correspond to `--exit-layer-nums` one-to-one. Default to 1.0. -For cases where memory is very tight, `full` checkpointing saves just the inputs to a transformer layer, or a block of transformer layers, and recomputes everything else. To turn on full activation recompute use `--recompute-granularity full`. When using full activation recomputation, there are two methods: `uniform` and `block`, chosen using the `--recompute-method` argument. +- `--exit-layer-weight-init`: The initial loss weights of early exits, which can be lower or higher than `--exit-layer-weight`. -* Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed. +- `--exit-layer-weight-warmup-iters`: The number of warm-up/cool-down iterations for early-exit loss weights (from `weight-init` to `weight`), default to 0. -* Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop. +- `--exit-layer-weight-warmup-style`: The increment function of early-exit loss weights, default to linear. +- `--fill-explicit-bubbles`: Enable filling explicit bubbles of the 1F1B pipeline schedule with additional microbatches. [Experimental] -## Distributed Optimizer +- `--num-fill-warmup-microbatches`: The number of microbatches to be inserted during the warm-up phase of the 1F1B schedule. [Experimental] -Usage: `--use-distributed-optimizer`. Compatible with all model and data types. +- `--num-fill-cooldown-microbatches`: The number of microbatches to be inserted during the cool-down phase of the 1F1B schedule. [Experimental] -The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described in [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054), our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed (although the fp32 main params are still distributed, as they are separate from the bf16 model params). +- `--backward-forward-ratio`: An estimate of the ratio of time consumption between backward and forward computation during training, used to automatically calculate the optimal number of inserted microbatches. Default to 2.0. [Experimental] -Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size): +## Inference -| | Non-distributed optim | Distributed optim | -|-|-|-| -| fp16 param, fp16 grads | 20 | 4 + 16/d | -| bf16 param, fp32 grads | 18 | 6 + 12/d | -| fp32 param, fp32 grads | 16 | 8 + 8/d | +We provided an text generation server for inference of early-exit LLMs. +To start a server, you can use the following script. +Before running, please set `CHECKPOINT_PATH` to the root folder path of the checkpoint, and set `TP` and `PP` appropriately according to the parallelism of the checkpoint. -## FlashAttention +``` +./example/early_exit/ee_inference_server.sh +``` -Usage: `--use-flash-attn`. Support attention head dimensions at most 128. +After the server is started, you can use `tools/request_client.py` to send requests to the server. +Below are some parameters for early-exit LLM inference, which can be found in `tools/request_client.py`. -[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and -memory-efficient algorithm to compute exact attention. It speeds up model -training and reduces memory requirement. +- `use_early_exit`: The early-exit feature is only enabled when this option is set, otherwise the model behaves exactly like a standard model without early exits. -To install FlashAttention: -```sh -pip install flash-attn -``` +- `early_exit_thres`: The confidence threshold used to determine whether to execute early exiting, ranging from 0.0 to 1.0. -## GPT-3 Example - -In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incremental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights. - -With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. - - -## Retro - -See: - -- `tools/retro/README.md` for an overview. -- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments. -- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data. -- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model. - -Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters. - -Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview. - - - -# Evaluation and Tasks - -We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. - -Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism. - -
-python tools/checkpoint/util.py \
-        --model-type GPT \
-        --load-dir checkpoints/gpt3_tp4_pp4 \
-        --save-dir checkpoints/gpt3_tp2_pp2 \
-        --target-tensor-parallel-size 2 \
-        --target-pipeline-parallel-size 2
-
-
- -Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. - -## GPT Text Generation - -We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server. - -Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. - -
-tools/text_generation_cli.py localhost:5000
-
- -You can also use CURL or any other tools to query the server directly: - -
-curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
-
- -See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options. - -### Detoxify GPT via Self-generation -We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. - -See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. - - -## GPT Evaluation -We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy. - -### WikiText Perplexity Evaluation -For even comparison with prior works, we evaluate perplexity on the word-level [WikiText-103 test dataset](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), and appropriately compute perplexity given the change in tokens when using our subword tokenizer. - -We use the following command to run WikiText-103 evaluation on a 345M parameter model. -
-TASK="WIKITEXT103"
-
-VALID_DATA=<wikitext path>.txt
-VOCAB_FILE=gpt2-vocab.json
-MERGE_FILE=gpt2-merges.txt
-CHECKPOINT_PATH=checkpoints/gpt2_345m
-
-COMMON_TASK_ARGS="--num-layers 24 \
-                  --hidden-size 1024 \
-                  --num-attention-heads 16 \
-                  --seq-length 1024 \
-                  --max-position-embeddings 1024 \
-                  --fp16 \
-                  --vocab-file $VOCAB_FILE"
-
-python tasks/main.py \
-       --task $TASK \
-       $COMMON_TASK_ARGS \
-       --valid-data $VALID_DATA \
-       --tokenizer-type GPT2BPETokenizer \
-       --merge-file $MERGE_FILE \
-       --load $CHECKPOINT_PATH \
-       --micro-batch-size 8 \
-       --log-interval 10 \
-       --no-load-optim \
-       --no-load-rng
-
- - -### LAMBADA Cloze Accuracy -To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl). - -We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path. - -
-TASK="LAMBADA"
-
-VALID_DATA=<lambada path>.json
-VOCAB_FILE=gpt2-vocab.json
-MERGE_FILE=gpt2-merges.txt
-CHECKPOINT_PATH=checkpoints/gpt2_345m
-COMMON_TASK_ARGS=<same as those in WikiText Perplexity Evaluation above>
-
-python tasks/main.py \
-       --task $TASK \
-       $COMMON_TASK_ARGS \
-       --valid-data $VALID_DATA \
-       --tokenizer-type GPT2BPETokenizer \
-       --strict-lambada \
-       --merge-file $MERGE_FILE \
-       --load $CHECKPOINT_PATH \
-       --micro-batch-size 8 \
-       --log-interval 10 \
-       --no-load-optim \
-       --no-load-rng
-
- -Further command line arguments are described in the source file [`main.py`](./tasks/main.py) - -## BERT Task Evaluation -### RACE Evaluation -The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line. - -
-TRAIN_DATA="data/RACE/train/middle"
-VALID_DATA="data/RACE/dev/middle \
-            data/RACE/dev/high"
-VOCAB_FILE=bert-vocab.txt
-PRETRAINED_CHECKPOINT=checkpoints/bert_345m
-CHECKPOINT_PATH=checkpoints/bert_345m_race
-COMMON_TASK_ARGS="--num-layers 24 \
-                  --hidden-size 1024 \
-                  --num-attention-heads 16 \
-                  --seq-length 512 \
-                  --max-position-embeddings 512 \
-                  --fp16 \
-                  --vocab-file $VOCAB_FILE"
-
-COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
-                      --valid-data $VALID_DATA \
-                      --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
-                      --save-interval 10000 \
-                      --save $CHECKPOINT_PATH \
-                      --log-interval 100 \
-                      --eval-interval 1000 \
-                      --eval-iters 10 \
-                      --weight-decay 1.0e-1"
-
-python tasks/main.py \
-       --task RACE \
-       $COMMON_TASK_ARGS \
-       $COMMON_TASK_ARGS_EXT \
-       --tokenizer-type BertWordPieceLowerCase \
-       --epochs 3 \
-       --micro-batch-size 4 \
-       --lr 1.0e-5 \
-       --lr-warmup-fraction 0.06
-
- -### MNLI Evaluation -The following script finetunes the BERT model for evaluation with the [MultiNLI sentence pair corpus](https://www.nyu.edu/projects/bowman/multinli/). Because the matching tasks are quite similar, the script can be quickly tweaked to work with the [Quora Question Pairs](https://www.kaggle.com/quora/question-pairs-dataset) (QQP) dataset as well. - -
-
-TRAIN_DATA="data/glue_data/MNLI/train.tsv"
-VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
-            data/glue_data/MNLI/dev_mismatched.tsv"
-PRETRAINED_CHECKPOINT=checkpoints/bert_345m
-VOCAB_FILE=bert-vocab.txt
-CHECKPOINT_PATH=checkpoints/bert_345m_mnli
-COMMON_TASK_ARGS=<same as those in RACE Evaluation above>
-COMMON_TASK_ARGS_EXT=<same as those in RACE Evaluation above>
-
-python tasks/main.py \
-       --task MNLI \
-       $COMMON_TASK_ARGS \
-       $COMMON_TASK_ARGS_EXT \
-       --tokenizer-type BertWordPieceLowerCase \
-       --epochs 5 \
-       --micro-batch-size 8 \
-       --lr 5.0e-5 \
-       --lr-warmup-fraction 0.065
-
- -## Llama-2 Inference and Finetuning - -The Llama-2 [family of models](https://ai.meta.com/llama/) are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At the time of release, Llama-2 models achieved among the best results for open-source models, and were competitive with the closed-source GPT-3.5 model (see https://arxiv.org/pdf/2307.09288.pdf). - -The Llama-2 checkpoints can be loaded into Megatron for inference and finetuning. See documentation [here](docs/llama2.md). - -# Datasets -We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced. - -## Collecting Wikipedia Training Data -We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text." - -We recommend using the `--json` argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset by nltk punctuation standardization. For BERT training, use the `--split-sentences` flag to `preprocess_data.py` as described [above](#data-preprocessing) to include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the `--split-sentences` flag. - -## Collecting GPT Webtext Data -We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content. - -# Reproducibility -Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary). - -There are currently two known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. The following workarounds should be applied in cases where reproducibility is required: -1. When training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used. -2. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`. - -These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue. +- `print_max_prob`: If set, the inference server will print the token with the highest confidence and the confidence values at all exits. \ No newline at end of file diff --git a/README_Megatron_LM.md b/README_Megatron_LM.md new file mode 100644 index 00000000..dfe29ffb --- /dev/null +++ b/README_Megatron_LM.md @@ -0,0 +1,526 @@ +Megatron ([1](https://arxiv.org/pdf/1909.08053.pdf), [2](https://arxiv.org/pdf/2104.04473.pdf), and [3](https://arxiv.org/pdf/2205.05198)) is a large, powerful transformer developed by the Applied Deep Learning Research team at NVIDIA. This repository is for ongoing research on training large transformer language models at scale. We developed efficient, model-parallel ([tensor](https://arxiv.org/pdf/1909.08053.pdf), [sequence](https://arxiv.org/pdf/2205.05198), and [pipeline](https://arxiv.org/pdf/2104.04473.pdf)), and multi-node pre-training of transformer based models such as [GPT](https://arxiv.org/abs/2005.14165), [BERT](https://arxiv.org/pdf/1810.04805.pdf), and [T5](https://arxiv.org/abs/1910.10683) using mixed precision. + +Below are some of the projects where we have directly used Megatron: +* [BERT and GPT Studies Using Megatron](https://arxiv.org/pdf/1909.08053.pdf) +* [BioMegatron: Larger Biomedical Domain Language Model](https://www.aclweb.org/anthology/2020.emnlp-main.379.pdf) +* [End-to-End Training of Neural Retrievers for Open-Domain Question Answering](https://arxiv.org/abs/2101.00408) +* [Large Scale Multi-Actor Generative Dialog Modeling](https://www.aclweb.org/anthology/2020.acl-main.8.pdf) +* [Local Knowledge Powered Conversational Agents](https://arxiv.org/abs/2010.10150) +* [MEGATRON-CNTRL: Controllable Story Generation with External Knowledge Using Large-Scale Language Models](https://www.aclweb.org/anthology/2020.emnlp-main.226.pdf) +* [RACE Reading Comprehension Dataset Leaderboard](http://www.qizhexie.com/data/RACE_leaderboard.html) +* [Training Question Answering Models From Synthetic Data](https://www.aclweb.org/anthology/2020.emnlp-main.468.pdf) +* [Few-shot Instruction Prompts for Pretrained Language Models to Detect Social Biases](https://arxiv.org/abs/2112.07868) +* [Exploring the Limits of Domain-Adaptive Training for Detoxifying Large-Scale Language Models](https://arxiv.org/abs/2202.04173) +* [Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model](https://arxiv.org/abs/2201.11990) +* [Multi-Stage Prompting for Knowledgeable Dialogue Generation](https://arxiv.org/abs/2203.08745) +* [Evaluating Parameter Efficient Learning for Generation](https://aclanthology.org/2022.emnlp-main.319.pdf) + +Megatron is also used in [NeMo Megatron](https://developer.nvidia.com/nvidia-nemo#nemo-megatron), a framework to help enterprises overcome the challenges of building and training sophisticated natural language processing models with billions and trillions of parameters. + +Our codebase is capable of efficiently training very large (hundreds of billions of parameters) language models with both model and data parallelism. To demonstrate how the code scales with multiple GPUs and model sizes, we consider GPT models from 1 billion all the way to 1 trillion parameters. All models use a vocabulary size of 51,200 and a sequence length of 2048. We vary hidden size, number of attention heads, and number of layers to arrive at a specific model size. As the model size increases, we also modestly increase the batch size. We leverage [NVIDIA's Selene supercomputer](https://www.top500.org/system/179842/) to perform scaling studies and use up to 3072 [A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for the largest model. Each cluster node has 8 NVIDIA 80GB A100 GPUs. The graph below shows that we scale nearly linear up to 1 trillion parameter models running on 3072 GPUs. Note that these results are from benchmark runs and these models were not trained to convergence; however, the FLOPs are measured for end-to-end training, i.e., includes all operations including data loading, optimization, and even logging. + +![Scaling Graph](images/Achieved_petaFLOPs.png) + +The following table shows both model (MFU) and hardware (HFU) FLOPs utilization for select configurations up to 1T parameters (see [our paper](https://arxiv.org/pdf/2205.05198) for a description of how these are calculated). As the model size increases, we achieve better GPU utilization and for the one trillion parameter model, we reach a MFU and HFU of 56.3% and 57.0%, respectively. Note that these numbers are also measured on benchmark runs and in this case are measured using a data parallel size of one. Data parallelism introduces some overhead due to the gradient all-reduce required between the data parallel groups. However, for large transformer models, this overhead is not large and can almost entirely eliminated by overlapping the gradient all-reduce with backpropagation. + +| Model Size | Model FLOPs Utilization | Hardware FLOPs Utilization | +| :---: | :---: | :---: | +| 22B | 41.5% | 43.7% | +| 175B | 51.4% | 52.8% | +| 530B | 56.0% | 57.0% | +| 1T | 56.3% | 57.0% | + +# Contents + * [Contents](#contents) + * [Setup](#setup) + * [Downloading Checkpoints](#downloading-checkpoints) + * [Usage](#usage) + * [Training](#training) + * [Data Preprocessing](#data-preprocessing) + * [BERT Pretraining](#bert-pretraining) + * [GPT Pretraining](#gpt-pretraining) + * [T5 Pretraining](#t5-pretraining) + * [Distributed Pretraining](#distributed-pretraining) + * [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation) + * [Distributed Optimizer](#distributed-optimizer) + * [FlashAttention](#flashattention) + * [GPT-3 Example](#gpt-3-example) + * [Retro](#retro) + * [Evaluation and Tasks](#evaluation-and-tasks) + * [GPT Text Generation](#gpt-text-generation) + * [GPT Evaluation](#gpt-evaluation) + * [WikiText Perplexity Evaluation](#wikitext-perplexity-evaluation) + * [LAMBADA Cloze Accuracy](#lambada-cloze-accuracy) + * [BERT Task Evaluation](#bert-task-evaluation) + * [RACE Evaluation](#race-evaluation) + * [MNLI Evaluation](#mnli-evaluation) + * [Llama-2 Inference and Finetuning](#llama-2-inference-and-finetuning) + * [Datasets](#datasets) + * [Collecting Wikipedia Training Data](#collecting-wikipedia-training-data) + * [Collecting GPT Webtext Data](#collecting-gpt-webtext-data) + * [Reproducibility](#reproducibility) + +# Setup +We strongly recommend using the latest release of [NGC's PyTorch container](https://ngc.nvidia.com/catalog/containers/nvidia:pytorch) with DGX nodes. If you can't use this for some reason, use the latest pytorch, cuda, nccl, and NVIDIA [APEX](https://github.com/NVIDIA/apex#quick-start) releases. Data preprocessing requires [NLTK](https://www.nltk.org/install.html), though this is not required for training, evaluation, or downstream tasks. + +You can launch an instance of the PyTorch container and mount Megatron, your dataset, and checkpoints with the following Docker commands: +``` +docker pull nvcr.io/nvidia/pytorch:xx.xx-py3 +docker run --gpus all -it --rm -v /path/to/megatron:/workspace/megatron -v /path/to/dataset:/workspace/dataset -v /path/to/checkpoints:/workspace/checkpoints nvcr.io/nvidia/pytorch:xx.xx-py3 +``` + +## Downloading Checkpoints +We have provided pretrained [BERT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_bert_345m) and [GPT-345M](https://ngc.nvidia.com/catalog/models/nvidia:megatron_lm_345m) checkpoints for use to evaluate or finetuning downstream tasks. To access these checkpoints, first [sign up](https://ngc.nvidia.com/signup) for and [setup](https://ngc.nvidia.com/setup/installers/cli) the NVIDIA GPU Cloud (NGC) Registry CLI. Further documentation for downloading models can be found in the [NGC documentation](https://docs.nvidia.com/dgx/ngc-registry-cli-user-guide/index.html#topic_6_4_1). + +Alternatively, you can directly download the checkpoints using: + +
+BERT-345M-uncased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_uncased/zip -O megatron_bert_345m_v0.1_uncased.zip
+BERT-345M-cased: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_bert_345m/versions/v0.1_cased/zip -O megatron_bert_345m_v0.1_cased.zip
+GPT-345M: wget --content-disposition https://api.ngc.nvidia.com/v2/models/nvidia/megatron_lm_345m/versions/v0.0/zip -O megatron_lm_345m_v0.0.zip
+
+ +The models require vocabulary files to run. The BERT WordPiece vocab file can be extracted from Google's pretrained BERT models: [uncased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt), [cased](https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt). The GPT [vocab file](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json) and [merge table](https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt) can be downloaded directly. + +# Usage + +After installation, there are several possible workflows. The most comprehensive is: +1. Data preprocessing +2. Pretraining +3. Finetuning (Optional for zero-shot tasks) +4. Downstream task evaluation or text generation + +However, steps 1 and 2 can be replaced by using one of the pretrained models mentioned above. + +We've provided several scripts for pretraining both BERT and GPT in [`examples`](./examples) directory, as well as scripts for both zero-shot and fine-tuned downstream tasks including MNLI, RACE, WikiText103, and LAMBADA evaluation. There is also a script for GPT interactive text generation. + +# Training +## Data Preprocessing +The training data requires preprocessing. First, place your training data in a loose json format, with one json containing a text sample per line. For example: +
+{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
+{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
+
+ +The name of the `text` field of the json can be changed by using the `--json-key` flag in [`preprocess_data.py`](./tools/preprocess_data.py) The other metadata are optional and are not used in training. + +The loose json is then processed into a binary format for training. To convert the json into mmap format use `preprocess_data.py`. An example script to prepare data for BERT training is: +
+python tools/preprocess_data.py \
+       --input my-corpus.json \
+       --output-prefix my-bert \
+       --vocab-file bert-vocab.txt \
+       --tokenizer-type BertWordPieceLowerCase \
+       --split-sentences
+
+ +The output will be two files named, in this case, `my-bert_text_sentence.bin` and `my-bert_text_sentence.idx`. The `--data-path` specified in later BERT training is the full path and new filename, but without the file extension. + +For T5 use the same preprocessing as BERT, perhaps renaming it to: +
+       --output-prefix my-t5 \
+
+ +Some minor modifications are required for GPT data preprocessing, namely, the addition of a merge table, an end-of-document token, removal of sentence splitting, and a change to the tokenizer type: +
+python tools/preprocess_data.py \
+       --input my-corpus.json \
+       --output-prefix my-gpt2 \
+       --vocab-file gpt2-vocab.json \
+       --tokenizer-type GPT2BPETokenizer \
+       --merge-file gpt2-merges.txt \
+       --append-eod
+
+ +Here the output files are named `my-gpt2_text_document.bin` and `my-gpt2_text_document.idx`. As before, in GPT training, use the longer name without the extension as `--data-path`. + +Further command line arguments are described in the source file [`preprocess_data.py`](./tools/preprocess_data.py). + +## BERT Pretraining + + +The [`examples/pretrain_bert.sh`](./examples/pretrain_bert.sh) script runs single GPU 345M parameter BERT pretraining. Debugging is the primary use for single GPU training, as the code base and command line arguments are optimized for highly distributed training. Most of the arguments are fairly self-explanatory. By default, the learning rate decays linearly over the training iterations starting at `--lr` to a minimum set by `--min-lr` over `--lr-decay-iters` iterations. The fraction of training iterations used for warmup is set by `--lr-warmup-fraction`. While this is single GPU training, the batch size specified by `--micro-batch-size` is a single forward-backward path batch-size and the code will perform gradient accumulation steps until it reaches `global-batch-size` which is the batch size per iteration. The data is partitioned into a 949:50:1 ratio for training/validation/test sets (default is 969:30:1). This partitioning happens on the fly, but is consistent across runs with the same random seed (1234 by default, or specified manually with `--seed`). We use `train-iters` as the training iterations requested. Alternatively, one can provide `--train-samples` which is total number of samples to train on. If this option is present, then instead of providing `--lr-decay-iters`, one will need to provide `--lr-decay-samples`. + +The logging, checkpoint-saving, and evaluation intervals are specified. Checkpointing the activations facilitates the training of larger models and/or batches. Note that the `--data-path` now includes the additional `_text_sentence` suffix added in preprocessing, but does not include the file extensions. + +Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). + +To run `examples/pretrain_bert.sh`, make any desired modifications including setting the environment variables for `CHECKPOINT_PATH`, `VOCAB_FILE`, and `DATA_PATH`. Make sure to set these variables to their paths in the container. Then launch the container with Megatron and necessary paths mounted (as explained in [Setup](#setup)) and run the example script. + +## GPT Pretraining + +The `examples/pretrain_gpt.sh` script runs single GPU 345M parameter GPT pretraining. As mentioned above, single GPU training is primarily intended for debugging purposes, as the code is optimized for distributed training. + +It follows largely the same format as the previous BERT script with a few notable differences: the tokenization scheme used is BPE (which requires a merge table and a `json` vocabulary file) instead of WordPiece, the model architecture allows for longer sequences (note that the max position embedding must be greater than or equal to the maximum sequence length), and the `--lr-decay-style` has been set to cosine decay. Note that the `--data-path` now includes the additional `_text_document` suffix added in preprocessing, but does not include the file extensions. + +Further command line arguments are described in the source file [`arguments.py`](./megatron/arguments.py). + +`examples/pretrain_gpt.sh` can be launched the same way as described for BERT. Set the env vars and make any other modifications, launch the container with appropriate mounts, and run the script. + +## T5 Pretraining + +Very similar to BERT and GPT, the `examples/pretrain_t5.sh` script runs single GPU "base" (~220M parameter) T5 pretraining. The primary difference from BERT and GPT is the addition of the following arguments to accommodate the T5 architecture: + +* `--kv-channels` sets the inner dimension of the "key" and "value" matrices of all attention mechanisms in the model. For BERT and GPT this defaults to the hidden size divided by the number of attention heads, but can be configured for T5. + +* `--ffn-hidden-size` sets the hidden size in the feed-forward networks within a transformer layer. For BERT and GPT this defaults to 4 times the transformer hidden size, but can be configured for T5. + +* `--encoder-seq-length` and `--decoder-seq-length` set the sequence length for the encoder and decoder separately. + +All of the other arguments remain as they were for BERT and GPT pretraining. Run this example with the same steps described above for the other scripts. + +## Distributed Pretraining + +The `examples/pretrain_{bert,gpt,t5}_distributed.sh` scripts use the PyTorch distributed launcher for distributed training. As such, multi-node training can be achieved by properly setting environment variables. See the official PyTorch [documentation](https://pytorch.org/docs/stable/elastic/run.html#launcher-api) for further description of these [environment variables](https://pytorch.org/docs/stable/distributed.html#environment-variable-initialization). By default, multi-node training uses the [nccl](https://developer.nvidia.com/nccl) distributed backend. A simple set of additional arguments and the use of the PyTorch distributed module with the `torchrun` elastic launcher (equivalent to `python -m torch.distributed.run`) are the only additional requirements to adopt distributed training. See any of `examples/pretrain_{bert,gpt,t5}_distributed.sh` for more details. + +We use two types of parallelism: data and model parallelism. We facilitate two distributed data parallel implementations: a simple one of our own that performs gradient all-reduce at the end of back propagation step, and Torch's distributed data parallel wrapper that overlaps gradient reduction with back propagation computation. To switch between these two options use `--DDP-impl local` or `--DDP-impl torch`, respectively. As expected, Torch distributed data parallelism is more efficient at larger model sizes. For example, for the 8.3 billion parameters model running on 512 GPUs, the scaling increases from 60% to 76% when Torch's distributed data parallel is used. However, the overlapping method requires more memory and for some configurations (e.g., 2.5 billion parameters using 2-way model parallel and 1.2 billion parameters with no model parallel) can make the overall training slower as a result. We empirically found that using a smaller model in those cases improves the training time. + +Second, we developed a simple and efficient two-dimensional model-parallel approach. To use tensor model parallelism (splitting execution of a single transformer module over multiple GPUs, see Section 3 of [our paper](https://arxiv.org/pdf/1909.08053.pdf)), add the `--tensor-model-parallel-size` flag to specify the number of GPUs among which to split the model, along with the arguments passed to the distributed launcher as mentioned above. To use sequence parallelism specify `--sequence-parallel`, which requires tensor model parallel as it split among the same GPUs (more details in Section 4.2.2 of [our paper](https://arxiv.org/pdf/2205.05198.pdf)). + +To use pipeline model parallelism (sharding the transformer modules into stages with an equal number of transformer modules on each stage, and then pipelining execution by breaking the batch into smaller microbatches, see Section 2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)), use the `--pipeline-model-parallel-size` flag to specify the number of stages to split the model into (e.g., splitting a model with 24 transformer layers across 4 stages would mean each stage gets 6 transformer layers each). + + + +We have examples of how to use these two different forms of model parallelism the example scripts ending in `distributed_with_mp.sh`: + +Other than these minor changes, the distributed training is identical to the training on a single GPU. + +The interleaved pipelining schedule (more details in Section 2.2.2 of [our paper](https://arxiv.org/pdf/2104.04473.pdf)) can be enabled using the `--num-layers-per-virtual-pipeline-stage` argument, which controls the number of transformer layers in a virtual stage (by default with the non-interleaved schedule, each GPU will execute a single virtual stage with `NUM_LAYERS / PIPELINE_MP_SIZE` transformer layers). The total number of layers in the transformer model should be divisible by this argument value. Additionally, the number of microbatches in the pipeline (computed as `GLOBAL_BATCH_SIZE / (DATA_PARALLEL_SIZE * MICRO_BATCH_SIZE)`) should be divisible by the `PIPELINE_MP_SIZE` when using this schedule (this condition is checked in an assertion in the code). The interleaved schedule is not supported for pipelines with 2 stages (`PIPELINE_MP_SIZE=2`). + +## Activation Checkpointing and Recomputation + +To reduce GPU memory usage so deploy a large model to a training system, we support activation checkpointing and recomputation. We support two levels of recompute granularity: `selective` and `full`. Selective recomputation is the default and recommended in almost all cases. It saves the activations that take less space and are expensive to recompute and recomputes activations that take a lot of space but are relatively cheap to recompute (see [our paper](https://arxiv.org/pdf/2205.05198) for details). To enable selective activation recompute simply use `--recompute-activations`. + +For cases where memory is very tight, `full` checkpointing saves just the inputs to a transformer layer, or a block of transformer layers, and recomputes everything else. To turn on full activation recompute use `--recompute-granularity full`. When using full activation recomputation, there are two methods: `uniform` and `block`, chosen using the `--recompute-method` argument. + +* Uniform method uniformly divides the Transformer layers into groups of layers and stores the input activations of each group in the memory. The baseline group size is 1 and, in this case, the input activation of each Transformer layer is checkpointed. When the GPU memory is insufficient, increasing the number of layers per group reduces the memory usage thus enables running a bigger model. For example, when using the number of layers per group of 4, the input activation of each group of 4 Transformer layers is checkpointed. + +* Block method checkpoints the input activations of a set number of individual Transformer layers per pipeline stage and do the rest of layers without any checkpointing. This method can be used to skip checkpointing some Transformer layers until the GPU memory is fully used, which is applicable only when there is unused GPU memory. Checkpointing fewer transformer layers avoids unnecessary activation recomputation in the backprop thus improves training performance. For example, when we specify 5 layers to checkpoint of 8 layers per pipeline stage, the input activations of only the first 5 Transformer layers are checkpointed and activation recomputation for the rest 3 layers is not needed in the backprop. + + +## Distributed Optimizer + +Usage: `--use-distributed-optimizer`. Compatible with all model and data types. + +The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described in [ZeRO: Memory Optimizations Toward Training Trillion Parameter Models](https://arxiv.org/abs/1910.02054), our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed (although the fp32 main params are still distributed, as they are separate from the bf16 model params). + +Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size): + +| | Non-distributed optim | Distributed optim | +|-|-|-| +| fp16 param, fp16 grads | 20 | 4 + 16/d | +| bf16 param, fp32 grads | 18 | 6 + 12/d | +| fp32 param, fp32 grads | 16 | 8 + 8/d | + +## FlashAttention + +Usage: `--use-flash-attn`. Support attention head dimensions at most 128. + +[FlashAttention](https://github.com/HazyResearch/flash-attention) is a fast and +memory-efficient algorithm to compute exact attention. It speeds up model +training and reduces memory requirement. + +To install FlashAttention: +```sh +pip install flash-attn +``` + +## GPT-3 Example + +In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incremental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights. + +With full global batch size of 1536 on 1024 A100 GPUs, each iteration takes around 32 seconds resulting in 138 teraFLOPs per GPU which is 44% of the theoretical peak FLOPs. + + +## Retro + +See: + +- `tools/retro/README.md` for an overview. +- `tools/retro/examples/get_preprocess_cmd.sh` for an example of common preprocessing arguments. +- `tools/retro/examples/preprocess_data.sh` for an example of how to preprocess data. +- `tools/retro/examples/pretrain_model.sh` for an example of how to pretrain a model. + +Retro is a retrieval-enhanced model that is based on GPT. As described in [Improving language models by retrieving from trillions of tokens](https://arxiv.org/abs/2112.04426), Retro retrieves from a database of document chunks by performing locality search using a sample's tokens. The retrieval database can be large -- often billions or even trillions of tokens -- and provides a more efficient storage mechanism of factual knowledge, when compared to storing factual knowledge implicitly within the network's parameters. + +Using Retro requires two steps: 1) preprocessing the retrieval database and pretraining neighbors, and 2) pretraining a model using this data. Please see `tools/retro/README.md` for a detailed overview. + + + +# Evaluation and Tasks + +We provide several command line arguments, detailed in the scripts listed below, to handle various zero-shot and fine-tuned downstream tasks. However, you can also finetune your model from a pretrained checkpoint on other corpora as desired. To do so, simply add the `--finetune` flag and adjust the input files and training parameters within the original training script. The iteration count will be reset to zero, and the optimizer and internal state will be reinitialized. If the fine-tuning is interrupted for any reason, be sure to remove the `--finetune` flag before continuing, otherwise the training will start again from the beginning. + +Because evaluation requires substantially less memory than training, it may be advantageous to merge a model trained in parallel for use on fewer GPUs in downstream tasks. The following script accomplishes this. This example reads in a GPT model with 4-way tensor and 4-way pipeline model parallelism and writes out a model with 2-way tensor and 2-way pipeline model parallelism. + +
+python tools/checkpoint/util.py \
+        --model-type GPT \
+        --load-dir checkpoints/gpt3_tp4_pp4 \
+        --save-dir checkpoints/gpt3_tp2_pp2 \
+        --target-tensor-parallel-size 2 \
+        --target-pipeline-parallel-size 2
+
+
+ +Several downstream tasks are described for both GPT and BERT models below. They can be run in distributed and model parallel modes with the same changes used in the training scripts. + +## GPT Text Generation + +We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server. + +Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. + +
+tools/text_generation_cli.py localhost:5000
+
+ +You can also use CURL or any other tools to query the server directly: + +
+curl 'http://localhost:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8'  -d '{"prompts":["Hello world"], "tokens_to_generate":1}'
+
+ +See [megatron/text_generation_server.py](megatron/text_generation_server.py) for more API options. + +### Detoxify GPT via Self-generation +We include an example in `examples/detxoify_lm/` to detoxify language models by leveraging the generative power of language models. + +See [examples/detxoify_lm/README.md](examples/detxoify_lm/README.md) for step-by-step tutorials on how to perform domain-adaptive training and detoxify LM using self-generated corpus. + + +## GPT Evaluation +We include example scripts for GPT evaluation on WikiText perplexity evaluation and LAMBADA Cloze accuracy. + +### WikiText Perplexity Evaluation +For even comparison with prior works, we evaluate perplexity on the word-level [WikiText-103 test dataset](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip), and appropriately compute perplexity given the change in tokens when using our subword tokenizer. + +We use the following command to run WikiText-103 evaluation on a 345M parameter model. +
+TASK="WIKITEXT103"
+
+VALID_DATA=<wikitext path>.txt
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+
+COMMON_TASK_ARGS="--num-layers 24 \
+                  --hidden-size 1024 \
+                  --num-attention-heads 16 \
+                  --seq-length 1024 \
+                  --max-position-embeddings 1024 \
+                  --fp16 \
+                  --vocab-file $VOCAB_FILE"
+
+python tasks/main.py \
+       --task $TASK \
+       $COMMON_TASK_ARGS \
+       --valid-data $VALID_DATA \
+       --tokenizer-type GPT2BPETokenizer \
+       --merge-file $MERGE_FILE \
+       --load $CHECKPOINT_PATH \
+       --micro-batch-size 8 \
+       --log-interval 10 \
+       --no-load-optim \
+       --no-load-rng
+
+ + +### LAMBADA Cloze Accuracy +To compute LAMBADA cloze accuracy (the accuracy of predicting the last token given the preceding tokens) we utilize a detokenized, processed version of the [LAMBADA dataset](https://github.com/cybertronai/bflm/blob/master/lambada_test.jsonl). + +We use the following command to run LAMBADA evaluation on a 345M parameter model. Note that the `--strict-lambada` flag should be used to require whole word matching. Make that `lambada` is part of the file path. + +
+TASK="LAMBADA"
+
+VALID_DATA=<lambada path>.json
+VOCAB_FILE=gpt2-vocab.json
+MERGE_FILE=gpt2-merges.txt
+CHECKPOINT_PATH=checkpoints/gpt2_345m
+COMMON_TASK_ARGS=<same as those in WikiText Perplexity Evaluation above>
+
+python tasks/main.py \
+       --task $TASK \
+       $COMMON_TASK_ARGS \
+       --valid-data $VALID_DATA \
+       --tokenizer-type GPT2BPETokenizer \
+       --strict-lambada \
+       --merge-file $MERGE_FILE \
+       --load $CHECKPOINT_PATH \
+       --micro-batch-size 8 \
+       --log-interval 10 \
+       --no-load-optim \
+       --no-load-rng
+
+ +Further command line arguments are described in the source file [`main.py`](./tasks/main.py) + +## BERT Task Evaluation +### RACE Evaluation +The following script finetunes the BERT model for evaluation on the [RACE dataset](http://www.cs.cmu.edu/~glai1/data/race/). The `TRAIN_DATA` and `VALID_DATA` directory contain the RACE dataset as separate `.txt` files. Note that for RACE, the batch size is the number of RACE query's to evaluate. Since each RACE query has four samples, the effective batch size passed through the model will be four times the batch size specified on the command line. + +
+TRAIN_DATA="data/RACE/train/middle"
+VALID_DATA="data/RACE/dev/middle \
+            data/RACE/dev/high"
+VOCAB_FILE=bert-vocab.txt
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+CHECKPOINT_PATH=checkpoints/bert_345m_race
+COMMON_TASK_ARGS="--num-layers 24 \
+                  --hidden-size 1024 \
+                  --num-attention-heads 16 \
+                  --seq-length 512 \
+                  --max-position-embeddings 512 \
+                  --fp16 \
+                  --vocab-file $VOCAB_FILE"
+
+COMMON_TASK_ARGS_EXT="--train-data $TRAIN_DATA \
+                      --valid-data $VALID_DATA \
+                      --pretrained-checkpoint $PRETRAINED_CHECKPOINT \
+                      --save-interval 10000 \
+                      --save $CHECKPOINT_PATH \
+                      --log-interval 100 \
+                      --eval-interval 1000 \
+                      --eval-iters 10 \
+                      --weight-decay 1.0e-1"
+
+python tasks/main.py \
+       --task RACE \
+       $COMMON_TASK_ARGS \
+       $COMMON_TASK_ARGS_EXT \
+       --tokenizer-type BertWordPieceLowerCase \
+       --epochs 3 \
+       --micro-batch-size 4 \
+       --lr 1.0e-5 \
+       --lr-warmup-fraction 0.06
+
+ +### MNLI Evaluation +The following script finetunes the BERT model for evaluation with the [MultiNLI sentence pair corpus](https://www.nyu.edu/projects/bowman/multinli/). Because the matching tasks are quite similar, the script can be quickly tweaked to work with the [Quora Question Pairs](https://www.kaggle.com/quora/question-pairs-dataset) (QQP) dataset as well. + +
+
+TRAIN_DATA="data/glue_data/MNLI/train.tsv"
+VALID_DATA="data/glue_data/MNLI/dev_matched.tsv \
+            data/glue_data/MNLI/dev_mismatched.tsv"
+PRETRAINED_CHECKPOINT=checkpoints/bert_345m
+VOCAB_FILE=bert-vocab.txt
+CHECKPOINT_PATH=checkpoints/bert_345m_mnli
+COMMON_TASK_ARGS=<same as those in RACE Evaluation above>
+COMMON_TASK_ARGS_EXT=<same as those in RACE Evaluation above>
+
+python tasks/main.py \
+       --task MNLI \
+       $COMMON_TASK_ARGS \
+       $COMMON_TASK_ARGS_EXT \
+       --tokenizer-type BertWordPieceLowerCase \
+       --epochs 5 \
+       --micro-batch-size 8 \
+       --lr 5.0e-5 \
+       --lr-warmup-fraction 0.065
+
+ +## Llama-2 Inference and Finetuning + +The Llama-2 [family of models](https://ai.meta.com/llama/) are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At the time of release, Llama-2 models achieved among the best results for open-source models, and were competitive with the closed-source GPT-3.5 model (see https://arxiv.org/pdf/2307.09288.pdf). + +The Llama-2 checkpoints can be loaded into Megatron for inference and finetuning. See documentation [here](docs/llama2.md). + +# Datasets +We do not host any datasets for GPT or BERT training, however, we detail their collection so that our results may be reproduced. + +## Collecting Wikipedia Training Data +We recommend following the Wikipedia data extraction process specified by Google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text." + +We recommend using the `--json` argument when using WikiExtractor, which will dump the Wikipedia data into loose json format (one json per line), making it more manageable on the file system and also readily consumable by our codebase. We recommend further preprocessing this json dataset by nltk punctuation standardization. For BERT training, use the `--split-sentences` flag to `preprocess_data.py` as described [above](#data-preprocessing) to include sentence breaks in the produced index. If you'd like to use Wikipedia data for GPT training you should still clean it with nltk/spacy/ftfy, but do not use the `--split-sentences` flag. + +## Collecting GPT Webtext Data +We utilize the publicly available [OpenWebText](https://github.com/eukaryote31/openwebtext) library from [jcpeterson](https://github.com/jcpeterson/openwebtext) and [eukaryote31's](https://github.com/eukaryote31/openwebtext) work to download urls. We then filtered, cleaned, and deduplicated all downloaded content according to the procedure described in our [openwebtext](./tools/openwebtext) directory. For reddit URLs corresponding to content up to October 2018 we arrived at approximately 37GB of content. + +# Reproducibility +Megatron training is intended to be bitwise reproducible. This means that the same training config run twice in the same HW and SW environment should produce identical model checkpoints, losses and accuracy metric values (iteration time metrics may vary). + +There are currently two known Megatron optimizations that break reproducibility whilst still producing almost identical training runs. The following workarounds should be applied in cases where reproducibility is required: +1. When training using `--bf16`, reproducbility is only obtained when the checkpointing and resume schedule of training is identical. If the checkpointing schedule will change, i.e. checkpointing and resume will occur at different iterations, the option `--no-bias-gelu-fusion` should be used. +2. Flash attention is non-deterministic. If reproducibility is required do not use `--use-flash-attn`. + +These sources of non-determinism are under active investigation. If you observe non-determinism in Megatron training under other circumstances please open an issue. diff --git a/examples/early_exit/1-3B.sh b/examples/early_exit/1-3B.sh new file mode 100755 index 00000000..a5481fa3 --- /dev/null +++ b/examples/early_exit/1-3B.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +PROJECT_NAME=EE-LLM +GROUP_NAME=1B-EXIT-6-12-untie-300B + +RUN_NAME=`date "+%m%d-%H%M"` + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# NCCL configuration +# export NCCL_IB_HCA= +# export NCCL_IB_TC= +# export NCCL_IB_SL= +# export NCCL_IB_GID_INDEX= +# export NCCL_SOCKET_IFNAME= +# export NCCL_DEBUG=WARN + +# Checkpoint configuration +CHECKPOINT_HOME= +CHECKPOINT_PATH=$CHECKPOINT_HOME/$PROJECT_NAME/$GROUP_NAME + +# data configuration +DATA_HOME= +TOKENIZER_PATH= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +# Distributed configuration +# MASTER_ADDR=127.0.0.1 +# MASTER_PORT=5900 +# RANK=0 +# WORLD_SIZE=2 +NPROC_PER_NODE=8 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +# Parallisim configuration +TP=1 +PP=4 + +MICRO_BATCH=4 +GLOBAL_BATCH=2048 + +# Train iteration +LOG_INTERVAL=2 +SAVE_INTERVAL=$(( 240 * 10 )) # 10B data +TRAIN_ITER=$(( $SAVE_INTERVAL * 30)) # 300B data +EVAL_INTERVAL=$(( 240 * 5)) + +# GPT configuration +NLAYERS=24 +HIDDEN=2048 +HEADS=32 +SEQ=2048 + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --sequence-parallel \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0003 \ + --train-iters $TRAIN_ITER \ + --lr-decay-style cosine \ + --min-lr 3.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ +" + +# Early-exit configuration +EE_ARGS=" + --exit-layer-nums 7 13 \ + --exit-layer-weight 0.25 0.5 \ + --pre-exit \ +" + +OUTPUT_ARGS=" + --log-interval 2 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 0 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $EE_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH diff --git a/examples/early_exit/13B.sh b/examples/early_exit/13B.sh new file mode 100755 index 00000000..6ebd22a5 --- /dev/null +++ b/examples/early_exit/13B.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +PROJECT_NAME=EE-LLM +GROUP_NAME=7B-EXIT-8-16-untie-300B + +RUN_NAME=`date "+%m%d-%H%M"` + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# NCCL configuration +# export NCCL_IB_HCA= +# export NCCL_IB_TC= +# export NCCL_IB_SL= +# export NCCL_IB_GID_INDEX= +# export NCCL_SOCKET_IFNAME= +# export NCCL_DEBUG=WARN + +# Checkpoint configuration +CHECKPOINT_HOME= +CHECKPOINT_PATH=$CHECKPOINT_HOME/$PROJECT_NAME/$GROUP_NAME + +# data configuration +DATA_HOME= +TOKENIZER_PATH= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +# Distributed configuration +# MASTER_ADDR=127.0.0.1 +# MASTER_PORT=5900 +# RANK=0 +# WORLD_SIZE=2 +NPROC_PER_NODE=8 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +# Parallisim configuration +TP=2 +PP=4 + +MICRO_BATCH=1 +GLOBAL_BATCH=2048 + +# Train iteration +LOG_INTERVAL=2 +SAVE_INTERVAL=$(( 240 * 10 )) # 10B data +TRAIN_ITER=$(( $SAVE_INTERVAL * 80)) # 800B data +EVAL_INTERVAL=$(( 240 * 5)) + +# GPT configuration +NLAYERS=40 +HIDDEN=5120 +HEADS=40 +SEQ=2048 + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --sequence-parallel \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0003 \ + --train-iters $TRAIN_ITER \ + --lr-decay-style cosine \ + --min-lr 3.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --untie-embeddings-and-output-weights \ +" + +# Early-exit configuration +EE_ARGS=" + --untie-exit-output-weights \ + --exit-layer-nums 11 21 \ + --exit-layer-weight 0.1 0.2 \ + --pre-exit \ +" + +OUTPUT_ARGS=" + --log-interval 2 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 0 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $EE_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH diff --git a/examples/early_exit/30B.sh b/examples/early_exit/30B.sh new file mode 100755 index 00000000..2903db72 --- /dev/null +++ b/examples/early_exit/30B.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +PROJECT_NAME=EE-LLM +GROUP_NAME=7B-EXIT-8-16-untie-300B + +RUN_NAME=`date "+%m%d-%H%M"` + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# NCCL configuration +# export NCCL_IB_HCA= +# export NCCL_IB_TC= +# export NCCL_IB_SL= +# export NCCL_IB_GID_INDEX= +# export NCCL_SOCKET_IFNAME= +# export NCCL_DEBUG=WARN + +# Checkpoint configuration +CHECKPOINT_HOME= +CHECKPOINT_PATH=$CHECKPOINT_HOME/$PROJECT_NAME/$GROUP_NAME + +# data configuration +DATA_HOME= +TOKENIZER_PATH= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +# Distributed configuration +# MASTER_ADDR=127.0.0.1 +# MASTER_PORT=5900 +# RANK=0 +# WORLD_SIZE=2 +NPROC_PER_NODE=8 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +# Parallisim configuration +TP=4 +PP=4 + +MICRO_BATCH=1 +GLOBAL_BATCH=2048 + +# Train iteration +LOG_INTERVAL=2 +SAVE_INTERVAL=$(( 240 * 10 )) # 10B data +TRAIN_ITER=$(( $SAVE_INTERVAL * 80)) # 800B data +EVAL_INTERVAL=$(( 240 * 5)) + +# GPT configuration +NLAYERS=60 +HIDDEN=6656 +HEADS=52 +SEQ=2048 + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --sequence-parallel \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0003 \ + --train-iters $TRAIN_ITER \ + --lr-decay-style cosine \ + --min-lr 3.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --untie-embeddings-and-output-weights \ +" + +# Early-exit configuration +EE_ARGS=" + --untie-exit-output-weights \ + --exit-layer-nums 16 31 \ + --exit-layer-weight 0.1 0.2 \ + --pre-exit \ +" + +OUTPUT_ARGS=" + --log-interval 2 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 0 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $EE_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH diff --git a/examples/early_exit/7B.sh b/examples/early_exit/7B.sh new file mode 100755 index 00000000..31439e53 --- /dev/null +++ b/examples/early_exit/7B.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +PROJECT_NAME=EE-LLM +GROUP_NAME=7B-EXIT-8-16-untie-300B + +RUN_NAME=`date "+%m%d-%H%M"` + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export OMP_NUM_THREADS=4 + +# NCCL configuration +# export NCCL_IB_HCA= +# export NCCL_IB_TC= +# export NCCL_IB_SL= +# export NCCL_IB_GID_INDEX= +# export NCCL_SOCKET_IFNAME= +# export NCCL_DEBUG=WARN + +# Checkpoint configuration +CHECKPOINT_HOME= +CHECKPOINT_PATH=$CHECKPOINT_HOME/$PROJECT_NAME/$GROUP_NAME + +# data configuration +DATA_HOME= +TOKENIZER_PATH= +DATASET_ARXIV=${DATA_HOME}/redpajama-arxiv/all +DATASET_BOOKS=${DATA_HOME}/redpajama-book/all +DATASET_C4=${DATA_HOME}/redpajama-c4/all +DATASET_CC=${DATA_HOME}/redpajama-cc/all +DATASET_STACKEXCHANGE=${DATA_HOME}/redpajama-pile-stackexchange/all +DATASET_CODE=${DATA_HOME}/redpajama-stack-code/all +DATASET_WIKIPEDIA=${DATA_HOME}/redpajama-wiki/all +DATASET_PILE_EUROPARL=${DATA_HOME}/the-pile-europarl/all +DATASET_PILE_FREELAW=${DATA_HOME}/the-pile-freelaw/all +DATASET_PILE_HACKERNEWS=${DATA_HOME}/the-pile-hackernews/all +DATASET_PILE_NIH=${DATA_HOME}/the-pile-nih/all +DATASET_PILE_PHILPAPER=${DATA_HOME}/the-pile-philpaper/all +DATASET_PILE_PMA=${DATA_HOME}/the-pile-pubmed-abstract/all +DATASET_PILE_PMC=${DATA_HOME}/the-pile-pubmed-central/all +DATASET_PILE_USPTO=${DATA_HOME}/the-pile-uspto/all + +DATA_PATH="\ + 0.0362 ${DATASET_ARXIV} \ + 0.0657 ${DATASET_BOOKS} \ + 0.2264 ${DATASET_C4} \ + 0.4491 ${DATASET_CC} \ + 0.0246 ${DATASET_STACKEXCHANGE} \ + 0.0810 ${DATASET_CODE} \ + 0.0548 ${DATASET_WIKIPEDIA} \ + 0.0010 ${DATASET_PILE_EUROPARL} \ + 0.0162 ${DATASET_PILE_FREELAW} \ + 0.0006 ${DATASET_PILE_HACKERNEWS} \ + 0.0005 ${DATASET_PILE_NIH} \ + 0.0006 ${DATASET_PILE_PHILPAPER} \ + 0.0065 ${DATASET_PILE_PMA} \ + 0.0318 ${DATASET_PILE_PMC} \ + 0.0050 ${DATASET_PILE_USPTO} \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --split 990,9,1 \ +" + +# Distributed configuration +# MASTER_ADDR=127.0.0.1 +# MASTER_PORT=5900 +# RANK=0 +# WORLD_SIZE=2 +NPROC_PER_NODE=8 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes $WORLD_SIZE \ + --node_rank $RANK \ + " + +# Parallisim configuration +TP=1 +PP=4 + +MICRO_BATCH=2 +GLOBAL_BATCH=2048 + +# Train iteration +LOG_INTERVAL=2 +SAVE_INTERVAL=$(( 240 * 10 )) # 10B data +TRAIN_ITER=$(( $SAVE_INTERVAL * 80)) # 800B data +EVAL_INTERVAL=$(( 240 * 5)) + +# GPT configuration +NLAYERS=32 +HIDDEN=4096 +HEADS=32 +SEQ=2048 + +GPT_ARGS=" + --tensor-model-parallel-size $TP \ + --pipeline-model-parallel-size $PP \ + --num-layers $NLAYERS \ + --hidden-size $HIDDEN \ + --num-attention-heads $HEADS \ + --seq-length $SEQ \ + --max-position-embeddings $SEQ \ + --sequence-parallel \ + --micro-batch-size $MICRO_BATCH \ + --global-batch-size $GLOBAL_BATCH \ + --lr 0.0003 \ + --train-iters $TRAIN_ITER \ + --lr-decay-style cosine \ + --min-lr 3.0e-5 \ + --weight-decay 1e-1 \ + --lr-warmup-iters 2000 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.01 \ + --clip-grad 1.0 \ + --bf16 \ + --disable-bias-linear \ + --use-flash-attn \ + --normalization RMSNorm \ + --position-embedding-type rope \ + --swiglu \ + --untie-embeddings-and-output-weights \ +" + +# Early-exit configuration +EE_ARGS=" + --untie-exit-output-weights \ + --exit-layer-nums 9 17 \ + --exit-layer-weight 0.1 0.2 \ + --pre-exit \ +" + +OUTPUT_ARGS=" + --log-interval 2 \ + --log-timers-to-tracker \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $EVAL_INTERVAL \ + --eval-iters 0 \ + --wandb-project $PROJECT_NAME \ + --wandb-group $GROUP_NAME \ + --wandb-exp-name $RUN_NAME \ +" + +torchrun $DIST_ARGS \ + pretrain_early_exit_gpt.py \ + $GPT_ARGS \ + $EE_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH diff --git a/examples/early_exit/ee_inference_server.sh b/examples/early_exit/ee_inference_server.sh new file mode 100755 index 00000000..dfe9fd3b --- /dev/null +++ b/examples/early_exit/ee_inference_server.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +PROJECT_NAME=EE-LLM + +export OMP_NUM_THREADS=8 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# Tokenizer +TOKENIZER_PATH= +# Checkpoint +CHECKPOINT_PATH= +# Parallelism +TP= +PP= +# Server port +PORT=5000 + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=5950 +NPROC_PER_NODE=$(( $TP * $PP )) +LOAD_ITERATION=0 + +DIST_ARGS=" + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT \ + --nproc_per_node $NPROC_PER_NODE \ + --nnodes 1 \ + --node_rank 0 \ + " + +SERVER_ARGS=" + --use-checkpoint-args \ + --tokenizer-type SentencePieceTokenizer \ + --tokenizer-model $TOKENIZER_PATH \ + --load $CHECKPOINT_PATH \ + --load-iteration $LOAD_ITERATION \ + --port $PORT +" + +torchrun $DIST_ARGS \ + tools/run_early_exit_text_generation_server.py \ + $SERVER_ARGS diff --git a/megatron/arguments.py b/megatron/arguments.py index e28a9391..cc853b95 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -8,6 +8,7 @@ import os import torch import types +import math import torch.nn.functional as F from megatron.global_vars import set_retro_args, get_retro_args @@ -40,6 +41,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False): parser = _add_transformer_engine_args(parser) parser = _add_retro_args(parser) parser = _add_experimental_args(parser) + parser = _add_multi_exit_args(parser) # Custom arguments. if extra_args_provider is not None: @@ -368,6 +370,26 @@ def validate_args(args, defaults={}): assert not args.standalone_embedding_stage, "early exit not support standalone embedding stage" assert args.num_layers_per_virtual_pipeline_stage is None, "early exit not support virtual pipeline" assert args.retro_add_retriever is False, "early exit not support retro_add_retriever" + assert args.exit_layer_weight_warmup_iters >= 0, '--exit-layer-weight-warmup-iters should be non-negative' + + # check bubble filling + if args.fill_explicit_bubbles: + assert args.pipeline_model_parallel_size > 1, "--fill-explicit-bubbles requires pipeline parallel size > 1" + # calculate for warmup + opt_num_fill_warmup_microbatches = int((args.pipeline_model_parallel_size - 1) * args.backward_forward_ratio / (1.0 + args.backward_forward_ratio)) + if args.num_fill_warmup_microbatches is None: + args.num_fill_warmup_microbatches = opt_num_fill_warmup_microbatches + elif args.num_fill_warmup_microbatches > opt_num_fill_warmup_microbatches: + if args.rank == 0: + print(f"WARNING: num_fill_warmup_microbatches is larger than optimal value {opt_num_fill_warmup_microbatches}, set to {opt_num_fill_warmup_microbatches}.") + args.num_fill_warmup_microbatches = opt_num_fill_warmup_microbatches + opt_num_fill_cooldown_microbatches = int((args.pipeline_model_parallel_size - 1) * args.backward_forward_ratio / (1.0 + args.backward_forward_ratio)) + if args.num_fill_cooldown_microbatches is None: + args.num_fill_cooldown_microbatches = opt_num_fill_cooldown_microbatches + elif args.num_fill_cooldown_microbatches > opt_num_fill_cooldown_microbatches: + if args.rank == 0: + print(f"WARNING: num_fill_cooldown_microbatches is larger than optimal value {opt_num_fill_cooldown_microbatches}, set to {opt_num_fill_cooldown_microbatches}.") + args.num_fill_cooldown_microbatches = opt_num_fill_cooldown_microbatches # Legacy RoPE arguments if args.use_rotary_position_embeddings: @@ -393,14 +415,19 @@ def validate_args(args, defaults={}): # multi exit checks. if len(args.exit_layer_weight) == 0: args.exit_layer_weight = [1.0 for _ in args.exit_layer_nums] + if len(args.exit_layer_weight_init) == 0: + args.exit_layer_weight_init = [0.0 for _ in args.exit_layer_nums] if len(args.exit_layer_temperature) == 0: args.exit_layer_temperature = [1.0 for _ in args.exit_layer_nums] if len(args.exit_layer_nums) != len(args.exit_layer_weight): raise RuntimeError("--exit-layer-nums and --exit-layer-weight must correspond one to one") + if len(args.exit_layer_nums) != len(args.exit_layer_weight_init): + raise RuntimeError("--exit-layer-nums and --exit-layer-weight-init must correspond one to one") if len(args.exit_layer_nums) != len(args.exit_layer_temperature): raise RuntimeError("--exit-layer-nums and --exit-layer-temperature must correspond one to one") if args.use_exit_mlp: assert len(args.exit_layer_nums) > 0, "--use-exit-mlp requires at least one early exit layer" + assert not args.pre_exit, "--use-exit-mlp not supports pre_exit" # Print arguments. _print_args("arguments", args) @@ -630,16 +657,7 @@ def _add_network_size_args(parser): help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--untie-embeddings-and-output-weights', action='store_true', help='Untie embeddings and output weights.'), - group.add_argument('--exit-layer-nums', type=int, nargs='+', default=[], - help='Layer number of early exit layers, start from 1.') - group.add_argument('--exit-layer-weight', type=float, nargs='+', default=[], - help='Loss weight of each early exit layer.') - group.add_argument('--exit-layer-temperature', type=float, nargs='+', default=[], - help='Temperature of each early exit layer.') - group.add_argument('--use-exit-mlp', action='store_true', - help='Use exit mlp in each early exit layer.') - group.add_argument('--untie-exit-output-weights', action='store_true', - help='Untie output weights of different exit layer') + return parser @@ -1025,8 +1043,8 @@ def _add_mixed_precision_args(parser): help='hysteresis for dynamic loss scaling') group.add_argument('--fp32-residual-connection', action='store_true', help='Move residual connections to fp32.') - group.add_argument('--no-query-key-layer-scaling', action='store_false', - help='Do not scale Q * K^T by 1 / layer-number.', + group.add_argument('--query-key-layer-scaling', action='store_true', + help='Scale Q * K^T by 1 / layer-number.', dest='apply_query_key_layer_scaling') group.add_argument('--attention-softmax-in-fp32', action='store_true', help='Run attention masking and softmax in fp32. ' @@ -1057,7 +1075,6 @@ def _add_distributed_args(parser): '--tensor-model-parallel-size instead.') group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, help='Number of layers per virtual pipeline stage') - group.add_argument('--use-gpipe', action='store_true') group.add_argument('--overlap-p2p-communication', action='store_true', help='overlap pipeline parallel communication with forward and backward chunks', @@ -1159,6 +1176,8 @@ def _add_data_args(parser): group.add_argument('--vocab-size', type=int, default=None, help='Size of vocab before EOD or padding.') + group.add_argument('--padded-vocab-size', type=int, default=None, + help='Size of vocab after padding.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file.') group.add_argument('--merge-file', type=str, default=None, @@ -1210,6 +1229,38 @@ def _add_data_args(parser): return parser +def _add_multi_exit_args(parser): + group = parser.add_argument_group(title='multexit') + + group.add_argument('--exit-layer-nums', type=int, nargs='+', default=[], + help='Layer number of early exit layers, start from 1.') + group.add_argument('--exit-layer-weight', type=float, nargs='+', default=[], + help='Loss weight of each early exit layer.') + group.add_argument('--exit-layer-weight-warmup-iters', default=0, type=int) + group.add_argument('--exit-layer-weight-warmup-style', default='linear', type=str, + choices=['linear', 'cosine']) + group.add_argument('--exit-layer-weight-init', type=float, nargs='+', default=[]) + group.add_argument('--exit-layer-temperature', type=float, nargs='+', default=[], + help='Temperature of each early exit layer.') + group.add_argument('--use-exit-mlp', action='store_true', + help='Use exit mlp in each early exit layer.') + group.add_argument('--use-exit-block', action='store_true', + help='Use a transformer block in each early exit branch') + group.add_argument('--use-exit-norm', action='store_true', + help='Use exit norm in each early exit layer') + group.add_argument('--untie-exit-output-weights', action='store_true', + help='Untie output weights of different exit layer') + group.add_argument('--pre-exit', action='store_true', + help='Calcualte early exit output before its transformer layer.') + # todo @pxc: calculate number of fill warmup/cooldown microbatches automatically + group.add_argument('--fill-explicit-bubbles', action='store_true') + group.add_argument('--num-fill-warmup-microbatches', type=int, default=None) + group.add_argument('--num-fill-cooldown-microbatches', type=int, default=None) + group.add_argument('--backward-forward-ratio', type=float, default=2.0) + group.add_argument('--use-dynamic-exit-layer-weight', action='store_true') + return parser + + def _add_autoresume_args(parser): group = parser.add_argument_group(title='autoresume') diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 0db7d959..14cde010 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -518,7 +518,10 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('exit_layer_nums', force=True) _set_arg('exit_layer_weight', force=True) _set_arg('use_exit_mlp', force=True) + _set_arg('use_exit_block', force=True) + _set_arg('use_exit_norm', force=True) _set_arg('untie_exit_output_weights', force=True) + _set_arg('pre_exit', force=True) if checkpoint_version < 3.0: _set_arg('tensor_model_parallel_size', 'model_parallel_size') diff --git a/megatron/core/inference_params.py b/megatron/core/inference_params.py index 9650d4e4..f24d0b20 100644 --- a/megatron/core/inference_params.py +++ b/megatron/core/inference_params.py @@ -16,6 +16,8 @@ def __init__(self, max_batch_size, max_sequence_length, early_exit_thres=None, t self.has_early_exit = False self.is_first_step = True self.tokenizer = tokenizer + self.prev_has_early_exit = False + self.output_logits = dict() def early_exit(self, logits, layer_num=0): # to regularly recompute kv cache of the entire network diff --git a/megatron/core/models/gpt/__init__.py b/megatron/core/models/gpt/__init__.py index 2d5eb867..1f6d2fb1 100644 --- a/megatron/core/models/gpt/__init__.py +++ b/megatron/core/models/gpt/__init__.py @@ -1 +1 @@ -from .gpt_model import GPTModel +from .gpt_model import GPTModel \ No newline at end of file diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index e17ecf8a..38743e53 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -19,6 +19,8 @@ _EMBEDDING_GROUP = None # Position embedding group. _POSITION_EMBEDDING_GROUP = None +# Pipeline Endpoint group. +_PIPELINE_ENDPOINT_GROUP = None # Data parallel group that the current rank belongs to. _DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP_GLOO = None @@ -46,6 +48,9 @@ # A list of ranks that have a copy of the position embedding. _POSITION_EMBEDDING_GLOBAL_RANKS = None +# A list of ranks at the ends of pipeline. +_PIPELINE_ENDPOINT_GLOBAL_RANKS = None + # A list of global ranks for each pipeline group to ease calculation of the source # rank when broadcasting from the first or last pipeline stage. _PIPELINE_GLOBAL_RANKS = None @@ -73,6 +78,10 @@ _EARLY_EXIT_LAYER_NUMS = None +_EARLY_EXIT_STAGES = None + +_EMBEDDING_STAGES = None + def initialize_model_parallel( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, @@ -323,11 +332,18 @@ def initialize_model_parallel( global _EMBEDDING_GROUP global _EMBEDDING_GLOBAL_RANKS assert _EMBEDDING_GROUP is None, 'embedding group is already initialized' + global _PIPELINE_ENDPOINT_GROUP + global _PIPELINE_ENDPOINT_GLOBAL_RANKS + assert _PIPELINE_ENDPOINT_GROUP is None, 'pipeline endpoint group is already initialized' global _POSITION_EMBEDDING_GROUP global _POSITION_EMBEDDING_GLOBAL_RANKS assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized' global _EARLY_EXIT_LAYER_NUMS assert _EARLY_EXIT_LAYER_NUMS is None, 'early exit layer nums is already initialized' + global _EARLY_EXIT_STAGES + assert _EARLY_EXIT_STAGES is None, 'early exit stages is already initialized' + layer_per_stage = num_layers / pipeline_model_parallel_size + _EARLY_EXIT_STAGES = list(set(map(lambda layer_num: int((layer_num - 1) // layer_per_stage), early_exit_layer_nums))) for i in range(num_pipeline_model_parallel_groups): ranks = range(i, world_size, num_pipeline_model_parallel_groups) group = torch.distributed.new_group(ranks) @@ -336,17 +352,19 @@ def initialize_model_parallel( _PIPELINE_GLOBAL_RANKS = ranks # get early exit layers in this pipeline stage offset = ranks.index(rank) - layer_per_stage = num_layers / pipeline_model_parallel_size _EARLY_EXIT_LAYER_NUMS = list(filter(lambda x: (layer_per_stage * offset + 1) <= x <= (layer_per_stage * (offset + 1)), early_exit_layer_nums)) - - # Setup embedding group (to exchange gradients between - # first and last stages). + + # TODO (@pxc): Check the compatibility of tied exit embedding with interleaved pipeline + # Setup embedding group. if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] + embedding_ranks = {ranks[stage] for stage in _EARLY_EXIT_STAGES} + embedding_ranks.update([ranks[0], ranks[-1]]) + embedding_ranks = list(embedding_ranks) + pipeline_endpoint_ranks = [ranks[0], ranks[-1]] position_embedding_ranks = [ranks[0]] if pipeline_model_parallel_split_rank is not None: - if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: - embedding_ranks = [ + if ranks[pipeline_model_parallel_split_rank] not in pipeline_endpoint_ranks: + pipeline_endpoint_ranks = [ ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1], @@ -355,14 +373,21 @@ def initialize_model_parallel( position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]] else: embedding_ranks = ranks + pipeline_endpoint_ranks = ranks position_embedding_ranks = ranks - group = torch.distributed.new_group(embedding_ranks) + embedding_group = torch.distributed.new_group(embedding_ranks) if rank in embedding_ranks: - _EMBEDDING_GROUP = group + _EMBEDDING_GROUP = embedding_group if rank in ranks: _EMBEDDING_GLOBAL_RANKS = embedding_ranks + pipeline_endpoint_group = torch.distributed.new_group(pipeline_endpoint_ranks) + if rank in pipeline_endpoint_ranks: + _PIPELINE_ENDPOINT_GROUP = pipeline_endpoint_group + if rank in ranks: + _PIPELINE_ENDPOINT_GLOBAL_RANKS = pipeline_endpoint_ranks + group = torch.distributed.new_group(position_embedding_ranks) if rank in position_embedding_ranks: _POSITION_EMBEDDING_GROUP = group @@ -522,6 +547,12 @@ def get_embedding_group(): return _EMBEDDING_GROUP +def get_pipeline_endpoint_group(): + """Get the pipeleine endpoint group the caller rank belongs to.""" + assert _PIPELINE_ENDPOINT_GROUP is not None, 'pipeline endpoint group is not initialized' + return _PIPELINE_ENDPOINT_GROUP + + def get_position_embedding_group(): """Get the position embedding group the caller rank belongs to.""" assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized' @@ -668,6 +699,20 @@ def is_pipeline_last_stage(ignore_virtual=False): return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1) +def is_output_embedding_pipeline_stage(ignore_virtual=False): + """Return True if in the pipeline has word embedding, False otherwise.""" + if not ignore_virtual: + virtual_pipeline_model_parallel_world_size = ( + get_virtual_pipeline_model_parallel_world_size() + ) + if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( + virtual_pipeline_model_parallel_world_size - 1 + ): + return False + rank = get_pipeline_model_parallel_rank() + return rank == (get_pipeline_model_parallel_world_size() - 1) or rank in _EARLY_EXIT_STAGES + + def is_rank_in_embedding_group(ignore_virtual=False): """Return true if current rank is in embedding group, False otherwise.""" rank = torch.distributed.get_rank() @@ -684,6 +729,13 @@ def is_rank_in_embedding_group(ignore_virtual=False): return False +def is_rank_in_pipeline_endpoint_group(): + """Return true if current rank is in embedding group, False otherwise.""" + rank = torch.distributed.get_rank() + global _PIPELINE_ENDPOINT_GLOBAL_RANKS + return rank in _PIPELINE_ENDPOINT_GLOBAL_RANKS + + def is_rank_in_position_embedding_group(): """Return true if current rank is in position embedding group, False otherwise.""" rank = torch.distributed.get_rank() @@ -864,19 +916,44 @@ def get_data_modulo_expert_parallel_rank(): else: return 0 + def has_early_exit(): """Return true if pipeline stage has early exit output""" return _EARLY_EXIT_LAYER_NUMS != None and len(_EARLY_EXIT_LAYER_NUMS) > 0 + def get_early_exit_layer_nums(): return _EARLY_EXIT_LAYER_NUMS + def set_early_exit_layer_nums(layer_nums): global _EARLY_EXIT_LAYER_NUMS assert type(layer_nums) == list _EARLY_EXIT_LAYER_NUMS = layer_nums +def get_early_exit_stages(): + return _EARLY_EXIT_STAGES + + +def is_exit_stage(): + return get_pipeline_model_parallel_rank() in _EARLY_EXIT_STAGES + + +def post_stage_has_early_exit(): + return (len(_EARLY_EXIT_STAGES) > 0) and (_EARLY_EXIT_STAGES[-1] > get_pipeline_model_parallel_rank()) + + +def pre_stage_has_early_exit(): + return (len(_EARLY_EXIT_STAGES) > 0) and (_EARLY_EXIT_STAGES[0] < get_pipeline_model_parallel_rank()) + + +def set_early_exit_stages(stages): + global _EARLY_EXIT_STAGES + assert type(stages) == list + _EARLY_EXIT_STAGES = stages + + def _set_global_memory_buffer(): """Initialize global buffer""" global _GLOBAL_MEMORY_BUFFER @@ -914,6 +991,8 @@ def destroy_model_parallel(): _CONTEXT_PARALLEL_GLOBAL_RANKS = None global _EMBEDDING_GROUP _EMBEDDING_GROUP = None + global _PIPELINE_ENDPOINT_GROUP + _PIPELINE_ENDPOINT_GROUP = None global _POSITION_EMBEDDING_GROUP _POSITION_EMBEDDING_GROUP = None global _TENSOR_AND_DATA_PARALLEL_GROUP @@ -940,3 +1019,5 @@ def destroy_model_parallel(): _GLOBAL_MEMORY_BUFFER = None global _EARLY_EXIT_LAYER_NUMS _EARLY_EXIT_LAYER_NUMS = None + global _EARLY_EXIT_STAGES + _EARLY_EXIT_STAGES = None diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 19ffc9f0..b3bbf224 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2,7 +2,9 @@ import contextlib from typing import Callable, Iterator, List, Optional, Union +from functools import partial +import math import torch from torch.autograd.variable import Variable @@ -94,18 +96,78 @@ def forward_step(data_iterator, model): """ args = get_args() pipeline_model_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() + + # early exit weight + if parallel_state.has_early_exit(): + exit_layer_weight = dict(filter(lambda p: p[0] in parallel_state.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_weight))) + exit_layer_weight_init = dict(filter(lambda p: p[0] in parallel_state.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_weight_init))) + early_exit_loss_weight = EarlyExitLossWeight(exit_layer_weight, exit_layer_weight_init, + args.exit_layer_weight_warmup_iters, args.exit_layer_weight_warmup_style) + else: + early_exit_loss_weight = None + if pipeline_model_parallel_size > 1: if len(args.exit_layer_nums) > 0: - forward_backward_func = forward_backward_pipelining_with_early_exit + if args.fill_explicit_bubbles: + forward_backward_func = partial(early_exit_forward_backward_pipelining_with_bubble_filling, + num_fill_warmup_microbatches=args.num_fill_warmup_microbatches, + num_fill_cooldown_microbatches=args.num_fill_cooldown_microbatches, + early_exit_loss_weight=early_exit_loss_weight) + else: + forward_backward_func = partial(early_exit_forward_backward_pipelining, early_exit_loss_weight=early_exit_loss_weight) elif parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: forward_backward_func = forward_backward_pipelining_with_interleaving else: forward_backward_func = forward_backward_pipelining_without_interleaving else: - forward_backward_func = forward_backward_no_pipelining + if len(args.exit_layer_nums) > 0: + forward_backward_func = partial(early_exit_forward_backward_no_pipelining, early_exit_loss_weight=early_exit_loss_weight) + else: + forward_backward_func = forward_backward_no_pipelining return forward_backward_func +class EarlyExitLossWeight(): + + def __init__(self, exit_layer_loss_weight, exit_layer_loss_weight_init, + exit_layer_weight_warmup_iters, exit_layer_weight_warmup_style): + args = get_args() + self.warmup = exit_layer_weight_warmup_iters > 0 and args.curr_iteration < exit_layer_weight_warmup_iters + if self.warmup: + self.warmup_iters = exit_layer_weight_warmup_iters + self.exit_layer_loss_weight = {layer_num: weight for layer_num, weight in exit_layer_loss_weight_init.items()} + self.exit_layer_loss_weight_init = exit_layer_loss_weight_init + self.exit_layer_loss_weight_delta = { + layer_num: exit_layer_loss_weight[layer_num] - exit_layer_loss_weight_init[layer_num] + for layer_num in exit_layer_loss_weight.keys() + } + if exit_layer_weight_warmup_style == 'cosine': + self.update_func = self.cosine_warmup + else: # linear + self.update_func = self.linear_warmup + else: + self.exit_layer_loss_weight = exit_layer_loss_weight + + def cosine_warmup(self, inc_ratio): + for layer_num in self.exit_layer_loss_weight.keys(): + self.exit_layer_loss_weight[layer_num] = 0.5 * (math.cos(math.pi * (inc_ratio + 1.0)) + 1.0) \ + * self.exit_layer_loss_weight_delta[layer_num] + self.exit_layer_loss_weight_init[layer_num] + + def linear_warmup(self, inc_ratio): + for layer_num in self.exit_layer_loss_weight.keys(): + self.exit_layer_loss_weight[layer_num] = inc_ratio * self.exit_layer_loss_weight_delta[layer_num] \ + + self.exit_layer_loss_weight_init[layer_num] + + def get_weight(self, layer): + return self.exit_layer_loss_weight[layer] + + def update(self): + if self.warmup: + iteration = get_args().curr_iteration + if iteration <= self.warmup_iters: + self.update_func(float(iteration) / self.warmup_iters) + return + def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): '''Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. @@ -162,7 +224,6 @@ def forward_step( config, collect_non_loss_data=False, checkpoint_activations_microbatch=None, - early_exit_loss=None, ): """Forward step for passed-in model. @@ -187,36 +248,21 @@ def forward_step( context_manager = contextlib.nullcontext() with context_manager: if checkpoint_activations_microbatch is None: - lm_output, loss_func = forward_step_func(data_iterator, model) + output_tensor, loss_func = forward_step_func(data_iterator, model) else: - lm_output, loss_func = forward_step_func( + output_tensor, loss_func = forward_step_func( data_iterator, model, checkpoint_activations_microbatch ) - if parallel_state.has_early_exit(): - output_tensor, early_exit_losses = lm_output - else: - output_tensor = lm_output - - loss_dict = {} if parallel_state.is_pipeline_last_stage(): if not collect_non_loss_data: output_tensor = loss_func(output_tensor) loss, loss_reduced = output_tensor output_tensor = loss / num_microbatches - loss_dict.update(loss_reduced) + forward_data_store.append(loss_reduced) else: data = loss_func(output_tensor, non_loss_data=True) - loss_dict.update(data) - - if parallel_state.has_early_exit(): - for _, output in early_exit_losses.items(): - loss, loss_reduced = output - early_exit_loss.append(loss / num_microbatches) - loss_dict.update(loss_reduced) - - if loss_dict: - forward_data_store.append(loss_dict) + forward_data_store.append(data) if config.timers is not None: config.timers('forward-compute').stop() @@ -235,7 +281,7 @@ def forward_step( return [output_tensor] -def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config, early_exit_loss=None): +def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, config): """Backward step through passed-in output tensor. If last stage, output_tensor_grad is None, otherwise gradient of loss @@ -269,14 +315,7 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c if output_tensor_grad[0] is None and config.grad_scale_func is not None: output_tensor[0] = config.grad_scale_func(output_tensor[0]) - if early_exit_loss is not None and len(early_exit_loss) > 0: - if output_tensor_grad[0] is not None: - fake_loss = torch.sum(torch.stack(early_exit_loss), dim=0) + torch.sum(output_tensor[0] * output_tensor_grad[0]) - else: - fake_loss = torch.sum(torch.stack(early_exit_loss), dim=0) + output_tensor[0] - custom_backward(fake_loss, None) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) - elif config.deallocate_pipeline_outputs: + if config.deallocate_pipeline_outputs: custom_backward(output_tensor[0], output_tensor_grad[0]) else: torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) @@ -1317,7 +1356,99 @@ def enable_grad_sync(): return forward_data_store -def forward_backward_pipelining_with_early_exit( + +def early_exit_forward_backward_no_pipelining( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, # unused + micro_batch_size: int, # unused + decoder_seq_length: int = None, # unused + forward_only: bool = False, + collect_non_loss_data: bool = False, + early_exit_loss_weight: EarlyExitLossWeight = None, +): + """Run forward and backward passes with no pipeline parallelism + (no inter-stage communication). + + Returns dictionary with losses. + + + See get_forward_backward_func() for argument details + """ + + if isinstance(model, list): + assert len(model) == 1, "non-pipeline-parallel schedule does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + + no_sync_func = config.no_sync_func + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + + if early_exit_loss_weight: + early_exit_loss_weight.update() + + forward_data_store = [] + backward_data_store = [] + input_tensor, output_tensor_grad = None, None + with no_sync_func(): + for i in range(num_microbatches - 1): + output_tensor, early_exit_output = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + if not forward_only: + exit_loss = cal_early_exit_loss(early_exit_output, backward_data_store, num_microbatches, early_exit_loss_weight) + early_exit_backward_step(input_tensor, output_tensor, output_tensor_grad, config, early_exit_loss=exit_loss) + + # Run computation for last microbatch out of context handler (want to + # synchronize gradients). + output_tensor, early_exit_output = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + ) + + if not forward_only: + exit_loss = cal_early_exit_loss(early_exit_output, backward_data_store, num_microbatches, early_exit_loss_weight) + early_exit_backward_step(input_tensor, output_tensor, output_tensor_grad, config, early_exit_loss=exit_loss) + + if config.timers is not None: + config.timers('forward-backward').stop() + + if config.finalize_model_grads_func is not None and not forward_only: + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism and layernorm all-reduce for sequence parallelism). + config.finalize_model_grads_func([model]) + + forward_data_store = [{**f, **b} for (f, b) in zip(forward_data_store, backward_data_store)] + + return forward_data_store + + +def early_exit_forward_backward_pipelining( *, forward_step_func, data_iterator: Union[Iterator, List[Iterator]], @@ -1328,6 +1459,7 @@ def forward_backward_pipelining_with_early_exit( decoder_seq_length: int = None, forward_only: bool = False, collect_non_loss_data: bool = False, + early_exit_loss_weight: EarlyExitLossWeight = None, ): if isinstance(model, list): assert ( @@ -1346,6 +1478,9 @@ def forward_backward_pipelining_with_early_exit( "Non-interleaved pipeline parallelism does not support overlapping p2p communication" ) + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: @@ -1369,6 +1504,9 @@ def enable_grad_sync(): disable_grad_sync() + if early_exit_loss_weight: + early_exit_loss_weight.update() + # Compute number of warmup microbatches. num_warmup_microbatches = ( parallel_state.get_pipeline_model_parallel_world_size() @@ -1417,7 +1555,7 @@ def enable_grad_sync(): input_tensors = [] output_tensors = [] if has_early_exit: - early_exit_losses = [] + early_exit_loss_funcs = [] forward_data_store = [] @@ -1432,11 +1570,8 @@ def enable_grad_sync(): else: checkpoint_activations_microbatch = None - early_exit_loss = None - if has_early_exit: - early_exit_loss = [] input_tensor = recv_forward(recv_tensor_shapes, config) - output_tensor = forward_step( + output_tensor, early_exit_output = early_exit_forward_step( forward_step_func, data_iterator, model, @@ -1446,7 +1581,6 @@ def enable_grad_sync(): config, collect_non_loss_data, checkpoint_activations_microbatch, - early_exit_loss=early_exit_loss ) send_forward(output_tensor, send_tensor_shapes, config) @@ -1456,7 +1590,7 @@ def enable_grad_sync(): input_tensors.append(input_tensor) output_tensors.append(output_tensor) if has_early_exit: - early_exit_losses.append(early_exit_loss) + early_exit_loss_funcs.append(early_exit_output) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to @@ -1475,10 +1609,7 @@ def enable_grad_sync(): ) >= config.num_microbatches_with_partial_activation_checkpoints else: checkpoint_activations_microbatch = None - early_exit_loss = None - if has_early_exit: - early_exit_loss = [] - output_tensor = forward_step( + output_tensor, early_exit_output = early_exit_forward_step( forward_step_func, data_iterator, model, @@ -1488,7 +1619,6 @@ def enable_grad_sync(): config, collect_non_loss_data, checkpoint_activations_microbatch, - early_exit_loss=early_exit_loss ) if forward_only: @@ -1505,18 +1635,21 @@ def enable_grad_sync(): input_tensors.append(input_tensor) output_tensors.append(output_tensor) if has_early_exit: - early_exit_losses.append(early_exit_loss) - + early_exit_loss_funcs.append(early_exit_output) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - early_exit_loss = early_exit_losses.pop(0) if has_early_exit else None - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config, - early_exit_loss=early_exit_loss + if has_early_exit: + exit_loss = cal_early_exit_loss(early_exit_loss_funcs.pop(0), forward_data_store, num_microbatches, early_exit_loss_weight) + else: + exit_loss = None + + input_tensor_grad = early_exit_backward_step( + input_tensor, output_tensor, output_tensor_grad, config, + early_exit_loss=exit_loss ) if last_iteration: @@ -1530,7 +1663,6 @@ def enable_grad_sync(): # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): - # Enable async grad reduction in the last backward pass # Note: If grad sync function is provided, only enable # async grad reduction in first pipeline stage. Other @@ -1542,21 +1674,482 @@ def enable_grad_sync(): input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - early_exit_loss = early_exit_losses.pop(0) if has_early_exit else None - + if has_early_exit: + exit_loss = cal_early_exit_loss(early_exit_loss_funcs.pop(0), forward_data_store, num_microbatches, early_exit_loss_weight) + else: + exit_loss = None output_tensor_grad = recv_backward(send_tensor_shapes, config) - input_tensor_grad = backward_step( - input_tensor, output_tensor, output_tensor_grad, model_type, config, - early_exit_loss=early_exit_loss + input_tensor_grad = early_exit_backward_step( + input_tensor, output_tensor, output_tensor_grad, config, + early_exit_loss=exit_loss ) send_backward(input_tensor_grad, recv_tensor_shapes, config) - # Launch any remaining grad reductions - if no_sync_context is not None: - enable_grad_sync() - if config.grad_sync_func is not None: - config.grad_sync_func(model.parameters()) + if config.timers is not None: + config.timers('forward-backward').stop() + + if config.finalize_model_grads_func is not None and not forward_only: + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism, layernorm all-reduce for sequence parallelism, and + # embedding all-reduce for pipeline parallelism). + config.finalize_model_grads_func([model]) return forward_data_store + + +def early_exit_forward_backward_pipelining_with_bubble_filling( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, + num_fill_warmup_microbatches: int = 0, + num_fill_cooldown_microbatches: int = 0, + backward_forward_ratio: float = 2.0, + early_exit_loss_weight: EarlyExitLossWeight = None, +): + if isinstance(model, list): + assert ( + len(model) == 1 + ), "non-interleaved pipeline parallelism does not support model chunking" + model = model[0] + if isinstance(data_iterator, list): + assert ( + len(data_iterator) == 1 + ), "non-pipeline-parallel schedule does not support model chunking" + data_iterator = data_iterator[0] + + config = get_model_config(model) + if config.overlap_p2p_comm: + raise ValueError( + "Non-interleaved pipeline parallelism does not support overlapping p2p communication" + ) + + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + has_early_exit = parallel_state.has_early_exit() + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + if early_exit_loss_weight: + early_exit_loss_weight.update() + + # Compute number of warmup microbatches. + num_warmup_microbatches = ( + parallel_state.get_pipeline_model_parallel_world_size() + - parallel_state.get_pipeline_model_parallel_rank() + ) + + # todo @(pxc): check mirco_batch num globally + assert num_warmup_microbatches < num_microbatches + + num_microbatches_remaining = num_microbatches - num_warmup_microbatches - num_fill_warmup_microbatches + + num_microbatches_for_loss = num_microbatches - num_fill_warmup_microbatches + + cur_warmup_bubble_size = backward_forward_ratio * (num_warmup_microbatches - 1) - parallel_state.get_pipeline_model_parallel_rank() + next_warmup_bubble_size = max(0, cur_warmup_bubble_size - backward_forward_ratio - 1.0) + + cur_num_partial_forward_microbatches = min(num_fill_warmup_microbatches, int(cur_warmup_bubble_size / (1.0 + backward_forward_ratio))) + next_num_partial_forward_microbatches = min(num_fill_warmup_microbatches, int(next_warmup_bubble_size / (1.0 + backward_forward_ratio))) + + cur_cooldown_bubble_size = backward_forward_ratio * parallel_state.get_pipeline_model_parallel_rank() + pre_cooldown_bubble_size = max(0, cur_cooldown_bubble_size - backward_forward_ratio) + + cur_num_partial_backward_microbatches = min(num_fill_cooldown_microbatches, int(cur_cooldown_bubble_size / (1.0 + backward_forward_ratio))) + pre_num_partial_backward_microbatches = min(num_fill_cooldown_microbatches, int(pre_cooldown_bubble_size / (1.0 + backward_forward_ratio))) + + num_cooldown_microbatches = max(0, num_warmup_microbatches - num_fill_cooldown_microbatches + cur_num_partial_backward_microbatches) + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if config.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + model_type = get_model_type(model) + + rank = parallel_state.get_pipeline_model_parallel_rank() + recv_tensor_shapes = get_tensor_shapes( + rank=rank - 1, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + send_tensor_shapes = get_tensor_shapes( + rank=rank, + model_type=model_type, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + decoder_seq_length=decoder_seq_length, + config=config, + ) + + # Input, output tensors only need to be saved when doing backward passes + input_tensors = None + output_tensors = None + if not forward_only: + input_tensors = [] + output_tensors = [] + if has_early_exit: + early_exit_loss_funcs = [] + + forward_data_store = [] + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + i % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor, early_exit_output = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches_for_loss, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + send_forward(output_tensor, send_tensor_shapes, config) + + if not forward_only: + if not has_early_exit: + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + if has_early_exit: + early_exit_loss_funcs.append(early_exit_output) + + warmup_input_tensors = [] + warmup_output_tensors = [] + warmup_early_exit_outputs = [] + + # Fill warmup bubbles + for i in range(num_fill_warmup_microbatches): + if (has_early_exit or parallel_state.post_stage_has_early_exit()) and \ + (i < cur_num_partial_forward_microbatches): + is_last = i + 1 > next_num_partial_forward_microbatches + input_tensor = recv_forward(recv_tensor_shapes, config) + output_tensor, early_exit_output = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches_for_loss, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + if not is_last: + send_forward(output_tensor, send_tensor_shapes, config) + if not forward_only: + if not has_early_exit: + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + warmup_input_tensors.append(input_tensor) + warmup_output_tensors.append(output_tensor) + warmup_early_exit_outputs.append(early_exit_output) + else: + # align data iterator of all pipeline stages + next(data_iterator) + + if not forward_only: + for i in reversed(range(cur_num_partial_forward_microbatches)): + # compute first backward without recv backward + is_last = i + 1 > next_num_partial_forward_microbatches + if is_last: + output_tensor_grad = None + else: + output_tensor_grad = recv_backward( + send_tensor_shapes, config + ) + + input_tensor = warmup_input_tensors.pop(-1) + output_tensor = warmup_output_tensors.pop(-1) + exit_output = warmup_early_exit_outputs.pop(-1) + if exit_output: + exit_loss = cal_early_exit_loss( + exit_output, forward_data_store, + num_microbatches_for_loss, early_exit_loss_weight) + input_tensor_grad = early_exit_backward_step( + input_tensor, output_tensor, output_tensor_grad, config, + early_exit_loss=exit_loss + ) + send_backward(input_tensor_grad, recv_tensor_shapes, config) + elif output_tensor_grad: + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + send_backward(input_tensor_grad, recv_tensor_shapes, config) + + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + (i + num_warmup_microbatches) % max_outstanding_backprops + ) >= config.num_microbatches_with_partial_activation_checkpoints + else: + checkpoint_activations_microbatch = None + + if not forward_only: + output_tensor_grad = recv_backward( + send_tensor_shapes, config + ) + # Pop input_tensor and output_tensor from the start of the list for + # the backward pass. + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + + if has_early_exit: + exit_loss = cal_early_exit_loss( + early_exit_loss_funcs.pop(0), forward_data_store, + num_microbatches_for_loss, early_exit_loss_weight) + else: + exit_loss = None + + input_tensor_grad = early_exit_backward_step( + input_tensor, output_tensor, output_tensor_grad, config, + early_exit_loss=exit_loss + ) + input_tensor = recv_forward(recv_tensor_shapes, config) + send_backward(input_tensor_grad, recv_tensor_shapes, config) + else: + input_tensor = recv_forward(recv_tensor_shapes, config) + + output_tensor, early_exit_output = early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches_for_loss, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + ) + send_forward(output_tensor, send_tensor_shapes, config) + if not forward_only: + if not has_early_exit: + deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + # Add input_tensor and output_tensor to end of list. + input_tensors.append(input_tensor) + output_tensors.append(output_tensor) + if has_early_exit: + early_exit_loss_funcs.append(early_exit_output) + + if not forward_only: + # Run cooldown backward passes. + for i in range(num_cooldown_microbatches): + is_partial_backward = (num_cooldown_microbatches - i) <= (cur_num_partial_backward_microbatches - pre_num_partial_backward_microbatches) + if i == num_cooldown_microbatches - 1: + if config.grad_sync_func is None or rank == 0: + enable_grad_sync() + + input_tensor = input_tensors.pop(0) + output_tensor = output_tensors.pop(0) + if has_early_exit: + exit_loss = cal_early_exit_loss(early_exit_loss_funcs.pop(0), forward_data_store, num_microbatches_for_loss, early_exit_loss_weight) + else: + exit_loss = None + output_tensor_grad = recv_backward(send_tensor_shapes, config) + + input_tensor_grad = early_exit_backward_step( + input_tensor, output_tensor, output_tensor_grad, config, + early_exit_loss=exit_loss + ) + if not is_partial_backward: + send_backward(input_tensor_grad, recv_tensor_shapes, config) + + if config.timers is not None: + config.timers('forward-backward').stop() + + if config.finalize_model_grads_func is not None and not forward_only: + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism, layernorm all-reduce for sequence parallelism, and + # embedding all-reduce for pipeline parallelism). + config.finalize_model_grads_func([model]) + + return forward_data_store + + +def cal_early_exit_loss(early_exit_loss_funcs, forward_data_store, num_microbatches, early_exit_loss_weight): + exit_loss_dict = {} + exit_losses = [] + with torch.enable_grad(): + for layer_num, exit_loss_func in early_exit_loss_funcs.items(): + loss = exit_loss_func(log_dict=exit_loss_dict) + loss_weight = early_exit_loss_weight.get_weight(layer_num) + exit_losses.append(loss.multiply_(loss_weight)) + exit_loss_dict[f'exit weight [{layer_num}]'] = early_exit_loss_weight.get_weight(layer_num) + forward_data_store.append(exit_loss_dict) + return torch.sum(torch.stack(exit_losses), dim=0).div(num_microbatches) + + +def early_exit_forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data=False, + checkpoint_activations_microbatch=None, +): + """Forward step for early exit model. + + If first stage, input tensor is obtained from data_iterator, otherwise + passed-in input_tensor is used. + + Returns output tensor.""" + if config.timers is not None: + config.timers('forward-compute', log_level=2).start() + + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") + set_input_tensor(input_tensor) + + if config.enable_autocast: + context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) + else: + context_manager = contextlib.nullcontext() + with context_manager: + if checkpoint_activations_microbatch is None: + lm_output, loss_func = forward_step_func(data_iterator, model) + else: + lm_output, loss_func = forward_step_func( + data_iterator, model, checkpoint_activations_microbatch + ) + early_exit_output = None + if parallel_state.has_early_exit(): + output_tensor, early_exit_output = lm_output + else: + output_tensor = lm_output + loss_dict = {} + + if parallel_state.is_pipeline_last_stage(): + output_tensor = loss_func(output_tensor=output_tensor, + log_dict=loss_dict, + log_key='lm loss') + output_tensor.div_(num_microbatches) + + if loss_dict: + forward_data_store.append(loss_dict) + + if config.timers is not None: + config.timers('forward-compute').stop() + + if unwrap_output_tensor: + return output_tensor, early_exit_output + + return [output_tensor], early_exit_output + + +def early_exit_backward_step(input_tensor, output_tensor, output_tensor_grad, config, early_exit_loss=None): + """Backward step through passed-in output tensor. + + If last stage, output_tensor_grad is None, otherwise gradient of loss + with respect to stage's output tensor. + + Returns gradient of loss with respect to input tensor (None if first + stage).""" + + if config.timers is not None: + config.timers('backward-compute', log_level=2).start() + + # Retain the grad on the input_tensor. + unwrap_input_tensor_grad = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_input_tensor_grad = True + for x in input_tensor: + if x is not None: + x.retain_grad() + + if not isinstance(output_tensor, list): + output_tensor = [output_tensor] + if not isinstance(output_tensor_grad, list): + output_tensor_grad = [output_tensor_grad] + + # Backward pass. + if output_tensor_grad[0] is None and config.grad_scale_func is not None: + output_tensor[0] = config.grad_scale_func(output_tensor[0]) + if early_exit_loss is not None: + if output_tensor_grad[0] is not None: + fake_loss = early_exit_loss + torch.sum(output_tensor[0] * output_tensor_grad[0]) + elif output_tensor[0].numel() == 1: + fake_loss = early_exit_loss + output_tensor[0] + else: + fake_loss = early_exit_loss + custom_backward(fake_loss, None) + elif config.deallocate_pipeline_outputs: + custom_backward(output_tensor[0], output_tensor_grad[0]) + else: + torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) + + # Collect the grad of the input_tensor. + input_tensor_grad = [None] + if input_tensor is not None: + input_tensor_grad = [] + for x in input_tensor: + if x is None: + input_tensor_grad.append(None) + else: + input_tensor_grad.append(x.grad) + + if unwrap_input_tensor_grad: + input_tensor_grad = input_tensor_grad[0] + + if config.timers is not None: + config.timers('backward-compute').stop() + + return input_tensor_grad diff --git a/megatron/multi_exit_text_generation_server.py b/megatron/early_exit_text_generation_server.py similarity index 94% rename from megatron/multi_exit_text_generation_server.py rename to megatron/early_exit_text_generation_server.py index a076eace..16c8d610 100644 --- a/megatron/multi_exit_text_generation_server.py +++ b/megatron/early_exit_text_generation_server.py @@ -1,4 +1,3 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import datetime import time import torch @@ -58,13 +57,16 @@ async def generate(self, req): stop_token_ids=req['stop_sequences'], prevent_newline_after_colon=req['prevent_newline_after_colon'], random_seed=req['random_seed'], - early_exit_thres=req['early_exit_thres']) + early_exit_thres=req['early_exit_thres'], + use_early_exit=req['use_early_exit'], + print_max_prob=req['print_max_prob']) end_time = time.time() print(f"Response(use {end_time - start_time}s): " + str(response)) return { "text": response, "segments": response_seg, - "logprobs": response_logprobs + "logprobs": response_logprobs, + "requst_time": end_time - start_time } def put(self): @@ -123,6 +125,11 @@ def put(self): else: raw_req['early_exit_thres'] = 40.0 + if "print_max_prob" in raw_req: + raw_req['print_max_prob'] = True + else: + raw_req['print_max_prob'] = False + top_k = 0.0 if "top_k" in raw_req: top_k = raw_req["top_k"] @@ -213,6 +220,11 @@ def put(self): else: raw_req['random_seed'] = 1234 + if "use_early_exit" in raw_req: + raw_req['use_early_exit'] = True + else: + raw_req['use_early_exit'] = False + no_log = False if "no_log" in raw_req: no_log = raw_req["no_log"] @@ -253,5 +265,5 @@ def __init__(self, model): api = Api(self.app) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) - def run(self, host, port): + def run(self, host, port): self.app.run(host=host, port=port, threaded=True, debug=False) diff --git a/megatron/global_vars.py b/megatron/global_vars.py index 38912b08..caa257a6 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -20,7 +20,6 @@ _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None _GLOBAL_SIGNAL_HANDLER = None -_GLOBAL_WANDB = None def get_args(): @@ -89,7 +88,7 @@ def _set_signal_handler(): -def set_global_variables(args, build_tokenizer=True): +def set_global_variables(args, build_tokenizer=True, init_wandb=True): """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" assert args is not None @@ -100,8 +99,9 @@ def set_global_variables(args, build_tokenizer=True): _build_num_microbatches_calculator(args) if build_tokenizer: _ = _build_tokenizer(args) + if init_wandb: + _set_wandb_writer(args) _set_tensorboard_writer(args) - _set_wandb_writer(args) _set_adlr_autoresume(args) _set_timers(args) @@ -171,35 +171,25 @@ def _set_wandb_writer(args): is_pipeline_stage_main = ((args.rank + 1) % pipeline_group_size) == 0 pipeline_stage_id = int(args.rank // pipeline_group_size) description = os.environ.get('RUN_DESCRIPTION', default='') - if getattr(args, 'wandb_project', '') and is_pipeline_stage_main: + if getattr(args, 'wandb_project') and is_pipeline_stage_main: if args.wandb_exp_name == '': raise ValueError("Please specify the wandb experiment name!") - try: - import wandb - is_master = args.rank == (args.world_size - 1) - name = f'{args.wandb_exp_name}-master' if is_master \ - else f'{args.wandb_exp_name}-worker-{pipeline_stage_id}', - if args.wandb_save_dir: - save_dir = args.wandb_save_dir - else: - # Defaults to the save dir. - save_dir = os.path.join(args.save, 'wandb') - wandb_kwargs = { - 'dir': save_dir, - 'name': name, - 'project': args.wandb_project, - 'group': args.wandb_group, - 'config': vars(args), - 'force': False, - 'notes': description, - 'save_code': False, - 'tags': ['master' if is_master else 'worker']} - os.makedirs(wandb_kwargs['dir'], exist_ok=True) - wandb.init(**wandb_kwargs) - _GLOBAL_WANDB_WRITER = wandb - except Exception: - print("WARNING: Skip wandb setup. Please execute " - "'wandb login' to enable wandb.", flush=True) + import wandb + is_master = args.rank == (args.world_size - 1) + name = f'{args.wandb_exp_name}-master' if is_master \ + else f'{args.wandb_exp_name}-worker-{pipeline_stage_id}' + wandb.init( + project=args.wandb_project, + group=args.wandb_group, + name=name, + save_code=False, + config=args, + force=False, + notes=description, + tags=['master'if is_master else 'worker'] + ) + _GLOBAL_WANDB_WRITER = wandb + def _set_adlr_autoresume(args): """Initialize ADLR autoresume.""" diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index a0eb8363..52b93997 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -5,7 +5,7 @@ from .bert_model import BertModel from .gpt_model import GPTModel -from .multi_exit_gpt_model import MultiExitGPTModel +from .early_exit_gpt_model import EarlyExitGPTModel from .t5_model import T5Model from .language_model import get_language_model from .module import Float16Module diff --git a/megatron/model/multi_exit_gpt_model.py b/megatron/model/early_exit_gpt_model.py similarity index 60% rename from megatron/model/multi_exit_gpt_model.py rename to megatron/model/early_exit_gpt_model.py index ed760a2d..8df8242a 100644 --- a/megatron/model/multi_exit_gpt_model.py +++ b/megatron/model/early_exit_gpt_model.py @@ -1,8 +1,7 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""GPT-2 model.""" +"""Early-exit GPT model.""" import torch +import torch.nn.functional as F from megatron import get_args from megatron.core import tensor_parallel, mpu @@ -17,7 +16,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy, - temperature=1.0): + temperature=1.0, + log_dict=None, + log_key=None): # Output. Format [s b h] output = parallel_lm_logits( @@ -43,8 +44,49 @@ def post_language_model_processing(lm_output, labels, logit_weights, loss = loss.transpose(0,1).contiguous() return loss -class MultiExitGPTModel(MegatronModule): - """Multi-Exit GPT Language model.""" + +def early_exit_processing(lm_output, labels, logit_weights, + parallel_output, + fp16_lm_cross_entropy, + temperature=1.0, + log_dict=None, + log_key=None): + output = parallel_lm_logits( + lm_output, + logit_weights, + parallel_output) + + if labels is None: + # [s b h] => [b s h] + return output.transpose(0,1).contiguous() + else: + # [b s] => [s b] + labels = labels.transpose(0,1).contiguous() + + if temperature != 1.0: + output.div_(temperature) + + with torch.no_grad(): + max_log_probs, max_idx = torch.max(F.log_softmax(output, dim=2), dim=2) + dynamic_loss_weights = torch.exp(max_log_probs) + if log_dict: + log_dict[log_key] = dynamic_loss_weights.mean() + + if fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) + + loss.multiply_(dynamic_loss_weights) + + # [s b] => [b, s] + loss = loss.transpose(0,1).contiguous() + return loss + + +class EarlyExitGPTModel(MegatronModule): + """Early-exit GPT Language model.""" def __init__(self, config, @@ -60,7 +102,6 @@ def __init__(self, self.post_process = post_process self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights - self.untie_exit_output_weights = args.untie_exit_output_weights self.language_model, self._language_model_key = get_language_model( config=config, @@ -71,42 +112,52 @@ def __init__(self, post_process=self.post_process) self.has_early_exit = mpu.has_early_exit() + self.use_dynamic_exit_layer_weight = args.use_dynamic_exit_layer_weight if not args.untie_embeddings_and_output_weights: self.initialize_word_embeddings() - self.initialize_exit_output_weights() - if self.has_early_exit: - self.exit_layer_loss_weight = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_weight))) - self.exit_layer_temperature = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_temperature))) - self.language_model.set_exit_output_weights(self.exit_output_weight) + self.exit_layer_loss_weight = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), \ + zip(args.exit_layer_nums, args.exit_layer_weight))) + self.exit_layer_temperature = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), \ + zip(args.exit_layer_nums, args.exit_layer_temperature))) + self.language_model.initialize_exit_output_weights(config, self.shared_embedding_or_output_weight() \ + if not args.untie_embeddings_and_output_weights else None) if self.post_process: - self.output_weight = self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight() + self.output_weight = self.get_output_weight() def set_input_tensor(self, input_tensor): """See megatron.model.transformer.set_input_tensor()""" self.language_model.set_input_tensor(input_tensor) + def get_output_weight(self): + if self.untie_embeddings_and_output_weights: + return self.language_model.output_layer.weight + elif self.pre_process: + return self.language_model.embedding.word_embeddings.weight + else: + return self.word_embeddings.weight + def forward(self, input_ids, position_ids, attention_mask, retriever_input_ids=None, retriever_position_ids=None, retriever_attn_mask=None, labels=None, tokentype_ids=None, inference_params=None, - masked_loss_func=None): + exit_loss_func=None): - early_exit_losses = dict() + early_exit_output = list() if self.has_early_exit: - exit_post_process_func = partial( - post_language_model_processing, + exit_process_func = partial( + early_exit_processing if self.use_dynamic_exit_layer_weight else post_language_model_processing, labels=labels, parallel_output=self.parallel_output, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy ) - lm_output, early_exit_losses = self.language_model( + lm_output, early_exit_output = self.language_model( input_ids, position_ids, attention_mask, @@ -114,8 +165,8 @@ def forward(self, input_ids, position_ids, attention_mask, retriever_position_ids=retriever_position_ids, retriever_attn_mask=retriever_attn_mask, inference_params=inference_params, - exit_post_process_func=exit_post_process_func, - exit_loss_func=masked_loss_func) + exit_process_func=exit_process_func, + exit_loss_func=exit_loss_func) else: lm_output = self.language_model( input_ids, @@ -126,14 +177,16 @@ def forward(self, input_ids, position_ids, attention_mask, retriever_attn_mask=retriever_attn_mask, inference_params=inference_params) - if self.post_process: + if inference_params is not None and inference_params.has_early_exited: + return lm_output + elif self.post_process: lm_output = post_language_model_processing( lm_output, labels, self.output_weight, self.parallel_output, self.fp16_lm_cross_entropy) if self.has_early_exit and inference_params is None: - return lm_output, early_exit_losses + return lm_output, early_exit_output else: return lm_output @@ -144,45 +197,19 @@ def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): = self.language_model.state_dict_for_save_checkpoint( prefix=prefix, keep_vars=keep_vars) # Save word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + if mpu.is_output_embedding_pipeline_stage() and not self.pre_process and not self.untie_embeddings_and_output_weights: state_dict_[self._word_embeddings_for_head_key] \ = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) - if self.has_early_exit and self.untie_exit_output_weights: - state_dict_[self._exit_output_key] = self.exit_output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Load word_embeddings. - if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights: + if mpu.is_output_embedding_pipeline_stage() and not self.pre_process and not self.untie_embeddings_and_output_weights: self.word_embeddings.load_state_dict( state_dict[self._word_embeddings_for_head_key], strict=strict) - if self.has_early_exit and self.untie_exit_output_weights: - self.exit_output_layer.load_state_dict( - state_dict[self._exit_output_key], strict=strict) if self._language_model_key in state_dict: state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) - - def initialize_exit_output_weights(self): - if not self.has_early_exit: - return - args = get_args() - self.exit_output_weight = dict() - self.exit_output_layer = torch.nn.ModuleList() - for layer_num in mpu.get_early_exit_layer_nums(): - if self.untie_exit_output_weights: - self.exit_output_layer.append(tensor_parallel.ColumnParallelLinear( - args.hidden_size, - args.padded_vocab_size, - config=self.config, - init_method=self.config.init_method, - bias=False)) - self.exit_output_weight[layer_num] = self.exit_output_layer[-1].weight - else: - # todo @pxc: fix bug when untie_embeddings_and_output_weights is True - assert not self.untie_embeddings_and_output_weights - self.exit_output_weight[layer_num] = self.shared_embedding_or_output_weight() - self._exit_output_key = 'exit_output_layer' \ No newline at end of file diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 560b921b..854e1117 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -12,7 +12,7 @@ from .enums import AttnMaskType, LayerType from .module import MegatronModule -from .transformer import ParallelTransformer, MultiExitParallelTransformer +from .transformer import ParallelTransformer, EarlyExitParallelTransformer from .utils import get_linear_layer from .utils import init_method_normal, scaled_init_method_normal @@ -40,7 +40,6 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, gradient_accumulation_fusion=args.gradient_accumulation_fusion, async_grad_allreduce=async_grad_allreduce, sequence_parallel=args.sequence_parallel) - # Gather if needed. if parallel_output: return logits_parallel @@ -64,7 +63,7 @@ def get_language_model(config, num_tokentypes, add_pooler, # Language model. if mpu.has_early_exit(): - language_model = MultiExitTransformerLanguageModel( + language_model = EarlyExitTransformerLanguageModel( config, encoder_attn_mask_type, num_tokentypes=num_tokentypes, @@ -636,23 +635,18 @@ def load_state_dict(self, state_dict, strict=True): strict=strict) -class MultiExitTransformerLanguageModel(TransformerLanguageModel): +class EarlyExitTransformerLanguageModel(TransformerLanguageModel): def __init__(self, config, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True,add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): - super(MultiExitTransformerLanguageModel, self).__init__( + super(EarlyExitTransformerLanguageModel, self).__init__( config, encoder_attn_mask_type, num_tokentypes, add_encoder, add_decoder, decoder_attn_mask_type, add_pooler, pre_process, post_process) - assert mpu.has_early_exit(), "MultiExitTransformerLanguageModel requires at least one early exit layer (in current pipeline stage)" - - def set_exit_output_weights(self, exit_output_weight): - assert self.encoder is not None, 'exit output weights is only available in MultiExitParallelTransformer' - assert type(self.encoder) is MultiExitParallelTransformer, 'exit output weights is only available in MultiExitParallelTransformer' - self.encoder.set_exit_output_weights(exit_output_weight=exit_output_weight) + assert mpu.has_early_exit(), "EarlyExitTransformerLanguageModel requires at least one early exit layer (in current pipeline stage)" def _build_encoder(self, config, args): - return MultiExitParallelTransformer( + return EarlyExitParallelTransformer( config, model_type=args.model_type, self_attn_mask_type=self.encoder_attn_mask_type, @@ -660,6 +654,31 @@ def _build_encoder(self, config, args): post_process=self.post_process, ) + def initialize_exit_output_weights(self, config, word_embedding=None): + args = get_args() + self.untie_exit_output_weights = args.untie_exit_output_weights + self.exit_output_weights = dict() + if self.untie_exit_output_weights: + self.exit_output_layer = torch.nn.ModuleList() + self._exit_output_key = 'exit_output_layer' + for layer_num in mpu.get_early_exit_layer_nums(): + if self.untie_exit_output_weights: + self.exit_output_layer.append(tensor_parallel.ColumnParallelLinear( + args.hidden_size, + args.padded_vocab_size, + config=config, + init_method=self.init_method, + bias=False)) + self.exit_output_weights[layer_num] = self.exit_output_layer[-1].weight + else: + # todo @pxc: fix bug when untie_embeddings_and_output_weights is True + assert not self.untie_embeddings_and_output_weights + self.exit_output_weights[layer_num] = word_embedding + assert self.encoder is not None, 'exit output weights is only available in EarlyExitParallelTransformer' + assert type(self.encoder) is EarlyExitParallelTransformer, 'exit output weights is only available in EarlyExitParallelTransformer' + self.encoder.set_exit_output_weights(exit_output_weights=self.exit_output_weights) + + def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, retriever_input_ids=None, @@ -669,7 +688,7 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, inference_params=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False, - exit_post_process_func=None, + exit_process_func=None, exit_loss_func=None): # Encoder embedding. if self.pre_process: @@ -698,7 +717,119 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, retriever_attn_mask=retriever_attn_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, - exit_post_process_func=exit_post_process_func, + exit_process_func=exit_process_func, exit_loss_func=exit_loss_func) - return encoder_output, early_exit_output \ No newline at end of file + return encoder_output, early_exit_output + + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + state_dict_[self._word_embeddings_key] \ + = self.word_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.untie_exit_output_weights: + state_dict_[self._exit_output_key] = self.exit_output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_position_embedding: + state_dict_[self._position_embeddings_key] \ + = self.position_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + if self.num_tokentypes > 0: + state_dict_[self._tokentype_embeddings_key] \ + = self.tokentype_embeddings.state_dict(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + if self.pre_process: + state_dict_[self._embedding_key] \ + = self.embedding.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_exit_output_weights: + state_dict_[self._exit_output_key] = self.exit_output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_encoder: + state_dict_[self._encoder_key] \ + = self.encoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.post_process: + if self.add_pooler: + state_dict_[self._pooler_key] \ + = self.pooler.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + if self.untie_embeddings_and_output_weights: + state_dict_[self._output_layer_key] \ + = self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) + + if self.add_decoder: + state_dict_[self._decoder_key] \ + = self.decoder.state_dict_for_save_checkpoint(prefix=prefix, + keep_vars=keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Embedding. + if self.pre_process: + if self._embedding_key in state_dict: + state_dict_ = state_dict[self._embedding_key] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if '_embeddings' in key: + state_dict_[key] = state_dict[key] + self.embedding.load_state_dict(state_dict_, strict=strict) + + # Exit Word embedding. + if self.untie_exit_output_weights: + self.exit_output_layer.load_state_dict( + state_dict[self._exit_output_key], strict=strict) + + # Encoder. + if self.add_encoder: + if self._encoder_key in state_dict: + state_dict_ = state_dict[self._encoder_key] + # For backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # For backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # For backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", + ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.encoder.load_state_dict(state_dict_, strict=strict) + + # Pooler. + if self.post_process: + if self.add_pooler: + assert 'pooler' in state_dict, \ + 'could not find data for pooler in the checkpoint' + self.pooler.load_state_dict(state_dict[self._pooler_key], + strict=strict) + if self.untie_embeddings_and_output_weights: + assert 'output_layer' in state_dict, \ + 'could not find data for output_layer in the checkpoint' + self.output_layer.load_state_dict(state_dict[self._output_layer_key], + strict=strict) \ No newline at end of file diff --git a/megatron/model/module.py b/megatron/model/module.py index c2887315..e44b0ae2 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -71,7 +71,7 @@ def initialize_word_embeddings(self): # 3. In the training loop, before an all-reduce between the grads of # the two word_embeddings layers to ensure that every applied weight # update is the same on both stages. - if mpu.is_pipeline_last_stage() and not self.pre_process: + if mpu.is_output_embedding_pipeline_stage() and not self.pre_process: assert not mpu.is_pipeline_first_stage() self._word_embeddings_for_head_key = 'word_embeddings_for_head' # set word_embeddings weights to 0 here, then copy first diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 8bdbec3f..09614b9c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from typing import Optional +from functools import partial from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches from .module import MegatronModule @@ -282,16 +283,9 @@ class ExitMLP(MegatronModule): def __init__(self, config): super(ExitMLP, self).__init__() - # todo @pxc: merge trunk and branch into one ParallelMLP self.trunk = ParallelMLP(config) self.branch = ParallelMLP(config) - def forward(self, hidden_states): - trunk_output, trunk_output_bias = self.trunk(hidden_states) - branch_output, branch_output_bias = self.branch(hidden_states) - - return trunk_output, trunk_output_bias, branch_output, branch_output_bias - class CoreAttention(MegatronModule): @@ -691,7 +685,7 @@ def forward(self, hidden_states, attention_mask, dim=3) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - - query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) + query_layer = query_layer.reshape(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -1265,39 +1259,125 @@ def forward(self, hidden_states, attention_mask, else: return output -class MultiExitParallelTransformerLayer(ParallelTransformerLayer): - """TransformerLayer of Multi Exit Transformer +class EarlyExitTransformerLayer(MegatronModule): + """ """ - def __init__(self, config, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0): - super(MultiExitParallelTransformerLayer, self).__init__(config, layer_number, layer_type, self_attn_mask_type, drop_path_rate) - def _build_mlp(self, config, num_experts=None): - assert num_experts == None, 'multi exit not supports MoE' - return ExitMLP(config) + def __init__(self, config, + layer_number, layer_type=LayerType.encoder, + self_attn_mask_type=AttnMaskType.padding, + drop_path_rate=0.): + args = get_args() + super(EarlyExitTransformerLayer, self).__init__() + self.layer_number = layer_number + self.layer_type = layer_type - def forward(self, hidden_states, attention_mask, - encoder_output=None, enc_dec_attn_mask=None, - retriever_input=None, - retriever_output=None, - retriever_attn_mask=None, - inference_params=None, - rotary_pos_emb=None): - # hidden_states: [s, b, h] + self.apply_residual_connection_post_norm \ + = config.apply_residual_connection_post_layernorm + + self.bf16 = config.bf16 + self.fp32_residual_connection = config.fp32_residual_connection + + # Early exit + self.pre_exit = args.pre_exit + self.use_exit_mlp = args.use_exit_mlp + self.use_exit_norm = args.use_exit_norm + self.use_exit_block = args.use_exit_block + self.exit_layer_temperature = args.exit_layer_temperature[args.exit_layer_nums.index(self.layer_number)] + self.exit_output_weight = None + + if self.use_exit_norm: + self.exit_norm = get_norm(config) + + if self.use_exit_block: + self.exit_block = ParallelTransformerLayer( + config, + layer_number=layer_number + args.num_layers, + layer_type=layer_type, + self_attn_mask_type=self_attn_mask_type, + drop_path_rate=0) + + # Normalize the input data. + self.input_norm = get_norm(config) + + # Self attention. + self.self_attention = self._build_attention(config, layer_number, self_attn_mask_type) + self.hidden_dropout = config.hidden_dropout + self.bias_dropout_fusion = config.bias_dropout_fusion + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None + + # Normalize the attention output + self.post_attention_norm = get_norm(config) + + # MLP + self.mlp = self._build_mlp(config, layer_number) + + # Set bias+dropout+add fusion grad_enable execution handler. + TORCH_MAJOR = int(torch.__version__.split('.')[0]) + TORCH_MINOR = int(torch.__version__.split('.')[1]) + use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + self.bias_dropout_add_exec_handler = \ + nullcontext if use_nvfuser else torch.enable_grad + + + def _build_attention(self, config, layer_num, self_attn_mask_type): + return ParallelAttention(config, + layer_num, + AttnType.self_attn, + self_attn_mask_type) + + def _build_mlp(self, config, layer_num): + if self.use_exit_mlp: + return ExitMLP(config) + return ParallelMLP(config) + + def set_exit_output_weight(self, weight): + self.exit_output_weight = weight + + def _forward_mlp(self, mlp, norm_output, residual, bias_dropout_add_func): + mlp_output, mlp_bias = mlp(norm_output) + + if self.drop_path is None: + if mlp_bias is not None: + mlp_bias = mlp_bias.expand_as(residual) + with self.bias_dropout_add_exec_handler(): + output = bias_dropout_add_func( + mlp_output, + mlp_bias, + residual, + self.hidden_dropout) + + output = core.utils.make_viewless_tensor(inp = output, + requires_grad = output.requires_grad, + keep_graph = True) + else: + if mlp_bias is not None: + mlp_output = mlp_output + mlp_bias + out = torch.nn.functional.dropout(mlp_output, + p=self.hidden_dropout, + training=self.training) + output = residual + self.drop_path(out) + return output + + def _forward_main(self, hidden_states, attention_mask, + inference_params=None, + rotary_pos_emb=None): + # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) + norm_output = self.input_norm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( - layernorm_output, + norm_output, attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb) # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output + if self.apply_residual_connection_post_norm: + residual = norm_output else: residual = hidden_states @@ -1317,7 +1397,7 @@ def forward(self, hidden_states, attention_mask, if attention_bias is not None: attention_bias = attention_bias.expand_as(residual) with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( + norm_input = bias_dropout_add_func( attention_output, attention_bias, residual, @@ -1326,63 +1406,108 @@ def forward(self, hidden_states, attention_mask, out = torch.nn.functional.dropout(attention_output + attention_bias, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + self.drop_path(out) + norm_input = residual + self.drop_path(out) # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - # trunk MLP. - mlp_output, mlp_bias, exit_mlp_output, exit_mlp_bias = self.mlp(layernorm_output) + norm_output = self.post_attention_norm(norm_input) # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output + if self.apply_residual_connection_post_norm: + residual = norm_output else: - residual = layernorm_input + residual = norm_input - # todo @pxc: reduce duplication - if self.drop_path is None: - if mlp_bias is not None: - mlp_bias = mlp_bias.expand_as(residual) - exit_mlp_bias = exit_mlp_bias.expand_as(residual) - with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias, - residual, - self.hidden_dropout) - exit_output = bias_dropout_add_func( - exit_mlp_output, - exit_mlp_bias, - residual, - self.hidden_dropout) + # MLP. + output = self._forward_mlp(mlp=self.mlp.trunk if self.use_exit_mlp else self.mlp, + norm_output=norm_output, + residual=residual, + bias_dropout_add_func=bias_dropout_add_func) + if not self.use_exit_mlp: + return output - # Jit compiled function creates 'view' tensor. This tensor - # potentially gets saved in the MPU checkpoint function context, - # which rejects view tensors. While making a viewless tensor here - # won't result in memory savings (like the data loader, or - # p2p_communication), it serves to document the origin of this - # 'view' tensor. - output = core.utils.make_viewless_tensor(inp = output, - requires_grad = output.requires_grad, - keep_graph = True) - exit_output = core.utils.make_viewless_tensor(inp = exit_output, - requires_grad = exit_output.requires_grad, - keep_graph = True) + # exit MLP. + exit_output = self._forward_mlp(mlp=self.mlp.branch, + norm_output=norm_output, + residual=residual, + bias_dropout_add_func=bias_dropout_add_func) + return output, exit_output + def _cal_exit_loss(self, hidden_states, exit_process_func, exit_loss_func, + inference_params=None, attention_mask=None, + rotary_pos_emb=None, lazy_hidden_states=False, + log_dict=None): + if lazy_hidden_states: + hidden_states = hidden_states() + if self.use_exit_block: + hidden_states = self.exit_block(hidden_states, attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + if self.use_exit_norm: + hidden_states = self.exit_norm(hidden_states) + return exit_loss_func(output_tensor=exit_process_func(lm_output=hidden_states, + temperature=self.exit_layer_temperature, + log_dict=log_dict, + log_key=f'dynamic exit weight [{self.layer_number}]'), + log_dict=log_dict, + log_key=f'early loss [{self.layer_number}]') + + def _forward_exit(self, hidden_states, exit_process_func, exit_loss_func, + inference_params, attention_mask=None, rotary_pos_emb=None): + if inference_params is not None and inference_params.use_early_exit: + if self.use_exit_block: + hidden_states = self.exit_block(hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + if self.use_exit_norm: + hidden_states = self.exit_norm(hidden_states) + exit_logits = exit_process_func(lm_output=hidden_states, + temperature=self.exit_layer_temperature) + exit = inference_params.do_early_exit(exit_logits, self.layer_number) + return exit_logits, exit else: - if mlp_bias is not None: - mlp_output = mlp_output + mlp_bias - exit_mlp_output = exit_mlp_output + exit_mlp_bias + lazy_exit_forward_func = partial(self._cal_exit_loss, + hidden_states=hidden_states, + exit_process_func=exit_process_func, + exit_loss_func=exit_loss_func, + lazy_hidden_states=False) + return lazy_exit_forward_func, False - output = residual + self.drop_path(torch.nn.functional.dropout(mlp_output, - p=self.hidden_dropout, - training=self.training)) - exit_output = residual + self.drop_path(torch.nn.functional.dropout(exit_mlp_output, - p=self.hidden_dropout, - training=self.training)) + def forward(self, hidden_states, attention_mask, + encoder_output=None, enc_dec_attn_mask=None, + retriever_input=None, + retriever_output=None, + retriever_attn_mask=None, + inference_params=None, + rotary_pos_emb=None, + exit_process_func=None, + exit_loss_func=None): + if self.pre_exit: + exit_output, exit = self._forward_exit(hidden_states=hidden_states, + inference_params=inference_params, + exit_process_func=exit_process_func, + exit_loss_func=exit_loss_func, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb) + if exit: + return hidden_states, exit_output, True + hidden_states = self._forward_main(hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) + if self.use_exit_mlp: + hidden_states, exit_hidden_states = hidden_states + else: + exit_hidden_states = hidden_states + if not self.pre_exit: + exit_output, exit = self._forward_exit(hidden_states=exit_hidden_states, + inference_params=inference_params, + exit_process_func=exit_process_func, + exit_loss_func=exit_loss_func, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb) + return hidden_states, exit_output, exit - return output, exit_output class NoopTransformerLayer(MegatronModule): """A single 'no-op' transformer layer. @@ -1916,35 +2041,30 @@ def load_state_dict(self, state_dict, strict=True): super().load_state_dict(state_dict_, strict) -class MultiExitParallelTransformer(ParallelTransformer): - """Multi Exit Transformer class.""" +class EarlyExitParallelTransformer(ParallelTransformer): + """Early-exit Transformer class.""" def __init__(self, config, model_type, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, - post_layer_norm=True, + post_norm=True, pre_process=True, post_process=True, drop_path_rate=0.0): - super(MultiExitParallelTransformer, self).__init__( + super(EarlyExitParallelTransformer, self).__init__( config, model_type, layer_type, self_attn_mask_type, - post_layer_norm, pre_process, post_process, + post_norm, pre_process, post_process, drop_path_rate ) self.exit_states = list(map(lambda x: x in mpu.get_early_exit_layer_nums(), self.layer_nums)) - self.early_exit_output = dict() - self.exit_output_weights = dict() - args = get_args() - self.use_exit_mlp = args.use_exit_mlp - self.exit_layer_loss_weight = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_weight))) - self.exit_layer_temperature = dict(filter(lambda p: p[0] in mpu.get_early_exit_layer_nums(), zip(args.exit_layer_nums, args.exit_layer_temperature))) + def _build_layer(self, layer_number, args, config, model_type, layer_type, self_attn_mask_type): assert args.transformer_impl == 'local', "early exit only supports transformer_impl=='local'" assert model_type == ModelType.encoder_or_decoder, \ "early exit only supports model_type==ModelType.encoder_or_decoder" - if (layer_number in mpu.get_early_exit_layer_nums()) and args.use_exit_mlp: - return MultiExitParallelTransformerLayer( + if layer_number in mpu.get_early_exit_layer_nums(): + return EarlyExitTransformerLayer( config, layer_number, layer_type=layer_type, @@ -1958,8 +2078,8 @@ def _build_layer(self, layer_number, args, config, model_type, layer_type, self_ self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) - def set_exit_output_weights(self, exit_output_weight): - self.exit_output_weights = exit_output_weight + def set_exit_output_weights(self, exit_output_weights): + self.exit_output_weights = exit_output_weights def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, @@ -1968,7 +2088,7 @@ def forward(self, hidden_states, attention_mask, retriever_attn_mask=None, inference_params=None, rotary_pos_emb=None, - exit_post_process_func=None, + exit_process_func=None, exit_loss_func=None): if not self.pre_process: hidden_states = self.input_tensor @@ -1978,7 +2098,7 @@ def forward(self, hidden_states, attention_mask, requires_grad=True, keep_graph=True, ) - early_exit_losses = dict() + lazy_early_exit_loss_funcs = dict() # RNG context. if self.sequence_parallel: @@ -1997,52 +2117,35 @@ def forward(self, hidden_states, attention_mask, self.microbatch_count = 0 # Reset count on new batch size rampup interval self.num_microbatches_in_previous_step = get_num_microbatches() - # note: do not support recompute - forward_kwargs = { - 'encoder_output': encoder_output, - 'enc_dec_attn_mask': enc_dec_attn_mask, - 'inference_params': inference_params, - } - - # note: do not support transformer_engine and retriever - forward_kwargs['rotary_pos_emb'] = rotary_pos_emb - - for index, is_exit in enumerate(self.exit_states): + for index, is_exit_layer in enumerate(self.exit_states): layer = self._get_layer(index) - layer_num = index + self.offset - if is_exit: - if self.use_exit_mlp: - hidden_states, exit_hidden_states = layer(hidden_states, - attention_mask, - **forward_kwargs) - else: - hidden_states = layer(hidden_states, - attention_mask, - **forward_kwargs) - exit_hidden_states = hidden_states - exit_output_tensor = exit_post_process_func(lm_output=exit_hidden_states, - logit_weights=self.exit_output_weights[layer_num], - temperature=self.exit_layer_temperature[layer_num]) - if exit_loss_func: - early_exit_losses[layer_num] = exit_loss_func( - output_tensor = exit_output_tensor, - layer_num=layer_num, - weight=self.exit_layer_loss_weight[layer_num], - ) - if inference_params is not None and \ - inference_params.early_exit(exit_output_tensor, layer_num): - hidden_states = exit_hidden_states + + if is_exit_layer: + hidden_states, exit_output, exit = layer(hidden_states, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + exit_process_func=partial(exit_process_func, logit_weights=self.exit_output_weights[layer.layer_number]), + exit_loss_func=exit_loss_func) + if inference_params is None: + # only collect loss funcs in training mode + lazy_early_exit_loss_funcs[layer.layer_number] = exit_output + elif exit: + # change output in inference mode + return exit_output, exit_output + if exit: break else: - hidden_states = layer( - hidden_states, - attention_mask, - **forward_kwargs) + hidden_states = layer(hidden_states, + attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb) if torch.is_grad_enabled() and self.training: self.microbatch_count += 1 - if self.post_process and self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) + if self.post_process and self.post_norm: + hidden_states = self.final_norm(hidden_states) - return hidden_states, early_exit_losses + return hidden_states, lazy_early_exit_loss_funcs + \ No newline at end of file diff --git a/megatron/text_generation/api.py b/megatron/text_generation/api.py index ac6b7046..6d1fdba6 100644 --- a/megatron/text_generation/api.py +++ b/megatron/text_generation/api.py @@ -4,11 +4,12 @@ import torch - +import traceback from megatron.core import mpu from .communication import broadcast_float_list from .generation import ( generate_tokens_probs_and_return_on_first_stage, + generate_with_pipelined_early_exit_and_return_on_first_stage, score_and_return_on_first_stage, beam_search_and_return_on_first_stage) from .tokenization import ( @@ -34,7 +35,9 @@ def generate_and_post_process(model, prevent_newline_after_colon=False, random_seed=-1, return_logits=False, - early_exit_thres=100.0): + early_exit_thres=1.0, + use_early_exit=False, + print_max_prob=False): """Run inference and post-process outputs, i.e., detokenize, move to cpu and convert to list.""" @@ -57,7 +60,9 @@ def generate_and_post_process(model, stop_token_ids=stop_token_ids, prevent_newline_after_colon=prevent_newline_after_colon, random_seed=random_seed, - early_exit_thres=early_exit_thres) + early_exit_thres=early_exit_thres, + use_early_exit=use_early_exit, + print_max_prob=print_max_prob) # Only post-process on first stage. if mpu.is_pipeline_first_stage(): @@ -97,7 +102,9 @@ def generate(model, stop_token_ids=None, prevent_newline_after_colon=False, random_seed=-1, - early_exit_thres=100.0): + early_exit_thres=1.0, + use_early_exit=False, + print_max_prob=False): """Given prompts and input parameters, run inference and return: tokens: prompts plus the generated tokens. lengths: length of the prompt + generations. Note that we can @@ -113,7 +120,7 @@ def generate(model, temperature, add_BOS, use_stop_tokens_for_early_termination, stop_on_double_eol, stop_on_eol, prevent_newline_after_colon, - random_seed, early_exit_thres] + random_seed, early_exit_thres, use_early_exit, print_max_prob] if stop_token_ids != None: stop_token_ids = torch.tensor(stop_token_ids, dtype=torch.int64) values.append(len(stop_token_ids)) @@ -135,10 +142,12 @@ def generate(model, prevent_newline_after_colon = bool(values_float_tensor[11].item()) random_seed = int(values_float_tensor[12].item()) early_exit_thres = values_float_tensor[13].item() - - stop_tokens_length = int(values_float_tensor[14].item()) + use_early_exit = bool(values_float_tensor[14].item()) + print_max_prob = bool(values_float_tensor[15].item()) + + stop_tokens_length = int(values_float_tensor[16].item()) if stop_tokens_length > 0: - stop_token_ids = values_float_tensor[15: 15 + stop_tokens_length].int() + stop_token_ids = values_float_tensor[17: 17 + stop_tokens_length].int() else: stop_token_ids = None @@ -159,19 +168,42 @@ def generate(model, # Main inference function. # Note that the outputs are available on the first stage. - return generate_tokens_probs_and_return_on_first_stage( - model, context_tokens_tensor, context_length_tensor, - return_output_log_probs=return_output_log_probs, - top_k=top_k_sampling, - top_p=top_p_sampling, - top_p_decay=top_p_decay, - top_p_bound=top_p_bound, - temperature=temperature, - use_stop_tokens_for_early_termination=use_stop_tokens_for_early_termination, - stop_tokens=stop_token_ids, - prevent_newline_after_colon=prevent_newline_after_colon, - echo_prompts=echo_prompts, - early_exit_thres=early_exit_thres) + try: + if mpu.get_pipeline_model_parallel_world_size() > 1: + output = generate_with_pipelined_early_exit_and_return_on_first_stage( + model, context_tokens_tensor, context_length_tensor, + return_output_log_probs=return_output_log_probs, + top_k=top_k_sampling, + top_p=top_p_sampling, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + use_stop_tokens_for_early_termination=use_stop_tokens_for_early_termination, + stop_tokens=stop_token_ids, + prevent_newline_after_colon=prevent_newline_after_colon, + echo_prompts=echo_prompts, + early_exit_thres=early_exit_thres, + use_early_exit=use_early_exit, + print_max_prob=print_max_prob) + else: + output = generate_tokens_probs_and_return_on_first_stage( + model, context_tokens_tensor, context_length_tensor, + return_output_log_probs=return_output_log_probs, + top_k=top_k_sampling, + top_p=top_p_sampling, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + use_stop_tokens_for_early_termination=use_stop_tokens_for_early_termination, + stop_tokens=stop_token_ids, + prevent_newline_after_colon=prevent_newline_after_colon, + echo_prompts=echo_prompts, + early_exit_thres=early_exit_thres, + use_early_exit=use_early_exit, + print_max_prob=print_max_prob) + except Exception as e: + traceback.print_exc() + return output def beam_search_and_post_process(model, prompts=None, diff --git a/megatron/text_generation/communication.py b/megatron/text_generation/communication.py index dee32077..0d03185c 100644 --- a/megatron/text_generation/communication.py +++ b/megatron/text_generation/communication.py @@ -4,6 +4,7 @@ import torch +import torch.distributed as dist from megatron.core import mpu @@ -41,6 +42,31 @@ def send_to_next_pipeline_rank(tensor=None): torch.cuda.synchronize() +def recv_list_from_prev_pipeline_rank(recv_buffers): + if not mpu.is_pipeline_first_stage(): + assert recv_buffers is not None and type(recv_buffers) is list + recv_prev_ops = [torch.distributed.P2POp( + torch.distributed.irecv, recv_buffer, + mpu.get_pipeline_model_parallel_prev_rank()) for recv_buffer in recv_buffers] + reqs = torch.distributed.batch_isend_irecv(recv_prev_ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + + +def send_list_to_next_pipeline_rank(tensors): + if not mpu.is_pipeline_last_stage(): + assert tensors is not None and type(tensors) is list + send_next_ops = [torch.distributed.P2POp( + torch.distributed.isend, tensor, + mpu.get_pipeline_model_parallel_next_rank()) for tensor in tensors] + reqs = torch.distributed.batch_isend_irecv(send_next_ops) + for req in reqs: + req.wait() + # To protect against race condition when using batch_isend_irecv(). + torch.cuda.synchronize() + def _is_cuda(tensor): """Check if a tensor is not none and is cuda.""" @@ -79,6 +105,28 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): return tensor +def broadcast_from_first_pipeline_stage(size, dtype, tensor=None): + """Broadcast a tensor from last pipeline stage to all ranks.""" + + is_first_stage = mpu.is_pipeline_first_stage() + # If first stage and last state are the same, then there is no + # pipeline parallelism and no need to communicate. + if mpu.is_pipeline_last_stage() and is_first_stage: + return tensor + + if is_first_stage: + _is_cuda_contiguous(tensor) + else: + tensor = torch.empty(size, + dtype=dtype, + device=torch.cuda.current_device()) + # Get the group and corresponding source rank. + src = mpu.get_pipeline_model_parallel_first_rank() + group = mpu.get_pipeline_model_parallel_group() + torch.distributed.broadcast(tensor, src, group) + + return tensor + def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): """Broadcast tensor values from last stage into the first stage.""" @@ -98,7 +146,7 @@ def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): dtype=dtype, device=torch.cuda.current_device()) src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() + group = mpu.get_pipeline_endpoint_group() # Broadcast from last stage into the first stage. torch.distributed.broadcast(tensor, src, group) else: @@ -123,7 +171,7 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): _is_cuda(tensor) is_contiguous = tensor.is_contiguous() src = mpu.get_pipeline_model_parallel_last_rank() - group = mpu.get_embedding_group() + group = mpu.get_pipeline_endpoint_group() if is_contiguous: tensor_ = tensor else: @@ -140,6 +188,69 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): tensor[...] = tensor_ +def get_exit_stages(): + early_exit_stage_ids = mpu.get_early_exit_stages() + last_stage_id = mpu.get_pipeline_model_parallel_world_size() - 1 + if last_stage_id not in early_exit_stage_ids: + return list(early_exit_stage_ids + [last_stage_id]) + return early_exit_stage_ids + + +EXIT=1 +CONTINUE=0 + +def send_token_and_probs_to_first_pipeline_stage(inference_params, token_tensor=None, prob_tensor=None, is_final=False): + signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device()) + if inference_params.has_early_exited or is_final: + signal_tensor[0] = EXIT + _is_cuda(token_tensor) + _is_cuda(prob_tensor) + else: + signal_tensor[0] = CONTINUE + dist.send(tensor=signal_tensor, dst=0, group=mpu.get_pipeline_model_parallel_group()) + if inference_params.has_early_exited or is_final: + dist.send(tensor=token_tensor, dst=0, group=mpu.get_pipeline_model_parallel_group()) + dist.send(tensor=prob_tensor, dst=0, group=mpu.get_pipeline_model_parallel_group()) + + +def recv_token_and_probs(inference_params, token_tensor_buffer, prob_tensor_buffer): + + is_contiguous = token_tensor_buffer.is_contiguous() + if is_contiguous: + token_tensor_ = token_tensor_buffer + prob_tensor_ = prob_tensor_buffer + else: + token_tensor_ = torch.empty(token_tensor_buffer.shape[0], + dtype=torch.int64, + device=torch.cuda.current_device()) + prob_tensor_ = torch.empty(prob_tensor_buffer.shape[0], + dtype=torch.float32, + device=torch.cuda.current_device()) + + # if first stage has early exit, get tensor directly + if mpu.has_early_exit(): + if inference_params.has_early_exited: + assert inference_params.tokens is not None + token_tensor_buffer[...] = inference_params.tokens + prob_tensor_buffer[...] = inference_params.probs + return + + exit_stages = get_exit_stages() + if exit_stages[0] == 0: + exit_stages.pop(0) + signal_tensor = torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device()) + + # get tensor from subsequent stages one by one + for stage_id in exit_stages: + dist.recv(tensor=signal_tensor, src=stage_id, group=mpu.get_pipeline_model_parallel_group()) + if signal_tensor[0] == EXIT: + dist.recv(tensor=token_tensor_, src=stage_id, group=mpu.get_pipeline_model_parallel_group()) + dist.recv(tensor=prob_tensor_, src=stage_id, group=mpu.get_pipeline_model_parallel_group()) + break + + if not is_contiguous: + token_tensor_buffer[...] = token_tensor_ + prob_tensor_buffer[...] = prob_tensor_ def broadcast_tensor(size, dtype, tensor=None, rank=0): """ Given size and type of a tensor on all ranks and the tensor value diff --git a/megatron/text_generation/forward_step.py b/megatron/text_generation/forward_step.py index 6da2d1df..b8ec727f 100644 --- a/megatron/text_generation/forward_step.py +++ b/megatron/text_generation/forward_step.py @@ -7,10 +7,13 @@ import torch from megatron import get_args -from megatron.core import mpu, InferenceParams +from megatron.core import mpu +from .inference_params import InferenceParams from .communication import ( send_to_next_pipeline_rank, - recv_from_prev_pipeline_rank_) + recv_from_prev_pipeline_rank_, + send_list_to_next_pipeline_rank, + recv_list_from_prev_pipeline_rank) class ForwardStep: @@ -45,16 +48,11 @@ def __call__(self, tokens, position_ids, attention_mask): is being modified by the forward step.""" # Pipelining case. if self.pipeline_size_larger_than_one: - current_batch_x_seqlen = tokens.size(0) * tokens.size(1) - if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: - micro_batch_size = \ - max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) - return _with_pipelining_forward_step(self.model, + return _with_early_exit_pipelining_forward_step(self.model, tokens, position_ids, attention_mask, - self.inference_params, - micro_batch_size) + self.inference_params) return _no_pipelining_forward_step(self.model, tokens, @@ -178,3 +176,36 @@ def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, inference_params.batch_size_offset = 0 return logits + + +def _allocate_early_exit_recv_buffers(batch_size, sequence_length): + if mpu.is_pipeline_first_stage(): + return None + args = get_args() + recv_size = (sequence_length, batch_size, args.hidden_size) + return [torch.empty(recv_size, + dtype=_get_recv_buffer_dtype(args), + device=torch.cuda.current_device()), + torch.empty(1, dtype=torch.int8, device=torch.cuda.current_device())] + + +def _with_early_exit_pipelining_forward_step(model, tokens, position_ids, attention_mask, + inference_params): + """No interleaving is supported.""" + sequence_length = tokens.size(1) + batch_size = tokens.size(0) + assert batch_size == 1, "early exit not support batch inference yet" + # Divide the batch dimension into micro batches. + # Preallocate recv buffer. + if not mpu.is_pipeline_first_stage(): + recv_buffers = _allocate_early_exit_recv_buffers(batch_size, sequence_length) + recv_list_from_prev_pipeline_rank(recv_buffers) + model.set_input_tensor(recv_buffers[0]) + inference_params.prev_has_early_exited = bool(recv_buffers[1]) + output_tensor = model(tokens, position_ids, attention_mask, inference_params=inference_params) + signal_tensor = torch.tensor([int(inference_params.has_early_exited or inference_params.prev_has_early_exited)], + dtype=torch.int8, + device=torch.cuda.current_device()) + send_list_to_next_pipeline_rank([output_tensor, signal_tensor]) + inference_params.sequence_len_offset += sequence_length + return output_tensor \ No newline at end of file diff --git a/megatron/text_generation/generation.py b/megatron/text_generation/generation.py index 9e344a52..32874979 100644 --- a/megatron/text_generation/generation.py +++ b/megatron/text_generation/generation.py @@ -6,12 +6,16 @@ import torch.nn.functional as F from megatron import get_args, get_tokenizer -from megatron.core import mpu, InferenceParams +from megatron.core import mpu from megatron.utils import get_ltor_masks_and_position_ids from .communication import ( copy_from_last_to_first_pipeline_stage, + send_token_and_probs_to_first_pipeline_stage, + recv_token_and_probs, broadcast_from_last_pipeline_stage, + broadcast_from_first_pipeline_stage, broadcast_from_last_to_first_pipeline_stage) +from .inference_params import InferenceParams from .forward_step import ForwardStep from .sampling import sample from .beam_utils import BeamHypotheses @@ -21,7 +25,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): Arguments: model: no interleaving is supported. tokens: prompt tokens extended to be of size [b, max_prompt_length] - lengths: original prompt length, size: [b] + lengths: original prompt length, size: [b]tokenizer Note: Outside of model, other parameters only need to be available on rank 0. Outputs: @@ -96,7 +100,9 @@ def generate_tokens_probs_and_return_on_first_stage( stop_tokens=None, prevent_newline_after_colon=True, echo_prompts=False, - early_exit_thres=None + early_exit_thres=1.0, + use_early_exit=False, + print_max_prob=False, ): """Main token generation function. Arguments: @@ -138,7 +144,14 @@ def generate_tokens_probs_and_return_on_first_stage( if max_sequence_length * batch_size > args.max_tokens_to_oom: raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) - inference_params = InferenceParams(batch_size, max_sequence_length, early_exit_thres, tokenizer) + inference_params = InferenceParams(batch_size, max_sequence_length, + top_k=top_k, top_p=top_p, + temperature=temperature, + top_p_bound=top_p_bound, + top_p_decay=top_p_decay, + early_exit_thres=early_exit_thres, + use_early_exit=use_early_exit, + print_max_prob=print_max_prob) # forward step. forward_step = ForwardStep(model, inference_params=inference_params) @@ -239,10 +252,10 @@ def generate_tokens_probs_and_return_on_first_stage( # Update the context length for the next token generation. prev_context_length = context_length - if not inference_params.has_early_exit: + if not inference_params.has_early_exited: full_exit_context_length = prev_context_length inference_params.sequence_len_offset += tokens2use.size(1) - inference_params.has_early_exit = False + inference_params.has_early_exited = False inference_params.is_first_step = False # Check if all the sequences have hit the termination_id. @@ -426,6 +439,230 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto return tokens, scores +def generate_with_pipelined_early_exit_and_return_on_first_stage( + model, tokens, lengths, + return_output_log_probs=False, + top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0, + temperature=1.0, + use_stop_tokens_for_early_termination=True, + stop_on_double_eol=False, + stop_on_eol=False, + stop_tokens=None, + prevent_newline_after_colon=True, + echo_prompts=False, + early_exit_thres=1.0, + use_early_exit=False, + print_max_prob=False +): + """Main token generation function. + Arguments: + model: no interleaving is supported. + tokens: prompt tokens extended to be of size [b, max-sequence-length] + lengths: original prompt length, size: [b] + return_output_log_probs: flag to calculate the log probability of + the generated tokens. Note that the log probability is the one + from the original logit. + top_k, top_p: top-k and top-p sampling parameters. + Note that top-k = 1 is gready. Also, these paramters are + exclusive meaning that: + if top-k > 0 then we expect top-p=0. + if top-p > 0 then we check for top-k=0. + temperature: sampling temperature. + use_eod_token_for_early_termination: if True, do early termination if + all the sequences have reached this token. + prevent_newline_after_colon: if True, it will disable generating new line \n after : + Note: Outside of model, other parameters only need to be available on + rank 0. + Outputs: Note that is size is adjusted to a lower value than + max-sequence-length if generation is terminated early. + tokens: prompt and generated tokens. size: [b, :] + generated_sequence_lengths: total length (including prompt) of + the generated sequence. size: [b] + output_log_probs: log probability of the selected tokens. size: [b, s] + """ + + args = get_args() + tokenizer = get_tokenizer() + + batch_size = tokens.size(0) + min_prompt_length = lengths.min().item() + max_sequence_length = tokens.size(1) + + if max_sequence_length > args.max_position_embeddings: + raise ValueError(f"Length of prompt + tokens_to_generate ({max_sequence_length}) longer than allowed ({args.max_position_embeddings})") + + if max_sequence_length * batch_size > args.max_tokens_to_oom: + raise ValueError("Too many tokens. " + str(max_sequence_length*batch_size)+ " is greater than "+str(args.max_tokens_to_oom)) + + inference_params = InferenceParams(batch_size, max_sequence_length, + top_k=top_k, top_p=top_p, + temperature=temperature, + top_p_bound=top_p_bound, + top_p_decay=top_p_decay, + early_exit_thres=early_exit_thres, + use_early_exit=use_early_exit, + print_max_prob=print_max_prob) + + # forward step. + forward_step = ForwardStep(model, inference_params=inference_params) + + # Added termination_id to support the case that we want to terminate the + # generation once that id is generated. + if hasattr(args, 'eos_id'): + termination_id = args.eos_id + else: + termination_id = tokenizer.eod + + # =================== + # Pre-allocate memory + # =================== + + # Log probability of the sequence (prompt + generated tokens). + output_log_probs = None + output_log_probs_size = (batch_size, max_sequence_length - 1) + # Lengths of generated seuquence including including prompts. + generated_sequence_lengths = None + if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() or mpu.has_early_exit(): + output_log_probs = torch.empty(output_log_probs_size, + dtype=torch.float32, + device=torch.cuda.current_device()) + if mpu.is_pipeline_first_stage(): + generated_sequence_lengths = torch.ones( + batch_size, dtype=torch.int64, + device=torch.cuda.current_device()) * max_sequence_length + + # Whether we have reached a termination id. + is_generation_done = torch.zeros(batch_size, dtype=torch.uint8, + device=torch.cuda.current_device()) + + # ============= + # Run infernece + # ============= + + with torch.no_grad(): + attention_mask, position_ids = _build_attention_mask_and_position_ids( + tokens) + prev_context_length = 0 + for context_length in range(min_prompt_length, max_sequence_length): + + # Pick the slice that we need to pass through the network. + tokens2use = tokens[:, prev_context_length:context_length] + positions2use = position_ids[:, prev_context_length:context_length] + attention_mask2use = attention_mask[ + ..., prev_context_length:context_length, :context_length] + + # clear inference states + inference_params.clear_early_exit_states() + + # logits will be meanigful only in the last pipeline stage. + logits = forward_step(tokens2use, positions2use, attention_mask2use) + + if mpu.is_pipeline_last_stage() and not (inference_params.has_early_exited or inference_params.prev_has_early_exited): + last_token_logits = logits[:, -1, :] + + # Calculate the log probabilities. + log_probs = F.log_softmax(logits, dim=2) + max_log_prob, token_id = torch.max(log_probs[:, -1, :], dim=1) + token = tokenizer.detokenize([int(token_id[-1])]) + if print_max_prob: + print(f"layer final: token [{token}], prob {float(torch.exp(max_log_prob[-1]))}") + inference_params.has_early_exited = max_log_prob[-1] >= inference_params.early_exit_thres + new_sample = sample(last_token_logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + vocab_size=tokenizer.vocab_size) + if top_p > 0.0 and top_p_decay > 0.0: + top_p = top_p * top_p_decay + if top_p_bound > 0.0: + top_p = max(top_p, top_p_bound) + + # If a prompt length is smaller or equal th current context + # length, it means we have started generating tokens + started = lengths <= context_length + # Update the tokens. + tokens[started, context_length] = new_sample[started] + # Pick the tokens that we need to get the log + # probabilities for. Note that next input token is + # the token which we selected in the current logits, + # so shift by 1. + indices = torch.unsqueeze( + tokens[ + :, + (prev_context_length + 1):(context_length + 1)], + 2) + output_log_probs[:, + prev_context_length:context_length] = \ + torch.gather(log_probs, 2, indices).squeeze(2) + send_token_and_probs_to_first_pipeline_stage(inference_params=inference_params, + token_tensor=tokens[:, context_length], + prob_tensor=output_log_probs[:, context_length - 1], + is_final=True) + elif mpu.is_pipeline_first_stage(): + recv_token_and_probs(inference_params=inference_params, + token_tensor_buffer=tokens[:, context_length], + prob_tensor_buffer=output_log_probs[:, context_length - 1]) + elif mpu.has_early_exit() and not(inference_params.has_early_exited or inference_params.prev_has_early_exited): + send_token_and_probs_to_first_pipeline_stage(inference_params=inference_params) + + # Update the context length for the next token generation. + prev_context_length = context_length + inference_params.is_first_step = False + + # Check if all the sequences have hit the termination_id. + # done = None + # if mpu.is_pipeline_first_stage(): + # # TODO(rprenger) These stopping methods are tokenizer dependent + # # instead tokenization should be in the inference loop so stop sequences can be used + # if stop_on_double_eol: + # hit_double_eol = (new_sample == 628).byte() & started.byte() + # hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() + # done_token = hit_double_eol | hit_two_eols + # elif stop_on_eol: + # hit_double_eol = (new_sample == 628).byte() & started.byte() + # hit_eol = (new_sample == 198).byte() & started.byte() + # done_token = hit_double_eol | hit_eol + # else: + # done_token = (new_sample == termination_id).byte() & \ + # started.byte() + + # just_finished = (done_token & ~is_generation_done).bool() + # generated_sequence_lengths[just_finished.view(-1)] = \ + # context_length + 1 + # is_generation_done = is_generation_done | done_token + # done = torch.all(is_generation_done) + # done = broadcast_from_first_pipeline_stage(1, torch.uint8, + # tensor=done) + # if use_stop_tokens_for_early_termination and done: + # break + + # =================================================== + # Update the length of based on max generated length. + # =================================================== + + # tokens = tokens[:, :(context_length + 1)] + # if mpu.is_pipeline_last_stage(): + # if return_output_log_probs: + # output_log_probs = output_log_probs[:, :context_length] + + # ====================================== + # Broadcast to the first pipeline stage. + # ====================================== + + # if return_output_log_probs: + # output_log_probs_size = (batch_size, context_length) + # output_log_probs = broadcast_from_last_to_first_pipeline_stage( + # output_log_probs_size, torch.float32, output_log_probs) + if not echo_prompts and mpu.is_pipeline_first_stage(): + generated_sequence_lengths -= lengths + for i, (sequence, length) in enumerate(zip(tokens, lengths)): + tokens[i] = sequence.roll(-length.item(), dims=0) + if return_output_log_probs: + for i, (prob, length) in enumerate(zip(output_log_probs, lengths)): + output_log_probs[i] = prob.roll(-(length.item() - 1), dims=0) + return tokens, generated_sequence_lengths, output_log_probs, None + + def _build_attention_mask_and_position_ids(tokens): """Build the attention mask and postition ids for the input tokens.""" diff --git a/megatron/text_generation/inference_params.py b/megatron/text_generation/inference_params.py new file mode 100644 index 00000000..2e57d387 --- /dev/null +++ b/megatron/text_generation/inference_params.py @@ -0,0 +1,102 @@ +import torch +import numpy as np +import torch.nn.functional as F + +from megatron import get_tokenizer, get_args +from megatron.text_generation.sampling import sample +from megatron.text_generation.communication import send_token_and_probs_to_first_pipeline_stage +from megatron.core import mpu + +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + def __init__(self, max_batch_size, max_sequence_length, + top_k=0, top_p=0, temperature=1.0, + top_p_decay=0, top_p_bound=0, + early_exit_thres=None, use_early_exit=False, + print_max_prob=False): + self.max_sequence_length = max_sequence_length + self.max_batch_size = max_batch_size + self.sequence_len_offset = 0 + self.batch_size_offset = 0 + self.key_value_memory_dict = {} + self.early_exit_thres = np.log(early_exit_thres) if early_exit_thres > 0 else float('-inf') + self.use_early_exit = use_early_exit + self.tokenizer = get_tokenizer() + self.use_pipeline_inference = get_args().pipeline_model_parallel_size > 1 + self.top_k = top_k + self.top_p = top_p + self.temperature=temperature + self.top_p_decay = top_p_decay + self.top_p_bound = top_p_bound + self.print_max_probs = print_max_prob + + self.has_early_exited = False + self.is_first_step = True + self.prev_has_early_exited = False + self.tokens = None + self.probs = None + + def clear_early_exit_states(self): + self.has_early_exited = False + self.prev_has_early_exited = False + self.tokens = None + self.probs = None + + def do_early_exit(self, logits, layer_num): + if self.has_early_exited or self.prev_has_early_exited: + return False + last_token_logits = logits[:, -1, :] + log_probs = F.log_softmax(last_token_logits, dim=1) + max_log_prob, token_id = torch.max(log_probs[:, :], dim=1) + token = self.tokenizer.detokenize([int(token_id[-1])]) + if self.print_max_probs: + print(f"layer [{layer_num}]: token [{token}], prob {float(torch.exp(max_log_prob[-1]))}") + self.has_early_exited = max_log_prob[-1] >= self.early_exit_thres + if self.use_pipeline_inference and self.has_early_exited: + # send token and probs to the first stage + tokens, probs = self.get_tokens_and_probs(last_token_logits) + self.send_to_first_pipeline_stage(tokens, probs) + return False + else: + return self.has_early_exited + + def get_tokens_and_probs(self, last_token_logits): + tokens = sample(last_token_logits, + top_k=self.top_k, + top_p=self.top_p, + temperature=self.temperature, + vocab_size=self.tokenizer.vocab_size) + if self.top_p > 0.0 and self.top_p_decay > 0.0: + top_p = self.top_p * self.top_p_decay + if self.top_p_bound > 0.0: + top_p = max(top_p, self.top_p_bound) + indices = torch.unsqueeze(tokens, 1) + log_probs = F.log_softmax(last_token_logits, dim=1) + output_log_probs = torch.gather(log_probs, 1, indices) + return tokens, output_log_probs + + def send_to_first_pipeline_stage(self, tokens, probs): + if mpu.is_pipeline_first_stage(): + self.tokens = tokens + self.probs = probs + else: + send_token_and_probs_to_first_pipeline_stage(self, tokens, probs) + + def swap_key_value_dict(self, batch_idx): + "swap between batches" + if len(self.key_value_memory_dict) == 0: + raise ValueError("should not swap when dict in empty") + + for layer_number in self.key_value_memory_dict.keys(): + inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] + assert ( + len(batch_idx) == inference_key_memory.shape[1] + ) # make sure batch size is the same + new_inference_key_memory = inference_key_memory[:, batch_idx] + new_inference_value_memory = inference_value_memory[:, batch_idx] + self.key_value_memory_dict[layer_number] = ( + new_inference_key_memory, + new_inference_value_memory, + ) diff --git a/megatron/training.py b/megatron/training.py index 425ade4c..c4cb06bd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -3,6 +3,8 @@ """Pretrain utilities.""" from datetime import datetime +from functools import partial +from contextlib import nullcontext import math import logging import sys @@ -404,7 +406,7 @@ def setup_model_and_optimizer(model_provider_func, -def train_step(forward_step_func, data_iterator, +def train_step(forward_backward_func, data_iterator, model, optimizer, opt_param_scheduler, config): """Single training step.""" args = get_args() @@ -416,9 +418,7 @@ def train_step(forward_step_func, data_iterator, optimizer.zero_grad() # Forward pass. - forward_backward_func = get_forward_backward_func() losses_reduced = forward_backward_func( - forward_step_func=forward_step_func, data_iterator=data_iterator, model=model, num_microbatches=get_num_microbatches(), @@ -735,6 +735,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, # Iterations. iteration = args.iteration + args.curr_iteration = args.iteration # Setup some training config params config.grad_scale_func = optimizer.scale_loss @@ -754,6 +755,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, print_datetime('before the start of training step') report_memory_flag = True + forward_backward_func = get_forward_backward_func() + forward_backward_func = partial(forward_backward_func, forward_step_func=forward_step_func) + while iteration < args.train_iters: if args.profile and \ iteration == args.profile_step_start and \ @@ -763,13 +767,14 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, update_num_microbatches(args.consumed_train_samples) args.curr_iteration = iteration + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \ - train_step(forward_step_func, - train_data_iterator, - model, - optimizer, - opt_param_scheduler, - config) + train_step(forward_backward_func, + train_data_iterator, + model, + optimizer, + opt_param_scheduler, + config) iteration += 1 args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ diff --git a/pretrain_multi_exit_gpt.py b/pretrain_early_exit_gpt.py similarity index 84% rename from pretrain_multi_exit_gpt.py rename to pretrain_early_exit_gpt.py index ec971726..f8ace67b 100644 --- a/pretrain_multi_exit_gpt.py +++ b/pretrain_early_exit_gpt.py @@ -1,6 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - -"""Pretrain GPT""" +"""Pretrain Early-exit LLM""" import torch from functools import partial @@ -11,7 +9,7 @@ from megatron.core import tensor_parallel from megatron.core.enums import ModelType from megatron.data.gpt_dataset import build_train_valid_test_datasets -from megatron.model import MultiExitGPTModel +from megatron.model import EarlyExitGPTModel from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group @@ -20,9 +18,10 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" - print_rank_0('building MultiExitGPT model ...') + print_rank_0('building EarlyExitGPT model ...') + args = get_args() config = core_transformer_config_from_args(get_args()) - model = MultiExitGPTModel( + model = EarlyExitGPTModel( config, num_tokentypes=0, parallel_output=True, @@ -64,17 +63,15 @@ def get_batch(data_iterator): return tokens, labels, loss_mask, attention_mask, position_ids -def loss_func(loss_mask, output_tensor, layer_num=None, weight=None): +def loss_func(loss_mask, output_tensor, log_dict, log_key): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # Reduce loss for logging. averaged_loss = average_losses_across_data_parallel_group([loss]) - loss_name = 'lm loss' if layer_num is None else f'early loss [{layer_num}]' - if weight is not None: - loss *= weight - return loss, {loss_name: averaged_loss[0]} + log_dict[log_key] = averaged_loss[0] + return loss def forward_step(data_iterator, model): @@ -90,7 +87,7 @@ def forward_step(data_iterator, model): masked_loss_func = partial(loss_func, loss_mask) lm_output = model(tokens, position_ids, attention_mask, - labels=labels, masked_loss_func=masked_loss_func) + labels=labels, exit_loss_func=masked_loss_func) return lm_output, masked_loss_func @@ -118,9 +115,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): if __name__ == "__main__": - pretrain(train_valid_test_datasets_provider, model_provider, ModelType.encoder_or_decoder, forward_step, - args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) + args_defaults={'tokenizer_type': 'SentencePieceTokenizer'}) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index a8162fde..a6181421 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -14,18 +14,14 @@ from megatron.core.enums import ModelType from megatron.data.gpt_dataset import GPTDataset, build_train_valid_test_datasets import megatron.model -from megatron.core.models.gpt import GPTModel from megatron.training import pretrain from megatron.core.transformer.spec_utils import import_module from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group from megatron.arguments import core_transformer_config_from_args -from megatron.core.models.gpt.gpt_layer_specs import ( - gpt_layer_with_transformer_engine_spec, - gpt_layer_with_transformer_engine_spec_moe -) -def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.model.GPTModel]: + +def model_provider(pre_process=True, post_process=True) -> megatron.model.GPTModel: """Builds the model. If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model. @@ -38,41 +34,17 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat Returns: Union[GPTModel, megatron.model.GPTModel]: The returned model """ - args = get_args() print_rank_0('building GPT model ...') config = core_transformer_config_from_args(get_args()) - if args.use_mcore_models: - if args.model_spec is not None: - transformer_layer_spec = import_module(args.model_spec) - else: - if args.num_experts is None: - transformer_layer_spec = gpt_layer_with_transformer_engine_spec - else: - transformer_layer_spec = gpt_layer_with_transformer_engine_spec_moe - - model = GPTModel( - config=config, - transformer_layer_spec=transformer_layer_spec, - vocab_size=args.padded_vocab_size, - max_sequence_length=args.max_position_embeddings, - pre_process=pre_process, - post_process=post_process, - fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, - parallel_output=True, - share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, - position_embedding_type=args.position_embedding_type, - rotary_percent=args.rotary_percent - ) - else: - model = megatron.model.GPTModel( - config, - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process - ) + model = megatron.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) return model @@ -134,7 +106,7 @@ def loss_func(loss_mask: Tensor, output_tensor: Tensor): return loss, {'lm loss': averaged_loss[0]} -def forward_step(data_iterator, model: GPTModel): +def forward_step(data_iterator, model): """Forward training step. Args: diff --git a/tools/checkpoint/checkpoint_converter.py b/tools/checkpoint/checkpoint_converter.py new file mode 100644 index 00000000..0582cd3f --- /dev/null +++ b/tools/checkpoint/checkpoint_converter.py @@ -0,0 +1,250 @@ +import json +import os +import sys +import torch +import argparse +from collections import OrderedDict + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--load-dir', type=str) + parser.add_argument('--load-iteration', type=int) + parser.add_argument('--save-dir', type=str) + parser.add_argument('--conversion-type', choices=['exit-position', 'add-exit']) + parser.add_argument('--target-exit-position', choices=['pre', 'post'], default='post') + parser.add_argument('--add-exit-layer-nums', type=int, nargs='+', default=[]) + parser.add_argument('--use-exit-mlp', action='store_true') + parser.add_argument('--use-exit-block', action='store_true') + parser.add_argument('--use-exit-norm', action='store_true') + parser.add_argument('--megatron-path', type=str, default=None, + help='Base directory of deepspeed repository') + return parser.parse_args() + +def load_checkpoint_args(checkpoint_root_path): + if os.path.exists(os.path.join(checkpoint_root_path, 'mp_rank_00')): + checkpoint_rank_0_dir = 'mp_rank_00' + elif os.path.exists(os.path.join(checkpoint_root_path, 'mp_rank_00_000')): + checkpoint_rank_0_dir = 'mp_rank_00_000' + else: + raise FileNotFoundError(f'Checkpoint file {checkpoint_root_path} not found') + checkpoint_path = os.path.join(checkpoint_root_path, checkpoint_rank_0_dir, 'model_optim_rng.pt') + print(f"Loading args from {checkpoint_root_path}") + model = torch.load(checkpoint_path) + return model['args'] + +def change_exit_position(args, checkpoint_load_dir, checkpoint_save_dir): + checkpoint_args = load_checkpoint_args(checkpoint_load_dir) + cur_exit_position = 'pre' if checkpoint_args.pre_exit else 'post' + if cur_exit_position == args.target_exit_position: + print("No need to convert") + return + pipeline_parallel_size = checkpoint_args.pipeline_model_parallel_size + tensor_parallel_size = checkpoint_args.tensor_model_parallel_size + exit_layer_nums = checkpoint_args.exit_layer_nums + if args.target_exit_position == 'pre': + exit_layer_nums = [layer_num + 1 for layer_num in exit_layer_nums] + else: + exit_layer_nums = [layer_num - 1 for layer_num in exit_layer_nums] + use_pipeline_parallel = pipeline_parallel_size > 1 + for tensor_rank in range(tensor_parallel_size): + checkpoint_dicts = {} + exit_output_weights = [] + exit_output_weight_offset = 0 + # load all pipeline ranks + for pipeline_rank in range(pipeline_parallel_size): + if not use_pipeline_parallel: + checkpoint_name = os.path.join(checkpoint_load_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') + else: + checkpoint_name = os.path.join(checkpoint_load_dir, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}', 'model_optim_rng.pt') + print(f'Loading checkpoint [pp:{pipeline_rank}, tp:{tensor_rank}] from {checkpoint_name} ...') + state_dict = torch.load(checkpoint_name, map_location='cpu') + checkpoint_dicts[pipeline_rank] = state_dict + # convert args + state_dict['args'].exit_layer_nums = exit_layer_nums + state_dict['args'].pre_exit = (args.target_exit_position == 'pre') + # get exit output weight + if checkpoint_args.untie_exit_output_weights and use_pipeline_parallel: + if 'exit_output_layer' in state_dict['model']['language_model']: + exit_weight_num = len(state_dict['model']['language_model']['exit_output_layer']) + for i in range(exit_weight_num): + exit_output_weights.append(state_dict['model']['language_model']['exit_output_layer'].pop(f'{i}.weight')) + # convert output weight position + if checkpoint_args.untie_exit_output_weights and use_pipeline_parallel: + layer_per_stage = checkpoint_args.num_layers / pipeline_parallel_size + for pipeline_rank in range(pipeline_parallel_size): + layer_nums = list(filter(lambda x: (layer_per_stage * pipeline_rank + 1) <= x <= (layer_per_stage * (pipeline_rank + 1)), exit_layer_nums)) + if len(layer_nums) > 0: + if 'exit_output_layer' not in checkpoint_dicts[pipeline_rank]['model']['language_model']: + checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'] = OrderedDict() + for i in range(len(layer_nums)): + checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'][f'{i}.weight'] = exit_output_weights[exit_output_weight_offset] + exit_output_weight_offset += 1 + elif 'exit_output_layer' in checkpoint_dicts[pipeline_rank]['model']['language_model']: + checkpoint_dicts[pipeline_rank]['model']['language_model'].pop('exit_output_layer') + # save back + for pipeline_rank in range(pipeline_parallel_size): + if not use_pipeline_parallel: + checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') + else: + checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}', 'model_optim_rng.pt') + dirname = os.path.dirname(checkpoint_save_path) + os.makedirs(dirname, exist_ok = True) + print(f'Saving checkpoint [pp:{pipeline_rank}, tp:{tensor_rank}] to {checkpoint_save_path} ...') + torch.save(checkpoint_dicts[pipeline_rank], checkpoint_save_path) + print('Exit Weight Position Conversion Completed') + + +def add_exit(args, checkpoint_load_dir, checkpoint_save_dir): + if len(args.add_exit_layer_nums) == 0: + print("No exit layer to add") + return + checkpoint_args = load_checkpoint_args(checkpoint_load_dir) + use_pre_exit = False + if len(checkpoint_args.exit_layer_nums) == 0: + if args.target_exit_position == 'pre': + use_pre_exit = True + else: + if checkpoint_args.pre_exit == (args.target_exit_position == 'pre'): + print("Can't add exit layers and change exit position at the same time") + return + use_pre_exit = checkpoint_args.pre_exit + target_exit_layer_nums = list(set(checkpoint_args.exit_layer_nums + args.add_exit_layer_nums)) + tensor_parallel_size = checkpoint_args.tensor_model_parallel_size + pipeline_parallel_size = checkpoint_args.pipeline_model_parallel_size + use_pipeline_parallel = pipeline_parallel_size > 1 + layer_per_stage = checkpoint_args.num_layers / pipeline_parallel_size + for tensor_rank in range(tensor_parallel_size): + checkpoint_dicts = {} + output_weight = None + final_norm_weight = None + final_norm_bias = None + # load all pipeline ranks + for pipeline_rank in range(pipeline_parallel_size): + if not use_pipeline_parallel: + checkpoint_name = os.path.join(checkpoint_load_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') + else: + checkpoint_name = os.path.join(checkpoint_load_dir, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}', 'model_optim_rng.pt') + print(f'Loading checkpoint [pp:{pipeline_rank}, tp:{tensor_rank}] from {checkpoint_name} ...') + layer_num_offset = layer_per_stage * pipeline_rank + 1 + exit_layer_nums = list(filter(lambda x: (layer_per_stage * pipeline_rank + 1) <= x <= (layer_per_stage * (pipeline_rank + 1)), target_exit_layer_nums)) + state_dict = torch.load(checkpoint_name) + checkpoint_dicts[pipeline_rank] = state_dict + # convert args + state_dict['args'].exit_layer_nums = target_exit_layer_nums + state_dict['args'].pre_exit = use_pre_exit + state_dict['args'].untie_exit_output_weights = True + + # get ouptut weight + if checkpoint_args.untie_embeddings_and_output_weights: + if pipeline_rank == pipeline_parallel_size - 1: + output_weight = state_dict['model']['language_model']['output_layer']['weight'] + else: + if pipeline_rank == 0: + output_weight = state_dict['model']['language_model']['embedding']['word_embeddings']['weight'] + + # convert to exit mlp + if args.use_exit_mlp and (not hasattr(state_dict['args'], 'use_exit_mlp') or not state_dict['args'].use_exit_mlp): + state_dict['args'].use_exit_mlp = args.use_exit_mlp + for layer_num in exit_layer_nums: + layer_id = int(layer_num - layer_num_offset) + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_h_to_4h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_4h_to_h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.weight'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.weight'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.weight'] + state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_h_to_4h.weight') + state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_4h_to_h.weight') + if checkpoint_args.add_bias_linear: + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_h_to_4h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.trunk.dense_4h_to_h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_h_to_4h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_h_to_4h.bias'] + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.branch.dense_4h_to_h.bias'] = \ + state_dict['model']['language_model']['encoder'][f'layers.{layer_id}.mlp.dense_4h_to_h.bias'] + state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_h_to_4h.bias') + state_dict['model']['language_model']['encoder'].pop(f'layers.{layer_id}.mlp.dense_4h_to_h.bias') + # convert to exit block + if args.use_exit_block: + state_dict['args'].use_exit_block = args.use_exit_block + # get last layer params + if pipeline_rank == pipeline_parallel_size - 1: + last_layer_id = int(layer_per_stage - 1) + last_layer_input_norm = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.input_norm.weight'] + last_layer_atten_qkv = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.self_attention.query_key_value.weight'] + last_layer_atten_dense = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.self_attention.dense.weight'] + last_layer_post_norm = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.post_attention_norm.weight'] + last_layer_mlp_h_to_4h = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.mlp.dense_h_to_4h.weight'] + last_layer_mlp_4h_to_h = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.mlp.dense_4h_to_h.weight'] + if checkpoint_args.add_bias_linear: + last_layer_atten_dense_bias = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.self_attention.dense.bias'] + last_layer_h_to_4h_bias = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.mlp.dense_h_to_4h.bias'] + last_layer_4h_to_h_bias = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.mlp.dense_4h_to_h.bias'] + if checkpoint_args.normalization == 'LayerNorm': + last_layer_input_norm_bias = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.input_norm.bias'] + last_layer_post_norm_bias = state_dict['model']['language_model']['encoder'][f'layers.{last_layer_id}.post_attention_norm.bias'] + # get final norm + if args.use_exit_norm: + state_dict['args'].use_exit_norm = args.use_exit_norm + if 'final_norm.weight' in state_dict['model']['language_model']['encoder']: + final_norm_weight = state_dict['model']['language_model']['encoder']['final_norm.weight'] + if checkpoint_args.normalization == 'LayerNorm': + final_norm_bias = state_dict['model']['language_mode']['encoder']['final_norm.bias'] + # get exit output weight + if len(exit_layer_nums) > 0 and 'exit_output_layer' not in state_dict['model']['language_model']: + state_dict['model']['language_model']['exit_output_layer'] = OrderedDict() + + for pipeline_rank in range(pipeline_parallel_size): + layer_num_offset = layer_per_stage * pipeline_rank + 1 + exit_layer_nums = list(filter(lambda x: (layer_per_stage * pipeline_rank + 1) <= x <= (layer_per_stage * (pipeline_rank + 1)), target_exit_layer_nums)) + # add exit output weight and exit norm + for i, layer_num in enumerate(exit_layer_nums): + layer_id = int(layer_num - layer_num_offset) + checkpoint_dicts[pipeline_rank]['model']['language_model']['exit_output_layer'][f'{i}.weight'] = output_weight + if args.use_exit_block: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.weight'] = last_layer_input_norm + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.query_key_value.weight'] = last_layer_atten_qkv + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.weight'] = last_layer_atten_dense + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.weight'] = last_layer_post_norm + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.weight'] = last_layer_mlp_h_to_4h + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.weight'] = last_layer_mlp_4h_to_h + if checkpoint_args.add_bias_linear: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.self_attention.dense.bias'] = last_layer_atten_dense_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_h_to_4h.bias'] = last_layer_h_to_4h_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.mlp.dense_4h_to_h.bias'] = last_layer_4h_to_h_bias + if checkpoint_args.normalization == 'LayerNorm': + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.input_norm.bias'] = last_layer_input_norm_bias + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_block.post_attention_norm.bias'] = last_layer_post_norm_bias + if args.use_exit_norm: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm.weight'] = final_norm_weight + if final_norm_bias is not None: + checkpoint_dicts[pipeline_rank]['model']['language_model']['encoder'][f'layers.{layer_id}.exit_norm'] = final_norm_bias + if not use_pipeline_parallel: + checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}', 'model_optim_rng.pt') + else: + checkpoint_save_path = os.path.join(checkpoint_save_dir, f'mp_rank_{tensor_rank:02d}_{pipeline_rank:03d}', 'model_optim_rng.pt') + dirname = os.path.dirname(checkpoint_save_path) + os.makedirs(dirname, exist_ok = True) + print(f'Saving checkpoint [pp:{pipeline_rank}, tp:{tensor_rank}] to {checkpoint_save_path} ...') + torch.save(checkpoint_dicts[pipeline_rank], checkpoint_save_path) + print('Add Exit Layers Completed') + + +def convert(args): + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + checkpoint_load_dir = os.path.join(args.load_dir, 'iter_{:07d}'.format(args.load_iteration)) + checkpoint_save_dir = os.path.join(args.save_dir, 'iter_{:07d}'.format(args.load_iteration)) + if args.conversion_type == 'exit-position': + change_exit_position(args, checkpoint_load_dir, checkpoint_save_dir) + elif args.conversion_type == 'add-exit': + add_exit(args, checkpoint_load_dir, checkpoint_save_dir) + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/tools/checkpoint/loader_llama2_hf.py b/tools/checkpoint/loader_llama2_hf.py index 36b907d9..ed94319a 100644 --- a/tools/checkpoint/loader_llama2_hf.py +++ b/tools/checkpoint/loader_llama2_hf.py @@ -222,7 +222,7 @@ def check_for_arg(arg_name, default=None): check_for_arg('swiglu', False) # Determine how to make our models. - assert args.model_type == 'GPT', 'Llama-2 is a GPT model.' + assert args.model_type == 'GPT' or args.model_type == 'EarlyExitGPT', 'Llama-2 is a GPT model.' margs.model_type = ModelType.encoder_or_decoder # Suppress warning about torch.distributed not being initialized. diff --git a/tools/checkpoint/loader_megatron.py b/tools/checkpoint/loader_megatron.py index d9db1202..46ef3fc8 100644 --- a/tools/checkpoint/loader_megatron.py +++ b/tools/checkpoint/loader_megatron.py @@ -98,8 +98,8 @@ def check_for_arg(arg_name, default=None): elif args.model_type == 'BERT': from pretrain_bert import model_provider margs.model_type = ModelType.encoder_or_decoder - elif args.model_type == 'MultiExitGPT': - from pretrain_multi_exit_gpt import model_provider + elif args.model_type == 'EarlyExitGPT': + from pretrain_early_exit_gpt import model_provider margs.model_type = ModelType.encoder_or_decoder else: raise Exception(f'unrecognized model type: {args.model_type}') @@ -210,7 +210,10 @@ def get_models(count, dtype): md.exit_layer_nums = margs.exit_layer_nums if hasattr(margs, 'exit_layer_nums') else [] md.exit_layer_weight = margs.exit_layer_weight if hasattr(margs, 'exit_layer_weight') else [] md.use_exit_mlp = margs.use_exit_mlp if hasattr(margs, 'use_exit_mlp') else False + md.use_exit_block = margs.use_exit_block if hasattr(margs, 'use_exit_block') else False + md.use_exit_norm = margs.use_exit_norm if hasattr(margs, 'use_exit_norm') else False md.untie_exit_output_weights = margs.untie_exit_output_weights if hasattr(margs, 'untie_exit_output_weights') else False + md.pre_exit = margs.pre_exit md.checkpoint_args = checkpoint_args # Get first pipe stage @@ -218,6 +221,7 @@ def get_models(count, dtype): if len(md.exit_layer_nums) > 0: layer_per_stage = md.num_layers / margs.pipeline_model_parallel_size mpu.set_early_exit_layer_nums(list(filter(lambda x: 0 < x <= layer_per_stage, md.exit_layer_nums))) + mpu.set_early_exit_stages(list(set(map(lambda layer_num: int((layer_num - 1) // layer_per_stage), md.exit_layer_nums)))) all_models = [get_models(tp_size, md.params_dtype)] models = all_models[0][0] @@ -261,13 +265,18 @@ def queue_put(name, msg): layer_num = layer.layer_number has_early_exit = layer_num in md.exit_layer_nums use_exit_mlp = has_early_exit and md.use_exit_mlp - message["_ norm weight"] = layer.input_norm.weight.data + use_exit_block = has_early_exit and md.use_exit_block + use_exit_norm = has_early_exit and md.use_exit_norm + message["input norm weight"] = layer.input_norm.weight.data if norm_has_bias: message["input norm bias"] = layer.input_norm.bias.data message["post norm weight"] = layer.post_attention_norm.weight.data if norm_has_bias: message["post norm bias"] = layer.post_attention_norm.bias.data - + if use_exit_norm: + message["exit norm weight"] = layer.exit_norm.weight.data + if norm_has_bias: + message['exit norm bias'] = layer.exit_norm.bias.data if md.linear_bias: message["dense bias"] = layer.self_attention.dense.bias.data if use_exit_mlp: @@ -275,6 +284,15 @@ def queue_put(name, msg): message["mlp l1 exit bias"] = layer.mlp.branch.dense_4h_to_h.bias.data else: message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data + if use_exit_block: + message["exit block input norm weight"] = layer.exit_block.input_norm.weight.data + message["exit block post norm weight"] = layer.exit_block.post_attention_norm.weight.data + if norm_has_bias: + message["exit block input norm bias"] = layer.exit_block.input_norm.bias.data + message["exit block post norm bias"] = layer.exit_block.post_attention_norm.bias.data + if md.linear_bias: + message["exit block dense bias"] = layer.exit_block.self_attention.dense.bias.data + message["exit block mlp l1 bias"] = layer.exit_block.mlp.dense_4h_to_h.bias.data # Grab all parallel tensors for this layer qkv_weight = [] @@ -287,6 +305,12 @@ def queue_put(name, msg): mlp_l0_exit_bias = [] mlp_l1_exit_weight = [] exit_output_weight = [] + exit_block_qkv_weight = [] + exit_block_qkv_bias = [] + exit_block_dense_weight = [] + exit_block_mlp_l0_weight = [] + exit_block_mlp_l0_bias = [] + exit_block_mlp_l1_weight = [] for tp_rank, model in enumerate(models): layer = model.language_model.encoder.layers[layer_id] qkv_weight.append(layer.self_attention.query_key_value.weight.data) @@ -308,6 +332,14 @@ def queue_put(name, msg): mlp_l0_exit_bias.append(layer.mlp.branch.dense_h_to_4h.bias.data) else: mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) + if use_exit_block: + exit_block_qkv_weight.append(layer.exit_block.self_attention.query_key_value.weight.data) + exit_block_dense_weight.append(layer.exit_block.self_attention.dense.weight.data) + exit_block_mlp_l0_weight.append(layer.exit_block.mlp.dense_h_to_4h.weight.data) + exit_block_mlp_l1_weight.append(layer.exit_block.mlp.dense_4h_to_h.weight.data) + if md.linear_bias: + exit_block_qkv_bias.append(layer.exit_block.self_attention.query_key_value.bias.data) + exit_block_mlp_l0_bias.append(layer.exit_block.mlp.dense_h_to_4h.bias.data) # Handle gated linear units if md.swiglu: @@ -351,9 +383,27 @@ def queue_put(name, msg): message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) if use_exit_mlp: message["mlp l0 exit bias"] = torch.cat(mlp_l0_exit_bias, dim=0) - + if use_exit_block: + if md.swiglu: + for tp_rank in range(tp_size): + exit_block_mlp_l0_weight[tp_rank] = torch.chunk(exit_block_mlp_l0_weight[tp_rank], 2, dim=0) + message["exit block mlp l0 weight W"] = torch.cat([w[0] for w in exit_block_mlp_l0_weight], dim=0) + message["exit block mlp l0 weight V"] = torch.cat([w[1] for w in exit_block_mlp_l0_weight], dim=0) + else: + message["exit block mlp l0 weight"] = torch.cat(exit_block_mlp_l0_weight, dim=0) + message["exit block qkv weight"] = torch.cat(exit_block_qkv_weight, dim=0) + message["exit block dense weight"] = torch.cat(exit_block_dense_weight, dim=1) + message["exit block mlp l1 weight"] = torch.cat(exit_block_mlp_l1_weight, dim=1) + if md.linear_bias: + message["exit block qkv bias"] = torch.cat(exit_block_qkv_bias, dim=0) + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(exit_block_mlp_l0_bias[tp_rank], 2, dim=0) + message["exit block mlp l0 bias W"] = torch.cat([b[0] for b in exit_block_mlp_l0_bias],dim=0) + message["exit block mlp l0 bias V"] = torch.cat([b[1] for b in exit_block_mlp_l0_bias],dim=0) + else: + message["exit block mlp l0 bias"] = torch.cat(exit_block_mlp_l0_bias, dim=0) queue_put(f"transformer layer {total_layer_num}", message) - total_layer_num = total_layer_num + 1 # Send final norm from tp_rank 0 diff --git a/tools/checkpoint/saver_megatron.py b/tools/checkpoint/saver_megatron.py index fd5fc2a8..5b576af3 100644 --- a/tools/checkpoint/saver_megatron.py +++ b/tools/checkpoint/saver_megatron.py @@ -21,6 +21,8 @@ def add_arguments(parser): group.add_argument('--target-pipeline-parallel-size', type=int, help='Target tensor model parallel size, default to the pipeline parall size ' 'in the input checkpoint if provided by the loader, otherwise to 1') + group.add_argument('--target-exit-position', choices=['ignore', 'pre', 'post'], default='ignore', + help='Change the relative position of early exit') def save_checkpoint(queue, args): @@ -134,7 +136,7 @@ def check_message(msg): if md.model_type == 'BERT' and not md.bert_binary_head: sys.argv.append('--bert-no-binary-head') - if md.exit_layer_nums is not None and len(md.exit_layer_nums) > 0: + if hasattr(md, 'exit_layer_nums') and len(md.exit_layer_nums) > 0: sys.argv.append('--exit-layer-nums') for layer_num in md.exit_layer_nums: sys.argv.append(str(layer_num)) @@ -143,6 +145,12 @@ def check_message(msg): sys.argv.append(str(layer_weight)) if md.use_exit_mlp: sys.argv.append("--use-exit-mlp") + if md.use_exit_block: + sys.argv.append("--use-exit-block") + if md.use_exit_norm: + sys.argv.append("--use-exit-norm") + if md.pre_exit: + sys.argv.append("--pre-exit") margs = parse_args() @@ -177,7 +185,7 @@ def check_message(msg): validate_args(margs) - set_global_variables(margs, build_tokenizer=False) + set_global_variables(margs, build_tokenizer=False, init_wandb=False) # margs = megatron args margs = get_args() @@ -197,8 +205,8 @@ def check_message(msg): elif md.model_type == 'BERT': from pretrain_bert import model_provider margs.model_type = ModelType.encoder_or_decoder - elif md.model_type == 'MultiExitGPT': - from pretrain_multi_exit_gpt import model_provider + elif md.model_type == 'EarlyExitGPT': + from pretrain_early_exit_gpt import model_provider margs.model_type = ModelType.encoder_or_decoder else: raise Exception(f'unrecognized model type: {args.model_type}') @@ -212,9 +220,10 @@ def get_models(count, dtype, pre_process, post_process): mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) mpu.set_tensor_model_parallel_rank(0) mpu.set_pipeline_model_parallel_rank(0) - if len(md.exit_layer_nums) > 0: + if hasattr(md, 'exit_layer_nums') and len(md.exit_layer_nums) > 0: layer_per_stage = md.num_layers / args.target_pipeline_parallel_size mpu.set_early_exit_layer_nums(list(filter(lambda x: 0 < x <= layer_per_stage, md.exit_layer_nums))) + mpu.set_early_exit_stages(list(set(map(lambda layer_num: int((layer_num - 1) // layer_per_stage), md.exit_layer_nums)))) fused_kernels.load(margs) # Embeddings @@ -259,7 +268,7 @@ def get_models(count, dtype, pre_process, post_process): # Make models for first pipeline stage and fill in embeddings mpu.set_pipeline_model_parallel_rank(0) - if len(md.exit_layer_nums) > 0: + if hasattr(md, 'exit_layer_nums') and len(md.exit_layer_nums) > 0: layer_per_stage = md.num_layers / args.target_pipeline_parallel_size mpu.set_early_exit_layer_nums(list(filter(lambda x: 0 < x <= layer_per_stage, md.exit_layer_nums))) post_process = args.target_pipeline_parallel_size == 1 @@ -278,16 +287,21 @@ def get_models(count, dtype, pre_process, post_process): # For later pipeline parallel ranks, make the new models if pp_rank > 0: mpu.set_pipeline_model_parallel_rank(pp_rank) - if len(md.exit_layer_nums) > 0: + if hasattr(md, 'exit_layer_nums') and len(md.exit_layer_nums) > 0: mpu.set_early_exit_layer_nums(list(filter(lambda x: (layer_per_stage * pp_rank) < x <= (layer_per_stage * (pp_rank + 1)), md.exit_layer_nums))) post_process = pp_rank == args.target_pipeline_parallel_size - 1 models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) + pre_process = pp_rank == 0 + + is_early_exit_stage = mpu.has_early_exit() for layer in range(len(models[0].language_model.encoder.layers)): msg = queue_get(f"transformer layer {total_layer_num}") layer_num = models[tp_rank].language_model.encoder.layers[layer].layer_number - has_early_exit = layer_num in md.exit_layer_nums - use_exit_mlp = has_early_exit and md.use_exit_mlp + is_early_exit_layer = layer_num in md.exit_layer_nums if hasattr(md, 'exit_layer_nums') else False + use_exit_mlp = is_early_exit_layer and hasattr(md, 'use_exit_mlp') and md.use_exit_mlp + use_exit_block = is_early_exit_layer and hasattr(md, 'use_exit_block') and md.use_exit_block + use_exit_norm = is_early_exit_layer and hasattr(md, 'use_exit_norm') and md.use_exit_norm # duplicated tensors input_norm_weight = msg.pop("input norm weight") @@ -296,11 +310,24 @@ def get_models(count, dtype, pre_process, post_process): post_norm_weight = msg.pop("post norm weight") if md.norm_has_bias: post_norm_bias = msg.pop("post norm bias") + if use_exit_norm: + exit_norm_weight = msg.pop("exit norm weight") + if md.norm_has_bias: + exit_norm_bias = msg.pop("exit norm bias") if md.linear_bias: dense_bias = msg.pop("dense bias") mlp_l1_bias = msg.pop("mlp l1 bias") if use_exit_mlp: mlp_l1_exit_bias = msg.pop("mlp l1 exit bias") + if use_exit_block: + exit_block_input_norm_weight = msg.pop("exit block input norm weight") + exit_block_post_norm_weight = msg.pop("exit block post norm weight") + if md.norm_has_bias: + exit_block_input_norm_bias = msg.pop("exit block input norm bias") + exit_block_post_norm_bias = msg.pop("exit block post norm bias") + if md.linear_bias: + exit_block_dense_bias = msg.pop("exit block dense bias") + exit_block_mlp_l1_bias = msg.pop("exit block mlp l1 bias") # Split up the parallel tensors qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) @@ -324,7 +351,7 @@ def get_models(count, dtype, pre_process, post_process): if use_exit_mlp: mlp_l0_exit_weight = torch.chunk(msg.pop("mlp l0 exit weight"), args.target_tensor_parallel_size, dim=0) - if has_early_exit and md.untie_exit_output_weights: + if is_early_exit_layer and md.untie_exit_output_weights: exit_output_weight = torch.chunk(msg.pop("exit output weight"), args.target_tensor_parallel_size, dim=0) if md.linear_bias: qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0) @@ -340,18 +367,40 @@ def get_models(count, dtype, pre_process, post_process): mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0) if use_exit_mlp: mlp_l0_exit_bias = torch.chunk(msg.pop("mlp l0 exit bias"), args.target_tensor_parallel_size, dim=0) + if use_exit_block: + # Split up the parallel tensors + exit_block_qkv_weight = torch.chunk(msg.pop("exit block qkv weight"), args.target_tensor_parallel_size, dim=0) + exit_block_dense_weight = torch.chunk(msg.pop("exit block dense weight"), args.target_tensor_parallel_size, dim=1) + exit_block_mlp_l1_weight = torch.chunk(msg.pop("exit block mlp l1 weight"), args.target_tensor_parallel_size, dim=1) + if md.swiglu: + exit_block_mlp_l0_weight_W = torch.chunk(msg.pop("exit block mlp l0 weight W"), args.target_tensor_parallel_size, dim=0) + exit_block_mlp_l0_weight_V = torch.chunk(msg.pop("exit block mlp l0 weight V"), args.target_tensor_parallel_size, dim=0) + exit_block_mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(exit_block_mlp_l0_weight_W, exit_block_mlp_l0_weight_V)] + else: + exit_block_mlp_l0_weight = torch.chunk(msg.pop("exit block mlp l0 weight"), args.target_tensor_parallel_size, dim=0) + if md.linear_bias: + exit_block_qkv_bias = torch.chunk(msg.pop("exit block qkv bias"), args.target_tensor_parallel_size, dim=0) + if md.swiglu: + exit_block_mlp_l0_bias_W = torch.chunk(msg.pop("exit block mlp l0 bias W"), args.target_tensor_parallel_size, dim=0) + exit_block_mlp_l0_bias_V = torch.chunk(msg.pop("exit block mlp l0 bias V"), args.target_tensor_parallel_size, dim=0) + exit_block_mlp_l0_bias = [torch.cat(bias, dim=0) for bias in zip(exit_block_mlp_l0_bias_W, exit_block_mlp_l0_bias_V)] + else: + exit_block_mlp_l0_bias = torch.chunk(msg.pop("exit block mlp l0 bias"), args.target_tensor_parallel_size, dim=0) # Save them to the model for tp_rank in range(args.target_tensor_parallel_size): l = models[tp_rank].language_model.encoder.layers[layer] l.input_norm.weight.data.copy_(input_norm_weight) + l.post_attention_norm.weight.data.copy_(post_norm_weight) if md.norm_has_bias: l.input_norm.bias.data.copy_(input_norm_bias) + l.post_attention_norm.bias.data.copy_(post_norm_bias) l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) - l.post_attention_norm.weight.data.copy_(post_norm_weight) - if md.norm_has_bias: - l.post_attention_norm.bias.data.copy_(post_norm_bias) + if use_exit_norm: + l.exit_norm.weight.data.copy_(exit_norm_weight) + if md.norm_has_bias: + l.exit_norm.bias.data.copy_(exit_norm_bias) if use_exit_mlp: l.mlp.trunk.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) l.mlp.trunk.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) @@ -360,7 +409,7 @@ def get_models(count, dtype, pre_process, post_process): else: l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) - if has_early_exit and md.untie_exit_output_weights: + if is_early_exit_layer and md.untie_exit_output_weights: models[tp_rank].language_model.encoder.exit_output_weights[layer_num].data.copy_(exit_output_weight[tp_rank]) if md.linear_bias: l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank]) @@ -373,10 +422,29 @@ def get_models(count, dtype, pre_process, post_process): else: l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank]) l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) + if use_exit_block: + l.exit_block.input_norm.weight.data.copy_(exit_block_input_norm_weight) + l.exit_block.post_attention_norm.weight.data.copy_(exit_block_post_norm_weight) + l.exit_block.self_attention.query_key_value.weight.data.copy_(exit_block_qkv_weight[tp_rank]) + l.exit_block.self_attention.dense.weight.data.copy_(exit_block_dense_weight[tp_rank]) + l.exit_block.mlp.dense_h_to_4h.weight.data.copy_(exit_block_mlp_l0_weight[tp_rank]) + l.exit_block.mlp.dense_4h_to_h.weight.data.copy_(exit_block_mlp_l1_weight[tp_rank]) + if md.norm_has_bias: + l.exit_block.input_norm.bias.data.copy_(exit_block_input_norm_bias) + l.exit_block.post_attention_norm.bias.data.copy_(exit_block_post_norm_bias) + if md.linear_bias: + l.exit_block.self_attention.query_key_value.bias.data.copy_(exit_block_qkv_bias[tp_rank]) + l.exit_block.self_attention.dense.bias.data.copy_(exit_block_dense_bias) + l.exit_block.mlp.dense_4h_to_h.bias.data.copy_(exit_block_mlp_l1_bias) + l.exit_block.mlp.dense_h_to_4h.bias.data.copy_(exit_block_mlp_l0_bias[tp_rank]) total_layer_num = total_layer_num + 1 check_message(msg) + if not md.output_layer and is_early_exit_stage and not (pre_process or post_process): + for tp_rank in range(args.target_tensor_parallel_size): + models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) + if post_process: msg = queue_get("final norm") final_norm_weight = msg.pop("weight") diff --git a/tools/checkpoint/util.py b/tools/checkpoint/util.py index d0219177..e323e954 100644 --- a/tools/checkpoint/util.py +++ b/tools/checkpoint/util.py @@ -111,7 +111,7 @@ def main(): allow_abbrev=False, conflict_handler='resolve') parser.add_argument('--model-type', type=str, required=True, - choices=['GPT', 'BERT', 'MultiExitGPT'], + choices=['GPT', 'BERT', 'EarlyExitGPT'], help='Type of the model') parser.add_argument('--loader', type=str, default='megatron', help='Module name to load checkpoint, should be on python path') diff --git a/tools/met_server.sh b/tools/met_server.sh deleted file mode 100755 index d7995b6c..00000000 --- a/tools/met_server.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -# example script for tongyi checkpoint - -PROJECT_NAME=MET_TEXT_GENERATION_SERVER -GROUP_NAME=1F1B-1B-2-EXIT - -export OMP_NUM_THREADS=8 -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -CHECKPOINT_PATH=/home/data/shared/checkpoints/MET-EXP/1F1B-MET-1B-2-EXIT-6-12-0.25-0.5-300B/convert -# CHECKPOINT_PATH=/home/data/shared/checkpoints/MET-EXP/MET-7B-8-16-0.25-0.5-untie/convert -# CHECKPOINT_PATH=/home/data/shared/checkpoints/MET-EXP/1F1B-MET-1B-1-EXIT-6-0.25-mlp-300B/convert -# CHECKPOINT_PATH=/home/data/shared/checkpoints/MET-EXP/1F1B-MET-1B-untie-embeddings/convert -# CHECKPOINT_PATH=/home/data/shared/checkpoints/MET-EXP/1F1B-MET-1B-1-EXIT/convert -TOKENIZER_PATH=/home/data/panxuchen.pxc/code/Megatron-LM/tokenizer/tokenizer.model - -TP=1 -PP=1 -SEQ=4096 - -MASTER_ADDR=127.0.0.1 -MASTER_PORT=5950 -NPROC_PER_NODE=$(( $TP * $PP )) -LOAD_ITERATION=0 - -DIST_ARGS=" - --master_addr $MASTER_ADDR \ - --master_port $MASTER_PORT \ - --nproc_per_node $NPROC_PER_NODE \ - --nnodes 1 \ - --node_rank 0 \ - " - -SERVER_ARGS=" - --use-checkpoint-args \ - --tokenizer-model $TOKENIZER_PATH \ - --load $CHECKPOINT_PATH \ - --load-iteration $LOAD_ITERATION \ - --port 5000 -" - -export CUDA_VISIBLE_DEVICES=1 && torchrun $DIST_ARGS \ - run_multi_exit_text_generation_server.py \ - $SERVER_ARGS diff --git a/tools/prompt_example.jsonl b/tools/prompt_example.jsonl new file mode 100644 index 00000000..ec478f62 --- /dev/null +++ b/tools/prompt_example.jsonl @@ -0,0 +1,2 @@ +{"text": "Artificial General Intelligence is"} +{"text": "The capital of China is"} \ No newline at end of file diff --git a/tools/request_client.py b/tools/request_client.py index 51096429..ce70f7b6 100644 --- a/tools/request_client.py +++ b/tools/request_client.py @@ -1,59 +1,67 @@ import requests import json import time -import numpy as np -PROMPTS_PATH = 'tools/prompts.json' -URL = 'http://localhost:5000/api' +URL = "http://localhost:5000/api" HEADER = { - 'Content-Type': 'application/json; charset=UTF-8', + "Content-Type": "application/json; charset=UTF-8", } -def request(prompts, tokens_to_generate, early_exit_thres): - length = len(prompts) - for i in range(length): - data = { - 'prompts': [prompts[i]], - 'tokens_to_generate': tokens_to_generate, - 'top_k': 1, - 'logprobs': True, - 'random_seed': int(time.time_ns()) % 16384, - 'echo_prompts': False, - 'early_exit_thres': early_exit_thres - } - start_time = time.time() - response = requests.put(URL, headers=HEADER, data=json.dumps(data)) - end_time = time.time() - print('Request:-------------------------------------------------') - print(f'{prompts[i]}') - print(f'Response:------------------({end_time - start_time:.4f}s)-------------------') - try: - print(f'{response.json()["text"][0]}') - except Exception as e: - print(response.json()) - # print(response.json()) - print(f'segments len: {len(response.json()["segments"][0])}') - print(f'logprobs len: {len(response.json()["logprobs"][0])}') - print(f'Response segment: {response.json()["segments"][0]}') - print(f'Response logprobs: {[np.exp(p) for p in response.json()["logprobs"][0]]}') - print('----------------------------------------------------------') +def request( + prompts, + tokens_to_generate=100, + use_early_exit=True, + early_exit_thres=0.8, + print_max_prob=False, +): + length = len(prompts) + for i in range(length): + data = { + "prompts": [prompts[i]], + "tokens_to_generate": tokens_to_generate, + "top_k": 1, + "logprobs": True, + "random_seed": int(time.time_ns()) % 16384, + "echo_prompts": False, + "early_exit_thres": early_exit_thres, + } + if use_early_exit: + data["use_early_exit"] = True + if print_max_prob: + data["print_max_prob"] = True + start_time = time.time() + response = requests.put(URL, headers=HEADER, data=json.dumps(data)) + end_time = time.time() + print("Request:-------------------------------------------------") + print(f"{prompts[i]}") + print( + f"Response:------------------({end_time - start_time:.4f}s)-------------------" + ) + try: + print(f'{response.json()["text"][0]}') + except Exception as e: + print(response) + print("----------------------------------------------------------") -def main(file_name, tokens_to_generate, early_exit_thres): - prompts = [] - with open(file_name, 'r') as f: - for line in f.readlines(): - prompts.append(json.loads(line)['text']) - request(prompts=prompts, tokens_to_generate=tokens_to_generate, early_exit_thres=early_exit_thres) +def main( + file_name, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob +): + prompts = [] + with open(file_name, "r") as f: + for line in f.readlines(): + prompts.append(json.loads(line)["text"]) + request( + prompts, tokens_to_generate, use_early_exit, early_exit_thres, print_max_prob + ) if __name__ == "__main__": - main('tools/benchmark/test_text.jsonl', tokens_to_generate=100, early_exit_thres=0.8) - -# 99.99% -0.0001 -# 99.9% -0.0010 -# 99% -0.0101 -# 90% -0.1054 -# 80% -0.2231 -# 75% -0.2877 \ No newline at end of file + main( + "tools/prompt_example.jsonl", + tokens_to_generate=100, + use_early_exit=True, + early_exit_thres=0.8, + print_max_prob=False, + ) diff --git a/tools/run_multi_exit_text_generation_server.py b/tools/run_early_exit_text_generation_server.py similarity index 72% rename from tools/run_multi_exit_text_generation_server.py rename to tools/run_early_exit_text_generation_server.py index c3d0fbd4..636fcbe9 100644 --- a/tools/run_multi_exit_text_generation_server.py +++ b/tools/run_early_exit_text_generation_server.py @@ -1,6 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -"""Sample Generate GPT""" +"""Run inference for Early-exit GPT""" import os import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), @@ -10,10 +8,10 @@ from megatron.core import mpu from megatron.checkpointing import load_checkpoint from megatron.initialize import initialize_megatron -from megatron.model import MultiExitGPTModel +from megatron.model import EarlyExitGPTModel from megatron.training import get_model from megatron.arguments import core_transformer_config_from_args -from megatron.multi_exit_text_generation_server import MegatronServer +from megatron.early_exit_text_generation_server import MegatronServer from megatron.text_generation import generate_and_post_process from megatron.text_generation import beam_search_and_post_process import torch @@ -23,23 +21,13 @@ def model_provider(pre_process=True, post_process=True): config = core_transformer_config_from_args(get_args()) - print_rank_0('building MultiExitGPT model ...') - model = MultiExitGPTModel(config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + print_rank_0('building EarlyExitGPT model ...') + model = EarlyExitGPTModel(config, num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) return model def add_text_generate_args(parser): group = parser.add_argument_group(title='text generation') - group.add_argument("--temperature", type=float, default=1.0, - help='Sampling temperature.') - group.add_argument("--top_p", type=float, default=0.0, - help='Top p sampling.') - group.add_argument("--top_k", type=int, default=0, - help='Top k sampling.') - group.add_argument("--out-seq-length", type=int, default=1024, - help='Size of the output generated text.') - group.add_argument("--early-exit-thres", type=float, default=None, - help='threshold of early exit logits') group.add_argument("--port", type=int, default=5000, help='Text generation server port.') return parser @@ -68,7 +56,7 @@ def add_text_generate_args(parser): model = model[0] if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: server = MegatronServer(model) - server.run("0.0.0.0",port=args.port) + server.run("0.0.0.0", port=args.port) while True: choice = torch.cuda.LongTensor(1)