From d5e10de27120144fb9b8cb0a9ad87107839fd462 Mon Sep 17 00:00:00 2001 From: Ankith Gunapal Date: Tue, 17 Sep 2024 15:39:27 -0700 Subject: [PATCH] TRT LLM Integration with LORA (#3305) * TRT LLM Integration with LORA * TRT LLM Integration with LORA * TRT LLM Integration with LORA * TRT LLM Integration with LORA * Added launcher support for trt_llm * updated README * updated README * Using the API that supports async generate * Review comments * Apply suggestions from code review Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com> * addressed review comments * Addressed review comments * Updated the async logic based on review comments * Made max_batch_size and kv_cache size configurable for the launcher * fixing lint --------- Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com> --- README.md | 15 +- examples/large_models/trt_llm/llama/README.md | 33 ++-- .../trt_llm/llama/model-config.yaml | 7 +- .../large_models/trt_llm/llama/prompt.json | 3 +- .../trt_llm/llama/trt_llm_handler.py | 118 ------------ examples/large_models/trt_llm/lora/README.md | 83 ++++++++ .../trt_llm/lora/model-config.yaml | 13 ++ .../large_models/trt_llm/lora/prompt.json | 4 + .../model_archiver/model_packaging_utils.py | 1 + requirements/trt_llm.txt | 3 + ts/llm_launcher.py | 180 ++++++++++++++---- ts/torch_handler/trt_llm_handler.py | 109 +++++++++++ ts/utils/hf_utils.py | 30 +++ ts_scripts/spellcheck_conf/wordlist.txt | 2 + 14 files changed, 428 insertions(+), 173 deletions(-) delete mode 100644 examples/large_models/trt_llm/llama/trt_llm_handler.py create mode 100644 examples/large_models/trt_llm/lora/README.md create mode 100644 examples/large_models/trt_llm/lora/model-config.yaml create mode 100644 examples/large_models/trt_llm/lora/prompt.json create mode 100644 requirements/trt_llm.txt create mode 100644 ts/torch_handler/trt_llm_handler.py create mode 100644 ts/utils/hf_utils.py diff --git a/README.md b/README.md index 766b3f4e45..f629ce61e3 100644 --- a/README.md +++ b/README.md @@ -62,12 +62,23 @@ Refer to [torchserve docker](docker/README.md) for details. ### 🤖 Quick Start LLM Deployment +#### VLLM Engine ```bash # Make sure to install torchserve with pip or conda as described above and login with `huggingface-cli login` -python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3-8B-Instruct --disable_token_auth +python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3.1-8B-Instruct --disable_token_auth # Try it out -curl -X POST -d '{"model":"meta-llama/Meta-Llama-3-8B-Instruct", "prompt":"Hello, my name is", "max_tokens": 200}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model/1.0/v1/completions" +curl -X POST -d '{"model":"meta-llama/Meta-Llama-3.1-8B-Instruct", "prompt":"Hello, my name is", "max_tokens": 200}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model/1.0/v1/completions" +``` + +#### TRT-LLM Engine +```bash +# Make sure to install torchserve with python venv as described above and login with `huggingface-cli login` +# pip install -U --use-deprecated=legacy-resolver -r requirements/trt_llm.txt +python -m ts.llm_launcher --model_id meta-llama/Meta-Llama-3.1-8B-Instruct --engine trt_llm --disable_token_auth + +# Try it out +curl -X POST -d '{"prompt":"count from 1 to 9 in french ", "max_tokens": 100}' --header "Content-Type: application/json" "http://localhost:8080/predictions/model" ``` ### 🚢 Quick Start LLM Deployment with Docker diff --git a/examples/large_models/trt_llm/llama/README.md b/examples/large_models/trt_llm/llama/README.md index e82433892c..0b883ec74c 100644 --- a/examples/large_models/trt_llm/llama/README.md +++ b/examples/large_models/trt_llm/llama/README.md @@ -4,19 +4,19 @@ ## Pre-requisites -TRT-LLM requires Python 3.10 +- TRT-LLM requires Python 3.10 +- TRT-LLM works well with python venv (vs conda) This example is tested with CUDA 12.1 Once TorchServe is installed, install TensorRT-LLM using the following. -This will downgrade the versions of PyTorch & Triton but this doesn't cause any issue. ``` -pip install tensorrt_llm==0.10.0 --extra-index-url https://pypi.nvidia.com -pip install tensorrt-cu12==10.1.0 +pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com +pip install transformers>=4.44.2 python -c "import tensorrt_llm" ``` shows ``` -[TensorRT-LLM] TensorRT-LLM version: 0.10.0 +[TensorRT-LLM] TensorRT-LLM version: 0.13.0.dev2024090300 ``` ## Download model from HuggingFace @@ -26,29 +26,32 @@ huggingface-cli login huggingface-cli login --token $HUGGINGFACE_TOKEN ``` ``` -python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct +python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True ``` ## Create TensorRT-LLM Engine Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine ``` -git clone -b v0.10.0 https://github.com/NVIDIA/TensorRT-LLM.git +git clone https://github.com/NVIDIA/TensorRT-LLM.git ``` Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API. ``` -python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16 +python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/ --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16 ``` + ``` -trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3-8b-engine +trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --max_batch_size 4 --output_dir ./llama-3.1-8b-engine ``` +If you have enough GPU memory, you can try increasing the `max_batch_size` You can test if TensorRT-LLM Engine has been compiled correctly by running the following ``` -python TensorRT-LLM/examples/run.py --engine_dir ./llama-3-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/ --input_text "How do I count to nine in French?" +python TensorRT-LLM/examples/run.py --engine_dir ./llama-3.1-8b-engine --max_output_len 100 --tokenizer_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/ --input_text "How do I count to nine in French?" ``` +If you are running into OOM, try reducing `kv_cache_free_gpu_memory_fraction` You should see an output as follows ``` @@ -70,17 +73,17 @@ That's it! You can now count to nine in French. Just remember that the numbers o ``` mkdir model_store -torch-model-archiver --model-name llama3-8b --version 1.0 --handler trt_llm_handler.py --config-file model-config.yaml --archive-format no-archive --export-path model_store -f -mv model model_store/llama3-8b/. -mv llama-3-8b-engine model_store/llama3-8b/. +torch-model-archiver --model-name llama3.1-8b --version 1.0 --handler trt_llm_handler --config-file model-config.yaml --archive-format no-archive --export-path model_store -f +mv model model_store/llama3.1-8b/. +mv llama-3.1-8b-engine model_store/llama3.1-8b/. ``` ## Start TorchServe ``` -torchserve --start --ncs --model-store model_store --models llama3-8b --disable-token-auth +torchserve --start --ncs --model-store model_store --models llama3.1-8b --disable-token-auth ``` ## Run Inference ``` -python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3-8b --prompt-text "@prompt.json" --prompt-json +python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3.1-8b --prompt-text "@prompt.json" --prompt-json ``` diff --git a/examples/large_models/trt_llm/llama/model-config.yaml b/examples/large_models/trt_llm/llama/model-config.yaml index 8d914f45bf..38610e64de 100644 --- a/examples/large_models/trt_llm/llama/model-config.yaml +++ b/examples/large_models/trt_llm/llama/model-config.yaml @@ -7,6 +7,7 @@ deviceType: "gpu" asyncCommunication: true handler: - tokenizer_dir: "model/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/e1945c40cd546c78e41f1151f4db032b271faeaa/" - trt_llm_engine_config: - engine_dir: "llama-3-8b-engine" + tokenizer_dir: "model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f/" + engine_dir: "llama-3.1-8b-engine" + kv_cache_config: + free_gpu_memory_fraction: 0.1 diff --git a/examples/large_models/trt_llm/llama/prompt.json b/examples/large_models/trt_llm/llama/prompt.json index 74490ebc79..e1bbca37eb 100644 --- a/examples/large_models/trt_llm/llama/prompt.json +++ b/examples/large_models/trt_llm/llama/prompt.json @@ -1,3 +1,4 @@ {"prompt": "How is the climate in San Francisco?", "temperature":0.5, - "max_new_tokens": 200} + "max_tokens": 400, + "streaming": true} diff --git a/examples/large_models/trt_llm/llama/trt_llm_handler.py b/examples/large_models/trt_llm/llama/trt_llm_handler.py deleted file mode 100644 index 5aaa0596ad..0000000000 --- a/examples/large_models/trt_llm/llama/trt_llm_handler.py +++ /dev/null @@ -1,118 +0,0 @@ -import json -import logging -import time - -import torch -from tensorrt_llm.runtime import ModelRunner -from transformers import AutoTokenizer - -from ts.handler_utils.utils import send_intermediate_predict_response -from ts.torch_handler.base_handler import BaseHandler - -logger = logging.getLogger(__name__) - - -class TRTLLMHandler(BaseHandler): - def __init__(self): - super().__init__() - - self.trt_llm_engine = None - self.tokenizer = None - self.model = None - self.model_dir = None - self.lora_ids = {} - self.adapters = None - self.initialized = False - - def initialize(self, ctx): - self.model_dir = ctx.system_properties.get("model_dir") - - trt_llm_engine_config = ctx.model_yaml_config.get("handler").get( - "trt_llm_engine_config" - ) - - tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir") - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_dir, - legacy=False, - padding_side="left", - truncation_side="left", - trust_remote_code=True, - use_fast=True, - ) - - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - - self.trt_llm_engine = ModelRunner.from_dir(**trt_llm_engine_config) - self.initialized = True - - async def handle(self, data, context): - start_time = time.time() - - metrics = context.metrics - - data_preprocess = await self.preprocess(data) - output, input_batch = await self.inference(data_preprocess, context) - output = await self.postprocess(output, input_batch, context) - - stop_time = time.time() - metrics.add_time( - "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" - ) - return output - - async def preprocess(self, requests): - input_batch = [] - assert len(requests) == 1, "Expecting batch_size = 1" - for req_data in requests: - data = req_data.get("data") or req_data.get("body") - if isinstance(data, (bytes, bytearray)): - data = data.decode("utf-8") - - prompt = data.get("prompt") - temperature = data.get("temperature", 1.0) - max_new_tokens = data.get("max_new_tokens", 50) - input_ids = self.tokenizer.encode( - prompt, add_special_tokens=True, truncation=True - ) - input_batch.append(input_ids) - - input_batch = [torch.tensor(x, dtype=torch.int32) for x in input_batch] - - return (input_batch, temperature, max_new_tokens) - - async def inference(self, input_batch, context): - input_ids_batch, temperature, max_new_tokens = input_batch - - with torch.no_grad(): - outputs = self.trt_llm_engine.generate( - batch_input_ids=input_ids_batch, - temperature=temperature, - max_new_tokens=max_new_tokens, - end_id=self.tokenizer.eos_token_id, - pad_id=self.tokenizer.pad_token_id, - output_sequence_lengths=True, - streaming=True, - return_dict=True, - ) - return outputs, input_ids_batch - - async def postprocess(self, inference_outputs, input_batch, context): - for inference_output in inference_outputs: - output_ids = inference_output["output_ids"] - sequence_lengths = inference_output["sequence_lengths"] - - batch_size, _, _ = output_ids.size() - for batch_idx in range(batch_size): - output_end = sequence_lengths[batch_idx][0] - outputs = output_ids[batch_idx][0][output_end - 1 : output_end].tolist() - output_text = self.tokenizer.decode(outputs) - send_intermediate_predict_response( - [json.dumps({"text": output_text})], - context.request_ids, - "Result", - 200, - context, - ) - return [""] * len(input_batch) diff --git a/examples/large_models/trt_llm/lora/README.md b/examples/large_models/trt_llm/lora/README.md new file mode 100644 index 0000000000..e969d9364d --- /dev/null +++ b/examples/large_models/trt_llm/lora/README.md @@ -0,0 +1,83 @@ +# Llama TensorRT-LLM Engine + LoRA model integration with TorchServe + +[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) provides users with an option to build TensorRT engines for LLMs that contain state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs. + +## Pre-requisites + +- TRT-LLM requires Python 3.10 +- TRT-LLM works well with python venv (vs conda) +This example is tested with CUDA 12.1 +Once TorchServe is installed, install TensorRT-LLM using the following. + +``` +pip install tensorrt_llm -U --pre --extra-index-url https://pypi.nvidia.com +pip install transformers>=4.44.2 +python -c "import tensorrt_llm" +``` +shows +``` +[TensorRT-LLM] TensorRT-LLM version: 0.13.0.dev2024090300 +``` + +## Download Base model & LoRA adapter from Hugging Face +``` +huggingface-cli login +# or using an environment variable +huggingface-cli login --token $HUGGINGFACE_TOKEN +``` +``` +python ../../utils/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --use_auth_token True +python ../../utils/Download_model.py --model_path model --model_name llama-duo/llama3.1-8b-summarize-gpt4o-128k --use_auth_token True +``` + +## Create TensorRT-LLM Engine +Clone TensorRT-LLM which will be used to create the TensorRT-LLM Engine + +``` +git clone https://github.com/NVIDIA/TensorRT-LLM.git +``` + +Compile the model into a TensorRT engine with model weights and a model definition written in the TensorRT-LLM Python API. + +``` +python TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir model/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/5206a32e0bd3067aef1ce90f5528ade7d866253f --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16 +``` + +``` +trtllm-build --checkpoint_dir tllm_checkpoint_1gpu_bf16 --gemm_plugin bfloat16 --gpt_attention_plugin bfloat16 --output_dir ./llama-3.1-8b-engine-lora --max_batch_size 4 --lora_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --lora_plugin bfloat16 +``` +If you have enough GPU memory, you can try increasing the `max_batch_size` + +You can test if TensorRT-LLM Engine has been compiled correctly by running the following +``` +python TensorRT-LLM/examples/run.py --engine_dir ./llama-3.1-8b-engine-lora --max_output_len 100 --tokenizer_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --input_text "Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:" --lora_dir model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825 --kv_cache_free_gpu_memory_fraction 0.3 --use_py_session +``` +If you are running into OOM, try reducing `kv_cache_free_gpu_memory_fraction` + +You should see an output as follows +``` +Input [Text 0]: "<|begin_of_text|>Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:" +Output [Text 0 Beam 0]: " Amanda offered Jerry cookies and said she would bring them to him tomorrow. +Amanda offered Jerry cookies and said she would bring them to him tomorrow. +The dialogue is between Amanda and Jerry. Amanda offers Jerry cookies and says she will bring them to him tomorrow. The dialogue is a simple exchange between two people, with no complex plot or themes. The tone is casual and friendly. The dialogue is a good example of a short, everyday conversation. +The dialogue is a good example of a short," +``` + +## Create model archive + +``` +mkdir model_store +torch-model-archiver --model-name llama3.1-8b --version 1.0 --handler trt_llm_handler --config-file model-config.yaml --archive-format no-archive --export-path model_store -f +mv model model_store/llama3.1-8b/. +mv llama-3.1-8b-engine-lora model_store/llama3.1-8b/. +``` + +## Start TorchServe +``` +torchserve --start --ncs --model-store model_store --models llama3.1-8b --disable-token-auth +``` + +## Run Inference +``` +python ../../utils/test_llm_streaming_response.py -o 50 -t 2 -n 4 -m llama3.1-8b --prompt-text "@prompt.json" --prompt-json +``` diff --git a/examples/large_models/trt_llm/lora/model-config.yaml b/examples/large_models/trt_llm/lora/model-config.yaml new file mode 100644 index 0000000000..98248b3502 --- /dev/null +++ b/examples/large_models/trt_llm/lora/model-config.yaml @@ -0,0 +1,13 @@ +# TorchServe frontend parameters +minWorkers: 1 +maxWorkers: 1 +maxBatchDelay: 100 +responseTimeout: 1200 +deviceType: "gpu" +asyncCommunication: true + +handler: + tokenizer_dir: "model/models--llama-duo--llama3.1-8b-summarize-gpt4o-128k/snapshots/4ba83353f24fa38946625c8cc49bf21c80a22825" + engine_dir: "llama-3.1-8b-engine-lora" + kv_cache_config: + free_gpu_memory_fraction: 0.1 diff --git a/examples/large_models/trt_llm/lora/prompt.json b/examples/large_models/trt_llm/lora/prompt.json new file mode 100644 index 0000000000..7adb05a610 --- /dev/null +++ b/examples/large_models/trt_llm/lora/prompt.json @@ -0,0 +1,4 @@ +{"prompt": "Amanda: I baked cookies. Do you want some?\nJerry: Sure \nAmanda: I will bring you tomorrow :-)\n\nSummarize the dialog:", + "temperature":0.0, + "max_new_tokens": 100, + "streaming": true} diff --git a/model-archiver/model_archiver/model_packaging_utils.py b/model-archiver/model_archiver/model_packaging_utils.py index eb5329e429..8b81af9f73 100644 --- a/model-archiver/model_archiver/model_packaging_utils.py +++ b/model-archiver/model_archiver/model_packaging_utils.py @@ -35,6 +35,7 @@ "image_segmenter": "vision", "dali_image_classifier": "vision", "vllm_handler": "text", + "trt_llm_handler": "text", } MODEL_SERVER_VERSION = "1.0" diff --git a/requirements/trt_llm.txt b/requirements/trt_llm.txt new file mode 100644 index 0000000000..68eb852697 --- /dev/null +++ b/requirements/trt_llm.txt @@ -0,0 +1,3 @@ +--pre --extra-index-url https://pypi.nvidia.com +tensorrt_llm +transformers>=4.44.2 diff --git a/ts/llm_launcher.py b/ts/llm_launcher.py index 89248ce9f4..1d95731e2f 100644 --- a/ts/llm_launcher.py +++ b/ts/llm_launcher.py @@ -1,6 +1,8 @@ import argparse import contextlib +import os import shutil +import subprocess from pathlib import Path from signal import pause @@ -10,14 +12,55 @@ from model_archiver.model_packaging import generate_model_archive from ts.launcher import start, stop +from ts.utils.hf_utils import download_model + + +def create_tensorrt_llm_engine( + model_store, model_name, dtype, snapshot_path, max_batch_size +): + if not Path("/tmp/TensorRT-LLM").exists(): + subprocess.run( + [ + "git", + "clone", + "https://github.com/NVIDIA/TensorRT-LLM.git", + "-b", + "v0.12.0", + "/tmp/TensorRT-LLM", + ] + ) + if not Path(f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16").exists(): + subprocess.run( + [ + "python", + "/tmp/TensorRT-LLM/examples/llama/convert_checkpoint.py", + "--model_dir", + snapshot_path, + "--output_dir", + f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16", + "--dtype", + dtype, + ] + ) + if not Path(f"{model_store}/{model_name}/{model_name}-engine").exists(): + subprocess.run( + [ + "trtllm-build", + "--checkpoint_dir", + f"{model_store}/{model_name}/tllm_checkpoint_1gpu_bf16", + "--gemm_plugin", + dtype, + "--gpt_attention_plugin", + dtype, + "--max_batch_size", + f"{max_batch_size}", + "--output_dir", + f"{model_store}/{model_name}/{model_name}-engine", + ] + ) -def get_model_config(args): - download_dir = getattr(args, "vllm_engine.download_dir") - download_dir = ( - Path(download_dir).resolve().as_posix() if download_dir else download_dir - ) - +def get_model_config(args, model_snapshot_path=None): model_config = { "minWorkers": 1, "maxWorkers": 1, @@ -26,58 +69,86 @@ def get_model_config(args): "responseTimeout": 1200, "deviceType": "gpu", "asyncCommunication": True, - "parallelLevel": torch.cuda.device_count() if torch.cuda.is_available else 1, - "handler": { - "model_path": args.model_id, - "vllm_engine_config": { - "max_num_seqs": getattr(args, "vllm_engine.max_num_seqs"), - "max_model_len": getattr(args, "vllm_engine.max_model_len"), - "download_dir": download_dir, - "tensor_parallel_size": torch.cuda.device_count() + } + + if args.engine == "vllm": + download_dir = getattr(args, "vllm_engine.download_dir") + download_dir = ( + Path(download_dir).resolve().as_posix() if download_dir else download_dir + ) + + model_config.update( + { + "parallelLevel": torch.cuda.device_count() if torch.cuda.is_available else 1, - }, - }, - } + "handler": { + "model_path": args.model_id, + "vllm_engine_config": { + "max_num_seqs": getattr(args, "vllm_engine.max_num_seqs"), + "max_model_len": getattr(args, "vllm_engine.max_model_len"), + "download_dir": download_dir, + "tensor_parallel_size": torch.cuda.device_count() + if torch.cuda.is_available + else 1, + }, + }, + } + ) - if hasattr(args, "lora_adapter_ids"): - raise NotImplementedError("Lora setting needs to be implemented") - lora_adapter_ids = args.lora_adapter_ids.split(";") + if hasattr(args, "lora_adapter_ids"): + raise NotImplementedError("Lora setting needs to be implemented") + lora_adapter_ids = args.lora_adapter_ids.split(";") - model_config["handler"]["vllm_engine_config"].update( + model_config["handler"]["vllm_engine_config"].update( + { + "enable_lora": True, + } + ) + + elif args.engine == "trt_llm": + model_config.update( { - "enable_lora": True, + "handler": { + "tokenizer_dir": os.path.join(os.getcwd(), model_snapshot_path), + "engine_dir": f"{args.model_name}-engine", + "kv_cache_config": { + "free_gpu_memory_fraction": getattr( + args, "trt_llm_engine.kv_cache_free_gpu_memory_fraction" + ), + }, + } } ) + else: + raise RuntimeError("Unsupported LLM Engine") return model_config @contextlib.contextmanager -def create_mar_file(args): - model_store_path = Path(args.model_store) - model_store_path.mkdir(parents=True, exist_ok=True) - - mar_file_path = model_store_path / args.model_name +def create_mar_file(args, model_snapshot_path=None): + mar_file_path = Path(args.model_store) / args.model_name model_config_yaml = Path(args.model_store) / "model-config.yaml" with model_config_yaml.open("w") as f: - yaml.dump(get_model_config(args), f) + yaml.dump(get_model_config(args, model_snapshot_path), f) config = ModelArchiverConfig( model_name=args.model_name, version="1.0", - handler="vllm_handler", + handler=f"{args.engine}_handler", serialized_file=None, export_path=args.model_store, requirements_file=None, runtime="python", - force=False, + force=True, config_file=model_config_yaml.as_posix(), archive_format="no-archive", ) - generate_model_archive(config) + if not mar_file_path.exists(): + generate_model_archive(config) model_config_yaml.unlink() @@ -85,7 +156,8 @@ def create_mar_file(args): yield mar_file_path.as_posix() - shutil.rmtree(mar_file_path) + if args.engine == "vllm": + shutil.rmtree(mar_file_path) def main(args): @@ -93,7 +165,20 @@ def main(args): Register the model in torchserve """ - with create_mar_file(args): + model_store_path = Path(args.model_store) + model_store_path.mkdir(parents=True, exist_ok=True) + if args.engine == "trt_llm": + model_snapshot_path = download_model(args.model_id) + + with create_mar_file(args, model_snapshot_path): + if args.engine == "trt_llm": + create_tensorrt_llm_engine( + args.model_store, + args.model_name, + args.dtype, + model_snapshot_path, + getattr(args, "trt_llm_engine.max_batch_size"), + ) try: start( model_store=args.model_store, @@ -129,7 +214,7 @@ def main(args): parser.add_argument( "--model_id", type=str, - default="meta-llama/Meta-Llama-3-8B-Instruct", + default="meta-llama/Meta-Llama-3.1-8B-Instruct", help="Model id", ) @@ -160,6 +245,33 @@ def main(args): help="Cache dir", ) + parser.add_argument( + "--engine", + type=str, + default="vllm", + help="LLM engine", + ) + + parser.add_argument( + "--dtype", + type=str, + default="bfloat16", + help="Data type", + ) + + parser.add_argument( + "--trt_llm_engine.max_batch_size", + type=int, + default=4, + help="Max batch size", + ) + + parser.add_argument( + "--trt_llm_engine.kv_cache_free_gpu_memory_fraction", + type=int, + default=0.1, + help="KV Cache free gpu memory fraction", + ) args = parser.parse_args() main(args) diff --git a/ts/torch_handler/trt_llm_handler.py b/ts/torch_handler/trt_llm_handler.py new file mode 100644 index 0000000000..cb705530d7 --- /dev/null +++ b/ts/torch_handler/trt_llm_handler.py @@ -0,0 +1,109 @@ +import json +import logging +import time + +from tensorrt_llm.hlapi import LLM, KvCacheConfig, SamplingParams +from transformers import AutoTokenizer + +from ts.handler_utils.utils import send_intermediate_predict_response +from ts.torch_handler.base_handler import BaseHandler + +logger = logging.getLogger(__name__) + + +class TRTLLMHandler(BaseHandler): + def __init__(self): + super().__init__() + + self.trt_llm_engine = None + self.tokenizer = None + self.model = None + self.model_dir = None + self.initialized = False + + def initialize(self, ctx): + self.model_dir = ctx.system_properties.get("model_dir") + + engine_dir = ctx.model_yaml_config.get("handler").get("engine_dir") + kv_cache_cfg = ctx.model_yaml_config.get("handler").get("kv_cache_config", {}) + + tokenizer_dir = ctx.model_yaml_config.get("handler").get("tokenizer_dir") + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, + legacy=False, + padding_side="left", + truncation_side="left", + trust_remote_code=True, + use_fast=True, + ) + + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + kv_cache_config = KvCacheConfig(**kv_cache_cfg) + + self.trt_llm_engine = LLM( + model=engine_dir, tokenizer=self.tokenizer, kv_cache_config=kv_cache_config + ) + self.initialized = True + + async def handle(self, data, context): + start_time = time.time() + + metrics = context.metrics + + data_preprocess = await self.preprocess(data) + output = await self.inference(data_preprocess, context) + output = await self.postprocess(output) + + stop_time = time.time() + metrics.add_time( + "HandlerTime", round((stop_time - start_time) * 1000, 2), None, "ms" + ) + return output + + async def preprocess(self, requests): + assert len(requests) == 1, "Expecting batch_size = 1" + req_data = requests[0] + data = req_data.get("data") or req_data.get("body") + if isinstance(data, (bytes, bytearray)): + data = data.decode("utf-8") + return data + + async def inference(self, data, context): + generate_kwargs = { + "end_id": self.tokenizer.eos_token_id, + "pad_id": self.tokenizer.pad_token_id, + } + prompt = data.get("prompt") + streaming = data.get("streaming", False) + del data["prompt"] + if "streaming" in data: + del data["streaming"] + generate_kwargs.update(data) + sampling_params = SamplingParams(**generate_kwargs) + + outputs = self.trt_llm_engine.generate_async( + prompt, streaming=streaming, sampling_params=sampling_params + ) + + async for output in outputs: + output_text, output_ids = ( + output.outputs[0].text, + output.outputs[0].token_ids, + ) + if not streaming: + return [output_text] + else: + output_text = self.tokenizer.decode([output_ids[-1]]) + send_intermediate_predict_response( + [json.dumps({"text": output_text})], + context.request_ids, + "Result", + 200, + context, + ) + return [""] + + async def postprocess(self, outputs): + return outputs diff --git a/ts/utils/hf_utils.py b/ts/utils/hf_utils.py new file mode 100644 index 0000000000..a92ab640c2 --- /dev/null +++ b/ts/utils/hf_utils.py @@ -0,0 +1,30 @@ +from huggingface_hub import snapshot_download + + +def download_model( + model_id="meta-llama/Meta-Llama-3.1-8B-Instruct", + revision="main", + model_path=".cache", + use_auth_token=True, +): + # Only download pytorch checkpoint files + allow_patterns = [ + "*.json", + "*.pt", + "*.bin", + "*.txt", + "*.model", + "*.pth", + "*.safetensors", + "original/*", + ] + + snapshot_path = snapshot_download( + repo_id=model_id, + revision=revision, + allow_patterns=allow_patterns, + cache_dir=model_path, + use_auth_token=use_auth_token, + ) + + return snapshot_path diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index 0242be57a1..3e055db67d 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -1299,3 +1299,5 @@ torchaudio ln OpenAI openai +kv +OOM