diff --git a/gemma2/gemma2-27b-it-vllm/README.md b/gemma2/gemma2-27b-it-vllm/README.md new file mode 100644 index 00000000..9c57a722 --- /dev/null +++ b/gemma2/gemma2-27b-it-vllm/README.md @@ -0,0 +1,55 @@ +# Gemma 2 27B + +This is a [Truss](https://truss.baseten.co/) for Gemma 2 27B Instruct. This README will walk you through how to deploy this Truss on Baseten to get your own instance of Gemma 2 27B Instruct. + +## Gemma 2 27B Instruct Implementation + +This implementation of Gemma 2 uses [vLLM](https://github.com/vllm-project/vllm). + +Since Gemma 2 is a gated model, you will also need to provide your Huggingface access token after making sure you have access to [the model](https://huggingface.co/google/gemma-2-27b-it). Please use the [following guide](https://docs.baseten.co/deploy/guides/secrets) to add your Huggingface access token as a secret. + +## Deployment + +First, clone this repository: + +```sh +git clone https://github.com/basetenlabs/truss-examples/ +cd gemma2/gemma2-27b-it-vllm +``` + +Before deployment: + +1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys). +2. Install the latest version of Truss: `pip install --upgrade truss` + +With `gemma2/gemma2-27b-it-vllm` as your working directory, you can deploy the model with: + +```sh +truss push --trusted +``` + +Paste your Baseten API key if prompted. + +For more information, see [Truss documentation](https://truss.baseten.co). + +## Gemma 2 27B Instruct API documentation + +This section provides an overview of the Gemma 2 27B Instruct API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided prompt. + +### API route: `predict` + +The predict route is the primary method for generating text completions based on a given prompt. It takes several parameters: + +- __prompt__: The input text that you want the model to generate a response for. +- __max_tokens__: The maximum number of output tokens. + +## Example usage + +You can also invoke your model via a REST API: + +``` +curl -X POST " https://app.baseten.co/model_versions/YOUR_MODEL_VERSION_ID/predict" \ + -H "Content-Type: application/json" \ + -H 'Authorization: Api-Key {YOUR_API_KEY}' \ + -d '{"prompt": "what came before, the chicken or the egg?", "max_tokens": 64}' +``` diff --git a/gemma2/gemma2-27b-it-vllm/config.yaml b/gemma2/gemma2-27b-it-vllm/config.yaml new file mode 100644 index 00000000..fed5a820 --- /dev/null +++ b/gemma2/gemma2-27b-it-vllm/config.yaml @@ -0,0 +1,17 @@ +model_name: "Gemma 2 27B Instruct VLLM" +python_version: py311 +model_metadata: + example_model_input: {"prompt": "what is the meaning of life"} + main_model: google/gemma-2-27b-it + tensor_parallel: 1 + max_num_seqs: 16 +requirements: + - vllm==0.5.1 + - https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp311-cp311-linux_x86_64.whl +resources: + accelerator: A100 + use_gpu: true +runtime: + predict_concurrency: 128 +secrets: + hf_access_token: null diff --git a/gemma2/gemma2-27b-it-vllm/model/__init__.py b/gemma2/gemma2-27b-it-vllm/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gemma2/gemma2-27b-it-vllm/model/model.py b/gemma2/gemma2-27b-it-vllm/model/model.py new file mode 100644 index 00000000..8f2fd309 --- /dev/null +++ b/gemma2/gemma2-27b-it-vllm/model/model.py @@ -0,0 +1,94 @@ +import logging +import os +import subprocess +import uuid + +from transformers import AutoTokenizer +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine + +os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" +os.environ["TOKENIZERS_PARALLELISM"] = "true" +os.environ[ + "VLLM_WORKER_MULTIPROC_METHOD" +] = "spawn" # for multiprocessing to work with CUDA +logger = logging.getLogger(__name__) + + +class Model: + def __init__(self, **kwargs): + self._config = kwargs["config"] + self.model = None + self.llm_engine = None + self.model_args = None + self.hf_secret_token = kwargs["secrets"]["hf_access_token"] + os.environ["HF_TOKEN"] = self.hf_secret_token + + def load(self): + try: + result = subprocess.run( + ["nvidia-smi"], capture_output=True, text=True, check=True + ) + print(result.stdout) + except subprocess.CalledProcessError as e: + print(f"Command failed with code {e.returncode}: {e.stderr}") + model_metadata = self._config["model_metadata"] + logger.info(f"main model: {model_metadata['main_model']}") + logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}") + logger.info(f"max num seqs: {model_metadata['max_num_seqs']}") + + self.model_args = AsyncEngineArgs( + model=model_metadata["main_model"], + trust_remote_code=True, + tensor_parallel_size=model_metadata["tensor_parallel"], + max_num_seqs=model_metadata["max_num_seqs"], + dtype="auto", + use_v2_block_manager=True, + enforce_eager=True, + ) + self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args) + # create tokenizer for gemma 2 to apply chat template to prompts + + self.tokenizer = AutoTokenizer.from_pretrained(model_metadata["main_model"]) + + try: + result = subprocess.run( + ["nvidia-smi"], capture_output=True, text=True, check=True + ) + print(result.stdout) + except subprocess.CalledProcessError as e: + print(f"Command failed with code {e.returncode}: {e.stderr}") + + async def predict(self, model_input): + prompt = model_input.pop("prompt") + stream = model_input.pop("stream", True) + + sampling_params = SamplingParams(**model_input) + idx = str(uuid.uuid4().hex) + chat = [ + {"role": "user", "content": prompt}, + ] + # templatize the input to the model + input = self.tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True + ) + # since we accept any valid vllm sampling parameters, we can just pass it through + + vllm_generator = self.llm_engine.generate(input, sampling_params, idx) + + async def generator(): + full_text = "" + async for output in vllm_generator: + text = output.outputs[0].text + delta = text[len(full_text) :] + full_text = text + yield delta + + if stream: + return generator() + else: + full_text = "" + async for delta in generator(): + full_text += delta + return {"text": full_text} diff --git a/gemma2/gemma2-9b-it-vllm/README.md b/gemma2/gemma2-9b-it-vllm/README.md new file mode 100644 index 00000000..1effdca8 --- /dev/null +++ b/gemma2/gemma2-9b-it-vllm/README.md @@ -0,0 +1,55 @@ +# Gemma 2 9B + +This is a [Truss](https://truss.baseten.co/) for Gemma 2 9B Instruct. This README will walk you through how to deploy this Truss on Baseten to get your own instance of Gemma 2 9B Instruct. + +## Gemma 2 9B Instruct Implementation + +This implementation of Gemma 2 uses [vLLM](https://github.com/vllm-project/vllm). + +Since Gemma 2 is a gated model, you will also need to provide your Huggingface access token after making sure you have access to [the model](https://huggingface.co/google/gemma-2-9b-it). Please use the [following guide](https://docs.baseten.co/deploy/guides/secrets) to add your Huggingface access token as a secret. + +## Deployment + +First, clone this repository: + +```sh +git clone https://github.com/basetenlabs/truss-examples/ +cd gemma2/gemma2-9b-it-vllm +``` + +Before deployment: + +1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys). +2. Install the latest version of Truss: `pip install --upgrade truss` + +With `gemma2/gemma2-9b-it-vllm` as your working directory, you can deploy the model with: + +```sh +truss push --trusted +``` + +Paste your Baseten API key if prompted. + +For more information, see [Truss documentation](https://truss.baseten.co). + +## Gemma 2 9B Instruct API documentation + +This section provides an overview of the Gemma 2 9B Instruct API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided prompt. + +### API route: `predict` + +The predict route is the primary method for generating text completions based on a given prompt. It takes several parameters: + +- __prompt__: The input text that you want the model to generate a response for. +- __max_tokens__: The maximum number of output tokens. + +## Example usage + +You can also invoke your model via a REST API: + +``` +curl -X POST " https://app.baseten.co/model_versions/YOUR_MODEL_VERSION_ID/predict" \ + -H "Content-Type: application/json" \ + -H 'Authorization: Api-Key {YOUR_API_KEY}' \ + -d '{"prompt": "what came before, the chicken or the egg?", "max_tokens": 64}' +``` diff --git a/gemma2/gemma2-9b-it-vllm/config.yaml b/gemma2/gemma2-9b-it-vllm/config.yaml new file mode 100644 index 00000000..fee3b4ec --- /dev/null +++ b/gemma2/gemma2-9b-it-vllm/config.yaml @@ -0,0 +1,17 @@ +model_name: "Gemma 2 9B Instruct VLLM" +python_version: py311 +model_metadata: + example_model_input: {"prompt": "what is the meaning of life"} + main_model: google/gemma-2-9b-it + tensor_parallel: 1 + max_num_seqs: 16 +requirements: + - vllm==0.5.1 + - https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp311-cp311-linux_x86_64.whl +resources: + accelerator: L4 + use_gpu: true +runtime: + predict_concurrency: 128 +secrets: + hf_access_token: null diff --git a/gemma2/gemma2-9b-it-vllm/model/__init__.py b/gemma2/gemma2-9b-it-vllm/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/gemma2/gemma2-9b-it-vllm/model/model.py b/gemma2/gemma2-9b-it-vllm/model/model.py new file mode 100644 index 00000000..58f57b42 --- /dev/null +++ b/gemma2/gemma2-9b-it-vllm/model/model.py @@ -0,0 +1,92 @@ +import logging +import os +import subprocess +import uuid + +from transformers import AutoTokenizer +from vllm import SamplingParams +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine + +os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +logger = logging.getLogger(__name__) + + +class Model: + def __init__(self, **kwargs): + self._config = kwargs["config"] + self.model = None + self.llm_engine = None + self.model_args = None + self.hf_secret_token = kwargs["secrets"]["hf_access_token"] + os.environ["HF_TOKEN"] = self.hf_secret_token + + def load(self): + try: + result = subprocess.run( + ["nvidia-smi"], capture_output=True, text=True, check=True + ) + print(result.stdout) + except subprocess.CalledProcessError as e: + print(f"Command failed with code {e.returncode}: {e.stderr}") + model_metadata = self._config["model_metadata"] + logger.info(f"main model: {model_metadata['main_model']}") + logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}") + logger.info(f"max num seqs: {model_metadata['max_num_seqs']}") + + self.model_args = AsyncEngineArgs( + model=model_metadata["main_model"], + trust_remote_code=True, + tensor_parallel_size=model_metadata["tensor_parallel"], + max_num_seqs=model_metadata["max_num_seqs"], + dtype="auto", + use_v2_block_manager=True, + enforce_eager=True, + ) + self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args) + # create tokenizer for gemma 2 to apply chat template to prompts + + self.tokenizer = AutoTokenizer.from_pretrained(model_metadata["main_model"]) + + try: + result = subprocess.run( + ["nvidia-smi"], capture_output=True, text=True, check=True + ) + print(result.stdout) + except subprocess.CalledProcessError as e: + print(f"Command failed with code {e.returncode}: {e.stderr}") + + async def predict(self, model_input): + prompt = model_input.pop("prompt") + stream = model_input.pop("stream", True) + + sampling_params = SamplingParams(**model_input) + idx = str(uuid.uuid4().hex) + chat = [ + {"role": "user", "content": prompt}, + ] + # templatize the input to the model + input = self.tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True + ) + # since we accept any valid vllm sampling parameters, we can just pass it through + + vllm_generator = self.llm_engine.generate(input, sampling_params, idx) + + async def generator(): + full_text = "" + async for output in vllm_generator: + text = output.outputs[0].text + delta = text[len(full_text) :] + full_text = text + yield delta + + if stream: + return generator() + else: + full_text = "" + async for delta in generator(): + full_text += delta + return {"text": full_text}