From f10a07149caeae765d774ed98abe2d57ac01c679 Mon Sep 17 00:00:00 2001 From: Hamid Shojanazeri Date: Wed, 4 Oct 2023 18:18:45 -0700 Subject: [PATCH] adding TP llama example (#2623) * adding TP llama example * clean up * adding check point converter * clean up * addressingt the comments * fixing the error handling * clean up * fixing the batch issue for chat * adding min max gpus for inference * adding max seq len and batch size explanations * clean up * lowering the new token gen number * clean up * fixing the hard coded tp_degree * adding flash attention v2 * adding note for tp size * fixing spell checks * update packaging step * update packaing step * addressing comments on configs * address comments on model config * Update REAME.md * Update REAME.md * Update REAME.md * Update REAME.md --------- Co-authored-by: lxning <23464292+lxning@users.noreply.github.com> --- examples/large_models/tp_llama/REAME.md | 139 ++++ .../tp_llama/checkpoint_converter.py | 195 ++++++ .../tp_llama/convert_checkpoints.py | 31 + examples/large_models/tp_llama/dialogs.txt | 9 + examples/large_models/tp_llama/generate.py | 341 +++++++++ .../large_models/tp_llama/llama-handler.py | 169 +++++ examples/large_models/tp_llama/llama2.py | 653 ++++++++++++++++++ .../large_models/tp_llama/llama2_tokenizer.py | 44 ++ .../large_models/tp_llama/model-config.yaml | 23 + .../large_models/tp_llama/sample_text.txt | 1 + ts_scripts/spellcheck_conf/wordlist.txt | 13 + 11 files changed, 1618 insertions(+) create mode 100644 examples/large_models/tp_llama/REAME.md create mode 100644 examples/large_models/tp_llama/checkpoint_converter.py create mode 100644 examples/large_models/tp_llama/convert_checkpoints.py create mode 100644 examples/large_models/tp_llama/dialogs.txt create mode 100644 examples/large_models/tp_llama/generate.py create mode 100644 examples/large_models/tp_llama/llama-handler.py create mode 100644 examples/large_models/tp_llama/llama2.py create mode 100644 examples/large_models/tp_llama/llama2_tokenizer.py create mode 100644 examples/large_models/tp_llama/model-config.yaml create mode 100644 examples/large_models/tp_llama/sample_text.txt diff --git a/examples/large_models/tp_llama/REAME.md b/examples/large_models/tp_llama/REAME.md new file mode 100644 index 0000000000..9f98b25481 --- /dev/null +++ b/examples/large_models/tp_llama/REAME.md @@ -0,0 +1,139 @@ +# Serving Llama2 with PyTorch Native Tensor Parallelism + +This document briefs on serving the [Llama 2](https://huggingface.co/meta-llama) as presented in the original [Llama repo](https://github.com/facebookresearch/llama/tree/main) using PyTorch(PT) Tensor Parallel (TP) APIs, which under the hood make use of DTensors. It basically, takes a sharding plan for linear layers in MLP and Attention blocks of Llama2 model and make a TP model distributed over multiple GPUs. In the following, we show the steps how to use this and serve the Llama2 7-70B model with Torchserve. + +Here we convert the Meta Llama2 model, which is based on Fairscale TP layers to PT distributed compliant checkpoints and use PT TP (DTensor) API to run the Distributed inference. + +**Note** The following has been tested on A100 GPUs with 40 GB memory so far. + + +### How to use it? + + +1- Make sure you have access to Llama2 weights on [HF model hub](https://huggingface.co/meta-llama), there is a form you need to fill up and within few mins you will get access. Any Llama2 model name on the hub **without -hf** is Meta/FAIR weight. + +Make sure you are signed up in HF as well, you will need your API token than can be accessed from [here](https://huggingface.co/settings/tokens), make sure to use the same email for accessing the weights as email you signed in to HF. + +Once you have the access, in your terminal login to HF + +``` +huggingface-cli login YOUR_TOKEN + +``` + +### Step 1: Install requirements + +Make sure to have PyTorch Nighlies installed. + +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 + +pip install transformers + +``` + +### Step 2: Download model + +Login into HuggingFace hub with token by running the below command, **make sure to specify the right name for the Llama2 model from [HuggingFace (HF) model hub](https://huggingface.co/meta-llama), any model name on the model hub without -hf is Meta original model/ checkpoints and we need them not the HF converted versions.** + + + +```bash +huggingface-cli login +``` +paste the token generated from HuggingFace hub. Make sure `use_auth_token=True` is in [Download script](../utils/Download_model.py). + +```bash +python ../utils/Download_model.py --model_name meta-llama/Llama-2-7b +``` +The script prints the path where the model is downloaded as below. + +`model/models--meta-llama--Llama-2-7b/snapshots/365ffa8f1a6c455d3e2028ae658236b4b85ba824` + + +### Step 3: Convert the "Meta" checkpoints to PyTorch Distributed compliant checkpoints + +Convert the checkpoints to PT-D compliant checkpoints as follows, note that for 7B `--model_parallel_size 1` for 13B would be `--model_parallel_size 2` and 70B `model_parallel_size 8`, you can also set `--nproc_per_node ` accordingly. PT-D compliant support flexible world_size when loading back the checkpoints into TP(lized) model. + +You would be able to use larger number of processes/ TP size when load the model back. For example if you have converted the `13B` checkpoints with `--nproc_per_node 2`, during the inference you can use `--nproc_per_node` be `[2, max_num_available_gpu]` which you are changing the world_size and effectively the TP size. The recommendation here is to keep the TP size as shown above respective to model size, 7B (TP Size =1), 13B (TP Size =2), 70B (TP Size =8), unless your benchmark and your batch size/ compute load compensate for communication cost. + + +This will save the model args in `model_args.json`, during the inference step you need to pass this json file for build the model. Make sure you are setting `--max_seq_len` which is the maximum sequence length for input text (context length) and `--max_batch_size` which is maximum batch size for inference to respective values. These two values will be used to construct the KV cache. + +``` +torchrun --nnodes 1 --nproc_per_node 8 convert_checkpoints.py --original_ckpt_dir PATH/TO/MODEL/CHECKPOINTS --tokenizer_path PATH/TO/MODEL/CHECKPOINTS/tokenizer.model --model_parallel_size 1 --save_checkpoint_dir converted_checkpoints --max_seq_len 512 --max_batch_size 2 + +``` + + + +### Step 4: set up the configs: + +Lets setup configs in `model-config.yaml` + +``` +#frontend settings +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 200 +responseTimeout: 300 +parallelType: "tp" +deviceType: "gpu" + +torchrun: + nproc-per-node: 8 # TP size + +handler: + converted_ckpt_dir: "converted_checkpoints" + tokenizer_path: "tokenizer.model" + model_args_path: "model_args.json" + max_seq_len: 512 + max_batch_size: 6 + max_new_tokens: 50 + temperature: 0.6 + top_p: 0.9 + manual_seed: 40 + mode: "text_completion" #choices are text_completion, chat +``` + +### step 5: Create the mar file: +Create the mar file using the following command here. + +``` +torch-model-archiver --model-name llama --version 1.0 --handler llama-handler.py --config-file model-config.yaml --archive-format no-archive --extra-files "llama2.py,llama2_tokenizer.py,generate.py,checkpoint_converter.py" + +mv converted_checkpoints llama + +mv PATH/TO/MODEL/CHECKPOINTS/tokenizer.model llama + +mv model_args.json llama + +``` + +### Step 6: Serve the model: + +``` +torchserve --ncs --start --model-store model_store --models llama + +``` + +### Step 6: Send inference request: + +Text completion example : + + +```bash + +curl -v "http://localhost:8080/predictions/llama" -T sample_text.txt + +``` + + +Chat example : + + +```bash + +curl -v "http://localhost:8080/predictions/llama" -T dialogs.txt + +``` diff --git a/examples/large_models/tp_llama/checkpoint_converter.py b/examples/large_models/tp_llama/checkpoint_converter.py new file mode 100644 index 0000000000..2e012c5c55 --- /dev/null +++ b/examples/large_models/tp_llama/checkpoint_converter.py @@ -0,0 +1,195 @@ +import logging +from dataclasses import dataclass +from typing import Dict, List, Union + +import torch +import torch.distributed as dist +from torch import nn, Tensor +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed.fsdp._fsdp_extensions import ( + _ext_chunk_dtensor, + _ext_chunk_tensor, +) + +def _verify_fqn_across_ranks(fqn, grp_gloo): + olist = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(olist, fqn, group=grp_gloo) + assert len(set(olist)) == 1 + assert olist[0] == fqn + +def _all_gather_into_list(data_tensor, model_parallel_group): + tensor_list = [ + torch.zeros_like(data_tensor).cuda() + for _ in range(dist.get_world_size(model_parallel_group)) + ] + dist.all_gather(tensor_list, data_tensor.cuda(), group=model_parallel_group) + return tensor_list + + +def _is_tp_sharded(fqn: str) -> bool: + """ + Returns whether a tensor given by the fqn is tensor parallel sharded. + NOTE: this is currently done by inspection of the MF model and is quite + brittle and would need to be updated if the MF sharding changes. + """ + return ( + "attention" in fqn + or "feed_forward" in fqn + or "output" in fqn + or "tok_embeddings" in fqn + ) + +def _unshard_param( + ref_state_dict, + fqn, + model_parallel_group, + grp_gloo, + data_tensor, + tp_sharded_shape, + ): + """ + Unshards the row or col-wise sharded parameter. + For rowwise, this is done by reshaping into the local shape, allgathering, + and stacking rows. For colwise, the only difference is we stack columns. + This is done via vstack and column_stack respectively. + """ + mp_size = dist.get_world_size(model_parallel_group) + + ref_shape = ref_state_dict[fqn].shape + assert ( + ref_shape[0] == tp_sharded_shape[0] or ref_shape[1] == tp_sharded_shape[1] + ), f"Expected sharded shape to match either row or col-wise, but does not: {ref_shape} {tp_sharded_shape}" + _verify_fqn_across_ranks(fqn, grp_gloo) + if ref_shape[0] != tp_sharded_shape[0]: + assert ref_shape[0] == tp_sharded_shape[0] * mp_size + # reshape the flat data_tensor into the rowwise shape + data_tensor = data_tensor.reshape(tp_sharded_shape) + # now, all_gather such tensors + tensor_list = _all_gather_into_list(data_tensor, model_parallel_group) + # stack rowwise to produce the final unsharded tensor + data_tensor = torch.vstack(tensor_list).cpu() + assert data_tensor.shape == ref_shape + full_shape = data_tensor.shape + elif ( + len(ref_shape) > 1 + and len(tp_sharded_shape) > 1 + and ref_shape[1] != tp_sharded_shape[1] + ): + assert ref_shape[1] == mp_size * tp_sharded_shape[1] + # first, reshape the flat data_tensor into the colwise shape + data_tensor = data_tensor.reshape(tp_sharded_shape) + tensor_list = _all_gather_into_list(data_tensor, model_parallel_group) + data_tensor = torch.column_stack(tensor_list).cpu() + assert data_tensor.shape == ref_shape, f"{data_tensor.shape} vs {ref_shape}" + full_shape = data_tensor.shape + else: + assert ref_shape == tp_sharded_shape # not tensor parallel sharded + full_shape = tp_sharded_shape + logging.warning(f"{fqn} {ref_shape} {full_shape} - not sharded") + return data_tensor, full_shape + + +def build_distributed_state_dict_from_consolidated( + model: nn.Module, + consolidated_state_dict: Dict[str, Tensor], + model_parallel_world_size: int, + offload_to_cpu: bool = False, + use_dtensor: bool = False, +) -> Dict[str, Union[Tensor, DTensor, ShardedTensor]]: + """ + Main API that takes a model (with no parallelism applied) and a fairscale checkpoint + and builds a PT-D compliant distributed state dict. Note that this expects a consolidated + checkpoint. + + Args: + model (torch.nn.Module): module with no parallelism applied (i.e. result of `build_model` with parallel_impl=ParallelImpl.NONE) + fs_state_dict (Dict[str, Any]): Fairscale consolidated + offload_to_cpu (bool): Whether to offload the resulting state_dict to CPU (default: False) + use_dtensor (bool): Whether to use PyTorch Distributed Tensor instead of ShardedTensor (default: False) + (this will eventually default to True) + model_parallel_world_size: Model parallel world size that was used to create the consolidated checkpoint. + This can be obtained by checking the number of consolidated0x.pth files in the checkpoint directory. + + Example usage:: + ``` + + MODEL_PARALLEL_SIZE = 8 + ckpt_path = get_consolidated_ckpt_path( + ckpt_dir=PTH_65b, mp_rank=local_rank, mp_size=MODEL_PARALLEL_SIZE + ) + state_dict = torch.load(ckpt_path) + # Build a local LLaMA with no parallelism + model = build_model(...) + sharded_state_dict = build_distributed_state_dict_from_consolidated( + model, state_dict, model_parallel_world_size=MODEL_PARALLEL_SIZE, + ) + # Wrap model with PT-native APIs + load + model = FSDP(model) + FSDP.set_state_dict_type(StateDictType.SHARDED_STATE_DICT) + model.load_state_dict(sharded_state_dict) + ``` + + Note: Please make sure to pass an unsharded model as the model arg! Otherwise, things will not + work. + + This distributed state dict is a mapping of FQN: ShardedTensor/DTensor. It will be replaced with + DTensor once DTensor 2D checkpoint format is fully rolled out. + + Note: This has only been tested for loading state_dict into PT-D FSDP sharded_state_dict for now. + """ + torch._C._log_api_usage_once("build_distributed_state_dict") + dist_state_dict = {} + ref_state_dict = model.state_dict() + grp_gloo = dist.new_group(backend="gloo") + # TODO: this should be the FSDP device mesh + mesh = ( + DeviceMesh( + device_type="cuda", + mesh=list(range(dist.get_world_size())), + ) + if use_dtensor + else None + ) + input_dtypes = {v.dtype for v in consolidated_state_dict.values()} + logging.warning(f"input_dtypes {input_dtypes}") + model_parallel_group, _ = dist.new_subgroups(group_size=model_parallel_world_size) + for fqn, tensor in consolidated_state_dict.items(): + # Hack for buffer + if "rope.freqs" in fqn: + dist_state_dict[fqn] = tensor.clone() + continue + if _is_tp_sharded(fqn): + + tensor, _ = _unshard_param( + ref_state_dict, + fqn, + model_parallel_group, + grp_gloo, + tensor, + tensor.shape, + ) + if use_dtensor: + + assert mesh is not None + tensor = _ext_chunk_dtensor( + tensor=tensor.contiguous(), + rank=dist.get_rank(), + device_mesh=mesh, + ) + + else: + + tensor = _ext_chunk_tensor( + tensor=tensor.contiguous(), + rank=dist.get_rank(), + world_size=dist.get_world_size(), + num_devices_per_node=torch.cuda.device_count(), # TODO: this is not accurate if user set CUDA_VISIBLE_DEVICES + pg=dist.distributed_c10d._get_default_group(), # TODO: this should be the FSDP process group + ) + + dist_state_dict[fqn] = tensor + dtypes = {v.dtype for v in dist_state_dict.values()} + logging.warning(f"Made dist_state_dict with dtypes {dtypes}") + return dist_state_dict + diff --git a/examples/large_models/tp_llama/convert_checkpoints.py b/examples/large_models/tp_llama/convert_checkpoints.py new file mode 100644 index 0000000000..f281590b45 --- /dev/null +++ b/examples/large_models/tp_llama/convert_checkpoints.py @@ -0,0 +1,31 @@ +import torch +from llama2 import Llama +import torch.distributed as dist +from typing import Any, Callable, Dict, List, Optional, Tuple +import abc +import fire + + +def convert_checkpoints( + original_ckpt_dir: str, + save_checkpoint_dir: str, + tokenizer_path: str, + model_parallel_size: int, + max_seq_len: int=512, + max_batch_size: int=4, + ): + dist.init_process_group("nccl") + + Llama.convert_checkpoints( + original_ckpt_dir=original_ckpt_dir, + save_checkpoint_dir=save_checkpoint_dir, + tokenizer_path= tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + model_parallel_size=model_parallel_size, + ) + + +if __name__ == "__main__": + fire.Fire(convert_checkpoints) + \ No newline at end of file diff --git a/examples/large_models/tp_llama/dialogs.txt b/examples/large_models/tp_llama/dialogs.txt new file mode 100644 index 0000000000..72ac614f4c --- /dev/null +++ b/examples/large_models/tp_llama/dialogs.txt @@ -0,0 +1,9 @@ +[ + [ + { + "role": "user", + "content": "what is the recipe of mayonnaise?" + } + ] + +] \ No newline at end of file diff --git a/examples/large_models/tp_llama/generate.py b/examples/large_models/tp_llama/generate.py new file mode 100644 index 0000000000..4a6a625ccc --- /dev/null +++ b/examples/large_models/tp_llama/generate.py @@ -0,0 +1,341 @@ +import torch +from llama2 import Llama +import torch.distributed as dist +from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import List, Literal, Optional, Tuple, TypedDict +import abc +import os +import sys +import fire +current_working_directory = os.getcwd() +sys.path.insert(0,current_working_directory) + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +SPECIAL_TAGS = [B_INST, E_INST, "<>", "<>"] +UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt." + +def sample_top_p(probs, p): + """ + Perform top-p (nucleus) sampling on a probability distribution. + + Args: + probs (torch.Tensor): Probability distribution tensor. + p (float): Probability threshold for top-p sampling. + + Returns: + torch.Tensor: Sampled token indices. + + Note: + Top-p sampling selects the smallest set of tokens whose cumulative probability mass + exceeds the threshold p. The distribution is renormalized based on the selected tokens. + + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +# @torch.inference_mode() +with torch.no_grad(): + def generate(model, + tokenizer, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], Optional[List[List[float]]]]: + """ + Generate text sequences based on provided prompts using the language generation model. + + Args: + prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers. + max_gen_len (int): Maximum length of the generated text sequence. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities. + + Note: + This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + bsz = len(prompt_tokens) + assert bsz <= model.max_batch_size, (bsz, model.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= model.max_seq_len + total_len = min(model.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + if min_prompt_len == total_len: + logits = model.forward(tokens, prev_pos) + token_logprobs = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens, + reduction="none", + ignore_index=pad_id, + ) + + for cur_pos in range(min_prompt_len, total_len): + logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + probs = None + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + if tokenizer.eos_id in toks: + eos_idx = toks.index(tokenizer.eos_id) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, out_logprobs if logprobs else None) + + +def text_completion( + model, + tokenizer, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, +) -> List[CompletionPrediction]: + """ + Perform text completion for a list of prompts using the language generation model. + + Args: + prompts (List[str]): List of text prompts for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + List[CompletionPrediction]: List of completion predictions, each containing the generated text completion. + + Note: + This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = model.max_seq_len - 1 + prompt_tokens = [tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, generation_logprobs = generate( + model, + tokenizer, + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": tokenizer.decode(t), + "tokens": [tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [{"generation": tokenizer.decode(t)} for t in generation_tokens] + +def chat_completion( + model, + tokenizer, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, +) -> List[ChatPrediction]: + """ + Generate assistant responses for a list of conversational dialogs using the language generation model. + + Args: + dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + max_gen_len (Optional[int], optional): Maximum length of the generated response sequence. + If not provided, it's set to the model's maximum sequence length minus 1. + logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False. + + Returns: + List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response. + + Raises: + AssertionError: If the last message in a dialog is not from the user. + AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order. + + Note: + This method generates assistant responses for the provided conversational dialogs. + It employs nucleus sampling to introduce controlled randomness in text generation. + If logprobs is True, token log probabilities are computed for each generated token. + + """ + if max_gen_len is None: + max_gen_len = model.max_seq_len - 1 + prompt_tokens = [] + unsafe_requests = [] + for dialog in dialogs: + unsafe_requests.append( + any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog]) + ) + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = generate( + model, + tokenizer, + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": tokenizer.decode(t) + if not unsafe + else UNSAFE_ERROR, + }, + "tokens": [tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i, unsafe in zip( + generation_tokens, generation_logprobs, unsafe_requests + ) + ] + return [ + { + "generation": { + "role": "assistant", + "content": tokenizer.decode(t) if not unsafe else UNSAFE_ERROR, + } + } + for t, unsafe in zip(generation_tokens, unsafe_requests) + ] + + \ No newline at end of file diff --git a/examples/large_models/tp_llama/llama-handler.py b/examples/large_models/tp_llama/llama-handler.py new file mode 100644 index 0000000000..83ec208f7b --- /dev/null +++ b/examples/large_models/tp_llama/llama-handler.py @@ -0,0 +1,169 @@ +import logging +import time +from abc import ABC +import json +import os +import sys +import importlib.util + +import packaging.version +import requests +import torch +import transformers +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ts.torch_handler.base_handler import BaseHandler +current_working_directory = os.getcwd() +sys.path.insert(0,current_working_directory) +from llama2 import Llama +from generate import chat_completion, text_completion, Dialog + +logger = logging.getLogger(__name__) +logger.info("Transformers version %s", transformers.__version__) +if packaging.version.parse(torch.__version__) >= packaging.version.parse("2.0.0"): + logger.info("PyTorch version is 2.0.0 or greater") +else: + logger.info( + "PyTorch version is less than 2.0.0, initializing with meta device needs PyTorch 2.0.0 and greater" + ) + + + +class LlamaHandler(BaseHandler,ABC): + """ + Transformers handler class for sequence, token classification and question answering. + """ + + def __init__(self): + super(LlamaHandler, self).__init__() + self.initialized = False + + def initialize(self, ctx): + """ + In this initialize function, the llama model is loaded using Fairscale and + partitioned into multiple stages each on one device using PiPPy. + Args: + ctx (context): It is a JSON Object containing information + pertaining to the model artefacts parameters. + """ + # super().initialize(ctx) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("nccl") + + self.manifest = ctx.manifest + properties = ctx.system_properties + model_dir = properties.get("model_dir") + + seed = ctx.model_yaml_config["handler"]["manual_seed"] + self.mode = ctx.model_yaml_config["handler"]["mode"] + self.max_new_tokens = ctx.model_yaml_config["handler"]["max_new_tokens"] + self.temperature = ctx.model_yaml_config["handler"]["temperature"] + self.top_p = ctx.model_yaml_config["handler"]["top_p"] + + torch.manual_seed(seed) + + logger.info("Instantiating Llama model") + model_load_start = time.perf_counter() + llama_model_and_tok= Llama.build( + model_args=f'{model_dir}/{ctx.model_yaml_config["handler"]["model_args_path"]}', + converted_ckpt_dir=f'{model_dir}/{ctx.model_yaml_config["handler"]["converted_ckpt_dir"]}', + tokenizer_path= f'{model_dir}/{ctx.model_yaml_config["handler"]["tokenizer_path"]}', + ) + load_time = time.perf_counter()-model_load_start + self.model = llama_model_and_tok.model + + + self.tokenizer = llama_model_and_tok.tokenizer + + logger.info(f"Llama model from path {model_dir} loaded successfully in {load_time} seconds") + + self.initialized = True + + def preprocess(self, requests): + """ + Basic text preprocessing, based on the user's choice of application mode. + Args: + requests (list): A list of dictionaries with a "data" or "body" field, each + containing the input text to be processed. + Returns: + tuple: A tuple with two tensors: the batch of input ids and the batch of + attention masks. + """ + input_texts = [data.get("data") or data.get("body") for data in requests] + + if self.mode == "chat": + if input_texts: + + try: + dialog = json.loads(input_texts[0]) + except json.JSONDecodeError: + raise ValueError(f"Invalid JSON format in text: {input_texts[0]}") + return dialog + + elif self.mode == "text_completion": + try: + return [self.prep_input_text(text) for text in input_texts] + except TypeError: + raise ValueError("Expected input_texts to contain text (string) values.") + else: + raise NotImplementedError("Unsupported mode. Please select a valid mode.") + + + def prep_input_text(self, input_text): + """ + preparing a single input text using the tokenizer. + Args: + input_text (str): The input text to be encoded. + Returns: + decoded input text + """ + if isinstance(input_text, (bytes, bytearray)): + input_text = input_text.decode("utf-8") + logger.debug("Received text: '%s'", input_text) + + return input_text + + def inference(self, input_batch): + """ + Generate tokens based on prompts + Args: + input_batch : a batch of input texts + Returns: + list: A list of strings with the predicted values for each input text in the batch. + """ + + if self.mode == "chat": + results = chat_completion( + self.model, + self.tokenizer, + input_batch, + max_gen_len=self.max_new_tokens, + temperature=self.temperature, + top_p=self.top_p, + ) + + elif self.mode == "text_completion": + results = text_completion( + self.model, + self.tokenizer, + input_batch, + max_gen_len=self.max_new_tokens, + temperature=self.temperature, + top_p=self.top_p, + ) + + + + return results + + def postprocess(self, inference_output): + """Post Process Function converts the predicted response into Torchserve readable format. + Args: + inference_output (list): It contains the predicted response of the input text. + Returns: + (list): Returns a list of the Predictions and Explanations. + """ + + logger.info("Generated text: %s", inference_output) + + return [inference_output] diff --git a/examples/large_models/tp_llama/llama2.py b/examples/large_models/tp_llama/llama2.py new file mode 100644 index 0000000000..f30930548e --- /dev/null +++ b/examples/large_models/tp_llama/llama2.py @@ -0,0 +1,653 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +import json +import math +import logging +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, Union + +from torch.distributed._tensor.placement_types import Replicate, Shard +import torch +import torch.distributed as dist +import torch.distributed.checkpoint as dist_cp + +import torch.nn.functional as F +from checkpoint_converter import build_distributed_state_dict_from_consolidated +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import StateDictType +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from llama2_tokenizer import Tokenizer +from torch.distributed._tensor import DeviceMesh, DTensor +from torch.distributed.tensor.parallel import ( + PairwiseParallel, + parallelize_module, + ColwiseParallel, + RowwiseParallel, + ) + +from copy import deepcopy +from dataclasses import dataclass, asdict, fields +current_working_directory = os.getcwd() +sys.path.insert(0,current_working_directory) + +log = logging.getLogger(__name__) + +def dataclass_to_json(dc): + return json.dumps(asdict(dc)) + +def json_to_dataclass(json_str, dataclass_type): + data = json.loads(json_str) + return dataclass_type(**data) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +# TODO: update this to use RMSNorm in MultiModal +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + + +class Attention(nn.Module): + def __init__( + self, + n_heads: int, + n_kv_heads: int, + dim: int, + max_batch_size: int, + max_seq_len: int, + ): + super().__init__() + tp_degree = int(os.environ["WORLD_SIZE"]) + self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads + self.n_local_heads = n_heads//tp_degree + self.n_local_kv_heads = self.n_kv_heads//tp_degree + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.wq = nn.Linear( + dim, + n_heads * self.head_dim, + bias=False, + ) + self.wk = nn.Linear( + dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wv = nn.Linear( + dim, + self.n_kv_heads * self.head_dim, + bias=False, + ) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + self._init_cache_k() + self._init_cache_v() + + def _init_cache_k(self): + self.cache_k = torch.zeros( + ( + self.max_batch_size, + self.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ) + + def _init_cache_v(self): + self.cache_v = torch.zeros( + ( + self.max_batch_size, + self.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + + bsz, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + #calling PT SDPA to enable using Flash Attention 2 and Xformer memory efficient kernels. + output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, is_causal=True) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + a= self.w1(x) + b =F.silu(a) + c= self.w3(x) + return self.w2(b*c) + + +class TransformerBlock(nn.Module): + def __init__( + self, + layer_id: int, + n_heads: int, + n_kv_heads: int, + dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + max_batch_size: int, + max_seq_len: int, + norm_eps: float, + ): + super().__init__() + self.n_heads = n_heads + self.dim = dim + self.head_dim = dim // n_heads + self.attention = Attention( + n_heads, n_kv_heads, dim, max_batch_size, max_seq_len + ) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(dim, eps=norm_eps) + self.ffn_norm = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Transformer(nn.Module): + """ + LLama2 implementation, free of any coupling to parallelism implementations, heavily drawn from + https://github.com/facebookresearch/llama. + """ + + def __init__( + self, + vocab_size: int, + n_layers: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + max_batch_size: int, + max_seq_len: int, + norm_eps: float, + ): + super().__init__() + self.vocab_size = vocab_size + self.n_layers = n_layers + self.dim = dim + self.n_heads = n_heads + self.max_seq_len = max_seq_len + self.tok_embeddings = nn.Embedding(vocab_size, dim) + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + + self.layers = torch.nn.ModuleList() + for layer_id in range(n_layers): + self.layers.append( + TransformerBlock( + layer_id, + n_heads, + n_kv_heads, + dim, + multiple_of, + ffn_dim_multiplier, + max_batch_size, + max_seq_len, + norm_eps, + ) + ) + + self.norm = RMSNorm(dim, eps=norm_eps) + self.output = nn.Linear(dim, vocab_size, bias=False) + + self.freqs_cis = precompute_freqs_cis( + self.dim // self.n_heads, self.max_seq_len * 2 + ) + + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + # print( + # f"RV: before embedding lookup, input {tokens}, start:{start_pos}", + # flush=True, + # ) + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + output = self.output(h).float() + return output + + @torch.no_grad() + def reset_parameters(self): + for layer in self.layers: + for submodule in layer.modules(): + for param_name, param in submodule.named_parameters(recurse=False): + if param.is_meta: + materialized_param = nn.Parameter( + torch.empty_like(param, dtype=torch.bfloat16, device=torch.device("cuda")) + ) + nn.init.uniform_after(materialized_param) + setattr(submodule, param_name, materialized_param) + + +### --- Utilities for model creation / loading ---- #### + + +def _build_model_args(ckpt_dir: str, max_seq_len, max_batch_size) -> ModelArgs: + """ + Reads params.json from checkpoint and builds ModelArgs to initialize + model with. + """ + params_path = os.path.join(ckpt_dir, "params.json") + with open(params_path, "r") as f: + params = json.loads(f.read()) + + # Some checkpoints have other details besides "model", fix this up and use a + # clearly specified format. + model_params = params.get("model", params) + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + dim=model_params["dim"], + n_layers=model_params["n_layers"], + n_heads=model_params["n_heads"], + n_kv_heads=model_params.get("n_kv_heads", model_params["n_heads"]), + multiple_of=model_params["multiple_of"], + ffn_dim_multiplier=model_params.get("ffn_dim_multiplier", None), + norm_eps=model_params["norm_eps"], + ) + return model_args + + +def _create_tokenizer(tokenizer_path: str) -> Tokenizer: + local_tokenizer_path = tokenizer_path + log.debug(f"successfully saved tokenizer to {local_tokenizer_path}") + tokenizer = Tokenizer(model_path=local_tokenizer_path) + return tokenizer + + +def _init_local_model(model_args: ModelArgs) -> Transformer: + with torch.device("meta"): + model = Transformer( + model_args.vocab_size, + model_args.n_layers, + model_args.dim, + model_args.n_heads, + model_args.n_kv_heads, # pyre-ignore[6] + model_args.multiple_of, + model_args.ffn_dim_multiplier, + model_args.max_batch_size, + model_args.max_seq_len, + model_args.norm_eps, + ) + + model.freqs_cis = precompute_freqs_cis( + model.dim // model.n_heads, model.max_seq_len * 2 + ) + for tformer_block in model.layers: + tformer_block.attention._init_cache_k() + tformer_block.attention._init_cache_v() + + return model + + +def get_consolidated_ckpt_path( + ckpt_dir: Union[str, Path], mp_rank: int = 0, mp_size: int = 1 +) -> Union[str, Path]: + + if mp_size == 1: + assert mp_rank == 0 + filename = "consolidated.00.pth" + else: + filename = f"consolidated.{mp_rank:02d}.pth" + if isinstance(ckpt_dir, Path): + return ckpt_dir / filename + else: + return os.path.join(ckpt_dir, filename) + +def _convert_fairscale_checkpoints(meta_model, model_parallel_size: int, original_ckpt_dir: str, save_checkpoint_dir: str): + mp_group, _ = dist.new_subgroups(group_size=model_parallel_size) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + raise RuntimeError("Expected local_rank to be set, but it is not!") + mp_rank = local_rank % model_parallel_size + + state_dict_pth = get_consolidated_ckpt_path( + ckpt_dir=original_ckpt_dir, mp_rank=mp_rank, mp_size=model_parallel_size + ) + state_dict = torch.load(state_dict_pth) + dist_state_dict = build_distributed_state_dict_from_consolidated( + meta_model, state_dict, model_parallel_world_size=model_parallel_size,use_dtensor=True + ) + dist_cp.save_state_dict( + state_dict=dist_state_dict, + storage_writer=dist_cp.FileSystemWriter(save_checkpoint_dir), + ) + + +def _load_checkpoint(mesh, model, meta_model, model_parallel_size: int, ckpt_dir: str) -> None: + mp_group, _ = dist.new_subgroups(group_size=model_parallel_size) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + raise RuntimeError("Expected local_rank to be set, but it is not!") + mp_rank = local_rank % model_parallel_size + state_dict_pth = get_consolidated_ckpt_path( + ckpt_dir=ckpt_dir, mp_rank=mp_rank, mp_size=model_parallel_size + ) + state_dict = torch.load(state_dict_pth) + dist_state_dict = build_distributed_state_dict_from_consolidated( + meta_model, state_dict, model_parallel_world_size=model_parallel_size,use_dtensor=True + ) + CHECKPOINT_DIR="converted_checkpoints" + dist_cp.save_state_dict( + state_dict=dist_state_dict, + storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), + ) + + converting_Dtensor_to_tensor(model,dist_state_dict,mesh ) + check_dtensor(dist_state_dict) + log.debug("build distributed_state_dict") + missing_keys, unexpected_keys = model.load_state_dict(dist_state_dict, strict=False) + assert not missing_keys + assert len(unexpected_keys) == 1 and "freqs" in unexpected_keys[0] + +def _load_tp_checkpoints(tp_model,CHECKPOINT_DIR): + + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + raise RuntimeError("Expected local_rank to be set, but it is not!") + tp_state_dict = tp_model.state_dict() + dist_cp.load_state_dict( + state_dict=tp_state_dict, + storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), + ) + tp_model.load_state_dict(tp_state_dict) + +def parallelize_llama_MLP_block(model, module_path, mesh): + block = model.get_submodule(module_path) + parallelized_block = parallelize_module( + module=block, + device_mesh=mesh, + parallelize_plan={ + "w1": ColwiseParallel(), + "w2": RowwiseParallel(), + "w3": ColwiseParallel(), + }, + # tp_mesh_dim=0, + ) + return parallelized_block + +def parallelize_llama_attn_block(model, module_path, twod_mesh): + block = model.get_submodule(module_path) + parallelized_block = parallelize_module( + module=block, + device_mesh=twod_mesh, + parallelize_plan={ + "wq": ColwiseParallel(), + "wk": ColwiseParallel(), + "wv": ColwiseParallel(), + "wo": RowwiseParallel(), + }, + # tp_mesh_dim=0, + ) + return parallelized_block + +def tp_llama(model, mesh): + for i in range(model.n_layers): + # print(f" i number of layers {i}*********************") + block = parallelize_llama_MLP_block(model, f"layers.{i}.feed_forward", mesh) + block = parallelize_llama_attn_block(model, f"layers.{i}.attention", mesh) + +def print_submodules(model): + for name, module in model.named_modules(): + print(f"Module name: {name}") + # print(module) + print() + +def check_dtensor(state_dict): + for fqn, tensor in state_dict.items(): + try: + is_dtensor = isinstance(tensor, DTensor) + except: + is_dtensor = False + + print(f"The model FQN: {fqn}, is DTensor {is_dtensor}") + +def converting_Dtensor_to_tensor(model_tp, dist_state_dict, mesh): +# Make sure this covers all non DTensor FQNs. + # model is the tp_model + for fqn in model_tp.state_dict(): + if not isinstance(model_tp.state_dict()[fqn], DTensor): + # # # Convert dist_state_dict[fqn] into non-DTensor + + if isinstance(dist_state_dict[fqn], DTensor): + # Not sure best way to materialize full DTensor on each rank Doing it by + # redistributing it to a world_size = 1 DeviceMesh, and then to_local. + unsharded_dt = dist_state_dict[fqn].redistribute(device_mesh=mesh, placements=[Replicate()]) + dist_state_dict[fqn] = unsharded_dt.to_local() + +class Llama: + @staticmethod + def build( + model_args: str, + converted_ckpt_dir:str, + tokenizer_path: str, + ) -> "Llama": + """ + Heavily motivated from https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L51, + and adapted for native parallelism APIs. + """ + start = time.time() + torch.set_default_tensor_type(torch.cuda.HalfTensor) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + raise RuntimeError("Expected local_rank to be set, but it is not!") + + torch.cuda.set_device(local_rank) + # model_args = _build_model_args(ckpt_dir, max_seq_len, max_batch_size) + # file_path = os.path.join(converted_ckpt_dir, 'model_args.json') + + with open(model_args, 'r') as file: + loaded_json = file.read() + + model_args = json_to_dataclass(loaded_json, ModelArgs) + + tokenizer = _create_tokenizer(tokenizer_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + + model = _init_local_model(model_args) + mesh = ( + DeviceMesh( + device_type="cuda", + mesh=list(range(dist.get_world_size())), + )) + + tp_llama(model, mesh) + + model.to_empty(device='cuda') + model.reset_parameters() + log.debug(f"Rank {dist.get_rank()}: created FSDP model {model}") + + _load_tp_checkpoints(model,converted_ckpt_dir) + param_numel = sum(p.numel() for p in model.parameters()) + log.debug( + f"Loaded {param_numel * dist.get_world_size()} params (across all workers) in {time.time() - start:.2f} seconds" + ) + return Llama(model, tokenizer) + + @staticmethod + def convert_checkpoints( + original_ckpt_dir: str, + save_checkpoint_dir:str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + model_parallel_size: int, + ) -> "Llama": + """ + Heavily motivated from https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L51, + and adapted for native parallelism APIs. + """ + start = time.time() + torch.set_default_tensor_type(torch.cuda.HalfTensor) + local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if local_rank == -1: + raise RuntimeError("Expected local_rank to be set, but it is not!") + + torch.cuda.set_device(local_rank) + model_args = _build_model_args(original_ckpt_dir, max_seq_len, max_batch_size) + tokenizer = _create_tokenizer(tokenizer_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + json_args = dataclass_to_json(model_args) + with open('model_args.json', 'w') as file: + file.write(json_args) + + model = _init_local_model(model_args) + + _convert_fairscale_checkpoints(model, model_parallel_size=model_parallel_size, original_ckpt_dir=original_ckpt_dir, save_checkpoint_dir=save_checkpoint_dir) + + log.debug( + f"the checkpoints have been converted to PTD compliant checkpoint and saved in {save_checkpoint_dir}" + ) + + + def __init__(self, model: Union[FSDP, Transformer], tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer \ No newline at end of file diff --git a/examples/large_models/tp_llama/llama2_tokenizer.py b/examples/large_models/tp_llama/llama2_tokenizer.py new file mode 100644 index 0000000000..917bbd6361 --- /dev/null +++ b/examples/large_models/tp_llama/llama2_tokenizer.py @@ -0,0 +1,44 @@ +""" +Tokenizer from https://github.com/facebookresearch/llama/blob/main/llama/tokenizer.py. +""" +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from logging import getLogger +from typing import List + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + from sentencepiece import SentencePieceProcessor + + self.sp_model = SentencePieceProcessor(model_file=model_path) # pyre-ignore[28] + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert ( + self.sp_model.vocab_size() + == self.sp_model.get_piece_size() # pyre-ignore[16] + ) + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + t = self.sp_model.encode(s) # pyre-ignore[16] + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) # pyre-ignore[16] \ No newline at end of file diff --git a/examples/large_models/tp_llama/model-config.yaml b/examples/large_models/tp_llama/model-config.yaml new file mode 100644 index 0000000000..0174fe8be2 --- /dev/null +++ b/examples/large_models/tp_llama/model-config.yaml @@ -0,0 +1,23 @@ +#frontend settings +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 200 +responseTimeout: 300 +parallelType: "tp" +deviceType: "gpu" + +torchrun: + nproc-per-node: 1 + +handler: + converted_ckpt_dir: "converted_checkpoints" + tokenizer_path: "tokenizer.model" + model_args_path: "model_args.json" + max_new_tokens: 50 + temperature: 0.6 + top_p: 0.9 + manual_seed: 40 + mode: "chat" #choices are text_completion, chat + + + diff --git a/examples/large_models/tp_llama/sample_text.txt b/examples/large_models/tp_llama/sample_text.txt new file mode 100644 index 0000000000..d5c3fdae71 --- /dev/null +++ b/examples/large_models/tp_llama/sample_text.txt @@ -0,0 +1 @@ +Hey, are you conscious? Can you talk to me? diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 2b1b907552..a7e3a176fa 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1102,6 +1102,19 @@ javac llamacpp streamlit tp +DTensor +DTensors +Fairscale +KV +MLP +Nighlies +ae +ba +ffa +lized +mins +sharding quantized Chatbot LLM +