-
Notifications
You must be signed in to change notification settings - Fork 860
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding TP llama example * clean up * adding check point converter * clean up * addressingt the comments * fixing the error handling * clean up * fixing the batch issue for chat * adding min max gpus for inference * adding max seq len and batch size explanations * clean up * lowering the new token gen number * clean up * fixing the hard coded tp_degree * adding flash attention v2 * adding note for tp size * fixing spell checks * update packaging step * update packaing step * addressing comments on configs * address comments on model config * Update REAME.md * Update REAME.md * Update REAME.md * Update REAME.md --------- Co-authored-by: lxning <[email protected]>
- Loading branch information
1 parent
e346a93
commit f10a071
Showing
11 changed files
with
1,618 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Serving Llama2 with PyTorch Native Tensor Parallelism | ||
|
||
This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) as presented in the original [Llama repo](https://github.com/facebookresearch/llama/tree/main) using PyTorch(PT) Tensor Parallel (TP) APIs, which under the hood make use of DTensors. It basically, takes a sharding plan for linear layers in MLP and Attention blocks of Llama2 model and make a TP model distributed over multiple GPUs. In the following, we show the steps how to use this and serve the Llama2 7-70B model with Torchserve. | ||
|
||
Here we convert the Meta Llama2 model, which is based on Fairscale TP layers to PT distributed compliant checkpoints and use PT TP (DTensor) API to run the Distributed inference. | ||
|
||
**Note** The following has been tested on A100 GPUs with 40 GB memory so far. | ||
|
||
|
||
### How to use it? | ||
|
||
|
||
1- Make sure you have access to Llama2 weights on [HF model hub](https://huggingface.co/meta-llama), there is a form you need to fill up and within few mins you will get access. Any Llama2 model name on the hub **without -hf** is Meta/FAIR weight. | ||
|
||
Make sure you are signed up in HF as well, you will need your API token than can be accessed from [here](https://huggingface.co/settings/tokens), make sure to use the same email for accessing the weights as email you signed in to HF. | ||
|
||
Once you have the access, in your terminal login to HF | ||
|
||
``` | ||
huggingface-cli login YOUR_TOKEN | ||
``` | ||
|
||
### Step 1: Install requirements | ||
|
||
Make sure to have PyTorch Nighlies installed. | ||
|
||
``` | ||
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 | ||
pip install transformers | ||
``` | ||
|
||
### Step 2: Download model | ||
|
||
Login into HuggingFace hub with token by running the below command, **make sure to specify the right name for the Llama2 model from [HuggingFace (HF) model hub](https://huggingface.co/meta-llama), any model name on the model hub without -hf is Meta original model/ checkpoints and we need them not the HF converted versions.** | ||
|
||
|
||
|
||
```bash | ||
huggingface-cli login | ||
``` | ||
paste the token generated from HuggingFace hub. Make sure `use_auth_token=True` is in [Download script](../utils/Download_model.py). | ||
|
||
```bash | ||
python ../utils/Download_model.py --model_name meta-llama/Llama-2-7b | ||
``` | ||
The script prints the path where the model is downloaded as below. | ||
|
||
`model/models--meta-llama--Llama-2-7b/snapshots/365ffa8f1a6c455d3e2028ae658236b4b85ba824` | ||
|
||
|
||
### Step 3: Convert the "Meta" checkpoints to PyTorch Distributed compliant checkpoints | ||
|
||
Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B `--model_parallel_size 1` for 13B would be `--model_parallel_size 2` and 70B `model_parallel_size 8`, you can also set `--nproc_per_node ` accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model. | ||
|
||
You would be able to use larger number of processes/ TP size when load the model back. For example if you have converted the `13B` checkpoints with `--nproc_per_node 2`, during the inference you can use `--nproc_per_node` be `[2, max_num_available_gpu]` which you are changing the world_size and effectively the TP size. The recommendation here is to keep the TP size as shown above respective to model size, 7B (TP Size =1), 13B (TP Size =2), 70B (TP Size =8), unless your benchmark and your batch size/ compute load compensate for communication cost. | ||
|
||
|
||
This will save the model args in `model_args.json`, during the inference step you need to pass this json file for build the model. Make sure you are setting `--max_seq_len` which is the maximum sequence length for input text (context length) and `--max_batch_size` which is maximum batch size for inference to respective values. These two values will be used to construct the KV cache. | ||
|
||
``` | ||
torchrun --nnodes 1 --nproc_per_node 8 convert_checkpoints.py --original_ckpt_dir PATH/TO/MODEL/CHECKPOINTS --tokenizer_path PATH/TO/MODEL/CHECKPOINTS/tokenizer.model --model_parallel_size 1 --save_checkpoint_dir converted_checkpoints --max_seq_len 512 --max_batch_size 2 | ||
``` | ||
|
||
|
||
|
||
### Step 4: set up the configs: | ||
|
||
Lets setup configs in `model-config.yaml` | ||
|
||
``` | ||
#frontend settings | ||
minWorkers: 1 | ||
maxWorkers: 1 | ||
maxBatchDelay: 200 | ||
responseTimeout: 300 | ||
parallelType: "tp" | ||
deviceType: "gpu" | ||
torchrun: | ||
nproc-per-node: 8 # TP size | ||
handler: | ||
converted_ckpt_dir: "converted_checkpoints" | ||
tokenizer_path: "tokenizer.model" | ||
model_args_path: "model_args.json" | ||
max_seq_len: 512 | ||
max_batch_size: 6 | ||
max_new_tokens: 50 | ||
temperature: 0.6 | ||
top_p: 0.9 | ||
manual_seed: 40 | ||
mode: "text_completion" #choices are text_completion, chat | ||
``` | ||
|
||
### step 5: Create the mar file: | ||
Create the mar file using the following command here. | ||
|
||
``` | ||
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py" | ||
mv converted_checkpoints llama | ||
mv PATH/TO/MODEL/CHECKPOINTS/tokenizer.model llama | ||
mv model_args.json llama | ||
``` | ||
|
||
### Step 6: Serve the model: | ||
|
||
``` | ||
torchserve --ncs --start --model-store model_store --models llama | ||
``` | ||
|
||
### Step 6: Send inference request: | ||
|
||
Text completion example : | ||
|
||
|
||
```bash | ||
|
||
curl -v "http://localhost:8080/predictions/llama" -T sample_text.txt | ||
|
||
``` | ||
|
||
|
||
Chat example : | ||
|
||
|
||
```bash | ||
|
||
curl -v "http://localhost:8080/predictions/llama" -T dialogs.txt | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import Dict, List, Union | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch import nn, Tensor | ||
from torch.distributed._shard.sharded_tensor.api import ShardedTensor | ||
from torch.distributed._tensor import DeviceMesh, DTensor | ||
from torch.distributed.fsdp._fsdp_extensions import ( | ||
_ext_chunk_dtensor, | ||
_ext_chunk_tensor, | ||
) | ||
|
||
def _verify_fqn_across_ranks(fqn, grp_gloo): | ||
olist = [None for _ in range(dist.get_world_size())] | ||
dist.all_gather_object(olist, fqn, group=grp_gloo) | ||
assert len(set(olist)) == 1 | ||
assert olist[0] == fqn | ||
|
||
def _all_gather_into_list(data_tensor, model_parallel_group): | ||
tensor_list = [ | ||
torch.zeros_like(data_tensor).cuda() | ||
for _ in range(dist.get_world_size(model_parallel_group)) | ||
] | ||
dist.all_gather(tensor_list, data_tensor.cuda(), group=model_parallel_group) | ||
return tensor_list | ||
|
||
|
||
def _is_tp_sharded(fqn: str) -> bool: | ||
""" | ||
Returns whether a tensor given by the fqn is tensor parallel sharded. | ||
NOTE: this is currently done by inspection of the MF model and is quite | ||
brittle and would need to be updated if the MF sharding changes. | ||
""" | ||
return ( | ||
"attention" in fqn | ||
or "feed_forward" in fqn | ||
or "output" in fqn | ||
or "tok_embeddings" in fqn | ||
) | ||
|
||
def _unshard_param( | ||
ref_state_dict, | ||
fqn, | ||
model_parallel_group, | ||
grp_gloo, | ||
data_tensor, | ||
tp_sharded_shape, | ||
): | ||
""" | ||
Unshards the row or col-wise sharded parameter. | ||
For rowwise, this is done by reshaping into the local shape, allgathering, | ||
and stacking rows. For colwise, the only difference is we stack columns. | ||
This is done via vstack and column_stack respectively. | ||
""" | ||
mp_size = dist.get_world_size(model_parallel_group) | ||
|
||
ref_shape = ref_state_dict[fqn].shape | ||
assert ( | ||
ref_shape[0] == tp_sharded_shape[0] or ref_shape[1] == tp_sharded_shape[1] | ||
), f"Expected sharded shape to match either row or col-wise, but does not: {ref_shape} {tp_sharded_shape}" | ||
_verify_fqn_across_ranks(fqn, grp_gloo) | ||
if ref_shape[0] != tp_sharded_shape[0]: | ||
assert ref_shape[0] == tp_sharded_shape[0] * mp_size | ||
# reshape the flat data_tensor into the rowwise shape | ||
data_tensor = data_tensor.reshape(tp_sharded_shape) | ||
# now, all_gather such tensors | ||
tensor_list = _all_gather_into_list(data_tensor, model_parallel_group) | ||
# stack rowwise to produce the final unsharded tensor | ||
data_tensor = torch.vstack(tensor_list).cpu() | ||
assert data_tensor.shape == ref_shape | ||
full_shape = data_tensor.shape | ||
elif ( | ||
len(ref_shape) > 1 | ||
and len(tp_sharded_shape) > 1 | ||
and ref_shape[1] != tp_sharded_shape[1] | ||
): | ||
assert ref_shape[1] == mp_size * tp_sharded_shape[1] | ||
# first, reshape the flat data_tensor into the colwise shape | ||
data_tensor = data_tensor.reshape(tp_sharded_shape) | ||
tensor_list = _all_gather_into_list(data_tensor, model_parallel_group) | ||
data_tensor = torch.column_stack(tensor_list).cpu() | ||
assert data_tensor.shape == ref_shape, f"{data_tensor.shape} vs {ref_shape}" | ||
full_shape = data_tensor.shape | ||
else: | ||
assert ref_shape == tp_sharded_shape # not tensor parallel sharded | ||
full_shape = tp_sharded_shape | ||
logging.warning(f"{fqn} {ref_shape} {full_shape} - not sharded") | ||
return data_tensor, full_shape | ||
|
||
|
||
def build_distributed_state_dict_from_consolidated( | ||
model: nn.Module, | ||
consolidated_state_dict: Dict[str, Tensor], | ||
model_parallel_world_size: int, | ||
offload_to_cpu: bool = False, | ||
use_dtensor: bool = False, | ||
) -> Dict[str, Union[Tensor, DTensor, ShardedTensor]]: | ||
""" | ||
Main API that takes a model (with no parallelism applied) and a fairscale checkpoint | ||
and builds a PT-D compliant distributed state dict. Note that this expects a consolidated | ||
checkpoint. | ||
Args: | ||
model (torch.nn.Module): module with no parallelism applied (i.e. result of `build_model` with parallel_impl=ParallelImpl.NONE) | ||
fs_state_dict (Dict[str, Any]): Fairscale consolidated | ||
offload_to_cpu (bool): Whether to offload the resulting state_dict to CPU (default: False) | ||
use_dtensor (bool): Whether to use PyTorch Distributed Tensor instead of ShardedTensor (default: False) | ||
(this will eventually default to True) | ||
model_parallel_world_size: Model parallel world size that was used to create the consolidated checkpoint. | ||
This can be obtained by checking the number of consolidated0x.pth files in the checkpoint directory. | ||
Example usage:: | ||
``` | ||
MODEL_PARALLEL_SIZE = 8 | ||
ckpt_path = get_consolidated_ckpt_path( | ||
ckpt_dir=PTH_65b, mp_rank=local_rank, mp_size=MODEL_PARALLEL_SIZE | ||
) | ||
state_dict = torch.load(ckpt_path) | ||
# Build a local LLaMA with no parallelism | ||
model = build_model(...) | ||
sharded_state_dict = build_distributed_state_dict_from_consolidated( | ||
model, state_dict, model_parallel_world_size=MODEL_PARALLEL_SIZE, | ||
) | ||
# Wrap model with PT-native APIs + load | ||
model = FSDP(model) | ||
FSDP.set_state_dict_type(StateDictType.SHARDED_STATE_DICT) | ||
model.load_state_dict(sharded_state_dict) | ||
``` | ||
Note: Please make sure to pass an unsharded model as the model arg! Otherwise, things will not | ||
work. | ||
This distributed state dict is a mapping of FQN: ShardedTensor/DTensor. It will be replaced with | ||
DTensor once DTensor 2D checkpoint format is fully rolled out. | ||
Note: This has only been tested for loading state_dict into PT-D FSDP sharded_state_dict for now. | ||
""" | ||
torch._C._log_api_usage_once("build_distributed_state_dict") | ||
dist_state_dict = {} | ||
ref_state_dict = model.state_dict() | ||
grp_gloo = dist.new_group(backend="gloo") | ||
# TODO: this should be the FSDP device mesh | ||
mesh = ( | ||
DeviceMesh( | ||
device_type="cuda", | ||
mesh=list(range(dist.get_world_size())), | ||
) | ||
if use_dtensor | ||
else None | ||
) | ||
input_dtypes = {v.dtype for v in consolidated_state_dict.values()} | ||
logging.warning(f"input_dtypes {input_dtypes}") | ||
model_parallel_group, _ = dist.new_subgroups(group_size=model_parallel_world_size) | ||
for fqn, tensor in consolidated_state_dict.items(): | ||
# Hack for buffer | ||
if "rope.freqs" in fqn: | ||
dist_state_dict[fqn] = tensor.clone() | ||
continue | ||
if _is_tp_sharded(fqn): | ||
|
||
tensor, _ = _unshard_param( | ||
ref_state_dict, | ||
fqn, | ||
model_parallel_group, | ||
grp_gloo, | ||
tensor, | ||
tensor.shape, | ||
) | ||
if use_dtensor: | ||
|
||
assert mesh is not None | ||
tensor = _ext_chunk_dtensor( | ||
tensor=tensor.contiguous(), | ||
rank=dist.get_rank(), | ||
device_mesh=mesh, | ||
) | ||
|
||
else: | ||
|
||
tensor = _ext_chunk_tensor( | ||
tensor=tensor.contiguous(), | ||
rank=dist.get_rank(), | ||
world_size=dist.get_world_size(), | ||
num_devices_per_node=torch.cuda.device_count(), # TODO: this is not accurate if user set CUDA_VISIBLE_DEVICES | ||
pg=dist.distributed_c10d._get_default_group(), # TODO: this should be the FSDP process group | ||
) | ||
|
||
dist_state_dict[fqn] = tensor | ||
dtypes = {v.dtype for v in dist_state_dict.values()} | ||
logging.warning(f"Made dist_state_dict with dtypes {dtypes}") | ||
return dist_state_dict | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from llama2 import Llama | ||
import torch.distributed as dist | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
import abc | ||
import fire | ||
|
||
|
||
def convert_checkpoints( | ||
original_ckpt_dir: str, | ||
save_checkpoint_dir: str, | ||
tokenizer_path: str, | ||
model_parallel_size: int, | ||
max_seq_len: int=512, | ||
max_batch_size: int=4, | ||
): | ||
dist.init_process_group("nccl") | ||
|
||
Llama.convert_checkpoints( | ||
original_ckpt_dir=original_ckpt_dir, | ||
save_checkpoint_dir=save_checkpoint_dir, | ||
tokenizer_path= tokenizer_path, | ||
max_seq_len=max_seq_len, | ||
max_batch_size=max_batch_size, | ||
model_parallel_size=model_parallel_size, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(convert_checkpoints) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[ | ||
[ | ||
{ | ||
"role": "user", | ||
"content": "what is the recipe of mayonnaise?" | ||
} | ||
] | ||
|
||
] |
Oops, something went wrong.