Skip to content

Commit

Permalink
adding TP llama example (#2623)
Browse files Browse the repository at this point in the history
* adding TP llama example

* clean up

* adding check point converter

* clean up

* addressingt the comments

* fixing the error handling

* clean up

* fixing the batch issue for chat

* adding min max gpus for inference

* adding max seq len and batch size explanations

* clean up

* lowering the new token gen number

* clean up

* fixing the hard coded tp_degree

* adding flash attention v2

* adding note for tp size

* fixing spell checks

* update packaging step

* update packaing step

* addressing comments on configs

* address comments on model config

* Update REAME.md

* Update REAME.md

* Update REAME.md

* Update REAME.md

---------

Co-authored-by: lxning <[email protected]>
  • Loading branch information
HamidShojanazeri and lxning authored Oct 5, 2023
1 parent e346a93 commit f10a071
Show file tree
Hide file tree
Showing 11 changed files with 1,618 additions and 0 deletions.
139 changes: 139 additions & 0 deletions examples/large_models/tp_llama/REAME.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Serving Llama2 with PyTorch Native Tensor Parallelism

This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) as presented in the original [Llama repo](https://github.com/facebookresearch/llama/tree/main) using PyTorch(PT) Tensor Parallel (TP) APIs, which under the hood make use of DTensors. It basically, takes a sharding plan for linear layers in MLP and Attention blocks of Llama2 model and make a TP model distributed over multiple GPUs. In the following, we show the steps how to use this and serve the Llama2 7-70B model with Torchserve.

Here we convert the Meta Llama2 model, which is based on Fairscale TP layers to PT distributed compliant checkpoints and use PT TP (DTensor) API to run the Distributed inference.

**Note** The following has been tested on A100 GPUs with 40 GB memory so far.


### How to use it?


1- Make sure you have access to Llama2 weights on [HF model hub](https://huggingface.co/meta-llama), there is a form you need to fill up and within few mins you will get access. Any Llama2 model name on the hub **without -hf** is Meta/FAIR weight.

Make sure you are signed up in HF as well, you will need your API token than can be accessed from [here](https://huggingface.co/settings/tokens), make sure to use the same email for accessing the weights as email you signed in to HF.

Once you have the access, in your terminal login to HF

```
huggingface-cli login YOUR_TOKEN
```

### Step 1: Install requirements

Make sure to have PyTorch Nighlies installed.

```
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
pip install transformers
```

### Step 2: Download model

Login into HuggingFace hub with token by running the below command, **make sure to specify the right name for the Llama2 model from [HuggingFace (HF) model hub](https://huggingface.co/meta-llama), any model name on the model hub without -hf is Meta original model/ checkpoints and we need them not the HF converted versions.**



```bash
huggingface-cli login
```
paste the token generated from HuggingFace hub. Make sure `use_auth_token=True` is in [Download script](../utils/Download_model.py).

```bash
python ../utils/Download_model.py --model_name meta-llama/Llama-2-7b
```
The script prints the path where the model is downloaded as below.

`model/models--meta-llama--Llama-2-7b/snapshots/365ffa8f1a6c455d3e2028ae658236b4b85ba824`


### Step 3: Convert the "Meta" checkpoints to PyTorch Distributed compliant checkpoints

Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B `--model_parallel_size 1` for 13B would be `--model_parallel_size 2` and 70B `model_parallel_size 8`, you can also set `--nproc_per_node ` accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model.

You would be able to use larger number of processes/ TP size when load the model back. For example if you have converted the `13B` checkpoints with `--nproc_per_node 2`, during the inference you can use `--nproc_per_node` be `[2, max_num_available_gpu]` which you are changing the world_size and effectively the TP size. The recommendation here is to keep the TP size as shown above respective to model size, 7B (TP Size =1), 13B (TP Size =2), 70B (TP Size =8), unless your benchmark and your batch size/ compute load compensate for communication cost.


This will save the model args in `model_args.json`, during the inference step you need to pass this json file for build the model. Make sure you are setting `--max_seq_len` which is the maximum sequence length for input text (context length) and `--max_batch_size` which is maximum batch size for inference to respective values. These two values will be used to construct the KV cache.

```
torchrun --nnodes 1 --nproc_per_node 8 convert_checkpoints.py --original_ckpt_dir PATH/TO/MODEL/CHECKPOINTS --tokenizer_path PATH/TO/MODEL/CHECKPOINTS/tokenizer.model --model_parallel_size 1 --save_checkpoint_dir converted_checkpoints --max_seq_len 512 --max_batch_size 2
```



### Step 4: set up the configs:

Lets setup configs in `model-config.yaml`

```
#frontend settings
minWorkers: 1
maxWorkers: 1
maxBatchDelay: 200
responseTimeout: 300
parallelType: "tp"
deviceType: "gpu"
torchrun:
nproc-per-node: 8 # TP size
handler:
converted_ckpt_dir: "converted_checkpoints"
tokenizer_path: "tokenizer.model"
model_args_path: "model_args.json"
max_seq_len: 512
max_batch_size: 6
max_new_tokens: 50
temperature: 0.6
top_p: 0.9
manual_seed: 40
mode: "text_completion" #choices are text_completion, chat
```

### step 5: Create the mar file:
Create the mar file using the following command here.

```
torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py"
mv converted_checkpoints llama
mv PATH/TO/MODEL/CHECKPOINTS/tokenizer.model llama
mv model_args.json llama
```

### Step 6: Serve the model:

```
torchserve --ncs --start --model-store model_store --models llama
```

### Step 6: Send inference request:

Text completion example :


```bash

curl -v "http://localhost:8080/predictions/llama" -T sample_text.txt

```


Chat example :


```bash

curl -v "http://localhost:8080/predictions/llama" -T dialogs.txt

```
195 changes: 195 additions & 0 deletions examples/large_models/tp_llama/checkpoint_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import logging
from dataclasses import dataclass
from typing import Dict, List, Union

import torch
import torch.distributed as dist
from torch import nn, Tensor
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DeviceMesh, DTensor
from torch.distributed.fsdp._fsdp_extensions import (
_ext_chunk_dtensor,
_ext_chunk_tensor,
)

def _verify_fqn_across_ranks(fqn, grp_gloo):
olist = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(olist, fqn, group=grp_gloo)
assert len(set(olist)) == 1
assert olist[0] == fqn

def _all_gather_into_list(data_tensor, model_parallel_group):
tensor_list = [
torch.zeros_like(data_tensor).cuda()
for _ in range(dist.get_world_size(model_parallel_group))
]
dist.all_gather(tensor_list, data_tensor.cuda(), group=model_parallel_group)
return tensor_list


def _is_tp_sharded(fqn: str) -> bool:
"""
Returns whether a tensor given by the fqn is tensor parallel sharded.
NOTE: this is currently done by inspection of the MF model and is quite
brittle and would need to be updated if the MF sharding changes.
"""
return (
"attention" in fqn
or "feed_forward" in fqn
or "output" in fqn
or "tok_embeddings" in fqn
)

def _unshard_param(
ref_state_dict,
fqn,
model_parallel_group,
grp_gloo,
data_tensor,
tp_sharded_shape,
):
"""
Unshards the row or col-wise sharded parameter.
For rowwise, this is done by reshaping into the local shape, allgathering,
and stacking rows. For colwise, the only difference is we stack columns.
This is done via vstack and column_stack respectively.
"""
mp_size = dist.get_world_size(model_parallel_group)

ref_shape = ref_state_dict[fqn].shape
assert (
ref_shape[0] == tp_sharded_shape[0] or ref_shape[1] == tp_sharded_shape[1]
), f"Expected sharded shape to match either row or col-wise, but does not: {ref_shape} {tp_sharded_shape}"
_verify_fqn_across_ranks(fqn, grp_gloo)
if ref_shape[0] != tp_sharded_shape[0]:
assert ref_shape[0] == tp_sharded_shape[0] * mp_size
# reshape the flat data_tensor into the rowwise shape
data_tensor = data_tensor.reshape(tp_sharded_shape)
# now, all_gather such tensors
tensor_list = _all_gather_into_list(data_tensor, model_parallel_group)
# stack rowwise to produce the final unsharded tensor
data_tensor = torch.vstack(tensor_list).cpu()
assert data_tensor.shape == ref_shape
full_shape = data_tensor.shape
elif (
len(ref_shape) > 1
and len(tp_sharded_shape) > 1
and ref_shape[1] != tp_sharded_shape[1]
):
assert ref_shape[1] == mp_size * tp_sharded_shape[1]
# first, reshape the flat data_tensor into the colwise shape
data_tensor = data_tensor.reshape(tp_sharded_shape)
tensor_list = _all_gather_into_list(data_tensor, model_parallel_group)
data_tensor = torch.column_stack(tensor_list).cpu()
assert data_tensor.shape == ref_shape, f"{data_tensor.shape} vs {ref_shape}"
full_shape = data_tensor.shape
else:
assert ref_shape == tp_sharded_shape # not tensor parallel sharded
full_shape = tp_sharded_shape
logging.warning(f"{fqn} {ref_shape} {full_shape} - not sharded")
return data_tensor, full_shape


def build_distributed_state_dict_from_consolidated(
model: nn.Module,
consolidated_state_dict: Dict[str, Tensor],
model_parallel_world_size: int,
offload_to_cpu: bool = False,
use_dtensor: bool = False,
) -> Dict[str, Union[Tensor, DTensor, ShardedTensor]]:
"""
Main API that takes a model (with no parallelism applied) and a fairscale checkpoint
and builds a PT-D compliant distributed state dict. Note that this expects a consolidated
checkpoint.
Args:
model (torch.nn.Module): module with no parallelism applied (i.e. result of `build_model` with parallel_impl=ParallelImpl.NONE)
fs_state_dict (Dict[str, Any]): Fairscale consolidated
offload_to_cpu (bool): Whether to offload the resulting state_dict to CPU (default: False)
use_dtensor (bool): Whether to use PyTorch Distributed Tensor instead of ShardedTensor (default: False)
(this will eventually default to True)
model_parallel_world_size: Model parallel world size that was used to create the consolidated checkpoint.
This can be obtained by checking the number of consolidated0x.pth files in the checkpoint directory.
Example usage::
```
MODEL_PARALLEL_SIZE = 8
ckpt_path = get_consolidated_ckpt_path(
ckpt_dir=PTH_65b, mp_rank=local_rank, mp_size=MODEL_PARALLEL_SIZE
)
state_dict = torch.load(ckpt_path)
# Build a local LLaMA with no parallelism
model = build_model(...)
sharded_state_dict = build_distributed_state_dict_from_consolidated(
model, state_dict, model_parallel_world_size=MODEL_PARALLEL_SIZE,
)
# Wrap model with PT-native APIs + load
model = FSDP(model)
FSDP.set_state_dict_type(StateDictType.SHARDED_STATE_DICT)
model.load_state_dict(sharded_state_dict)
```
Note: Please make sure to pass an unsharded model as the model arg! Otherwise, things will not
work.
This distributed state dict is a mapping of FQN: ShardedTensor/DTensor. It will be replaced with
DTensor once DTensor 2D checkpoint format is fully rolled out.
Note: This has only been tested for loading state_dict into PT-D FSDP sharded_state_dict for now.
"""
torch._C._log_api_usage_once("build_distributed_state_dict")
dist_state_dict = {}
ref_state_dict = model.state_dict()
grp_gloo = dist.new_group(backend="gloo")
# TODO: this should be the FSDP device mesh
mesh = (
DeviceMesh(
device_type="cuda",
mesh=list(range(dist.get_world_size())),
)
if use_dtensor
else None
)
input_dtypes = {v.dtype for v in consolidated_state_dict.values()}
logging.warning(f"input_dtypes {input_dtypes}")
model_parallel_group, _ = dist.new_subgroups(group_size=model_parallel_world_size)
for fqn, tensor in consolidated_state_dict.items():
# Hack for buffer
if "rope.freqs" in fqn:
dist_state_dict[fqn] = tensor.clone()
continue
if _is_tp_sharded(fqn):

tensor, _ = _unshard_param(
ref_state_dict,
fqn,
model_parallel_group,
grp_gloo,
tensor,
tensor.shape,
)
if use_dtensor:

assert mesh is not None
tensor = _ext_chunk_dtensor(
tensor=tensor.contiguous(),
rank=dist.get_rank(),
device_mesh=mesh,
)

else:

tensor = _ext_chunk_tensor(
tensor=tensor.contiguous(),
rank=dist.get_rank(),
world_size=dist.get_world_size(),
num_devices_per_node=torch.cuda.device_count(), # TODO: this is not accurate if user set CUDA_VISIBLE_DEVICES
pg=dist.distributed_c10d._get_default_group(), # TODO: this should be the FSDP process group
)

dist_state_dict[fqn] = tensor
dtypes = {v.dtype for v in dist_state_dict.values()}
logging.warning(f"Made dist_state_dict with dtypes {dtypes}")
return dist_state_dict

31 changes: 31 additions & 0 deletions examples/large_models/tp_llama/convert_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from llama2 import Llama
import torch.distributed as dist
from typing import Any, Callable, Dict, List, Optional, Tuple
import abc
import fire


def convert_checkpoints(
original_ckpt_dir: str,
save_checkpoint_dir: str,
tokenizer_path: str,
model_parallel_size: int,
max_seq_len: int=512,
max_batch_size: int=4,
):
dist.init_process_group("nccl")

Llama.convert_checkpoints(
original_ckpt_dir=original_ckpt_dir,
save_checkpoint_dir=save_checkpoint_dir,
tokenizer_path= tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)


if __name__ == "__main__":
fire.Fire(convert_checkpoints)

9 changes: 9 additions & 0 deletions examples/large_models/tp_llama/dialogs.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[
[
{
"role": "user",
"content": "what is the recipe of mayonnaise?"
}
]

]
Loading

0 comments on commit f10a071

Please sign in to comment.