Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added gemma2 9b and 27b vllm with streaming #318

Merged
merged 11 commits into from
Jul 16, 2024
55 changes: 55 additions & 0 deletions gemma2/gemma2-27b-it-vllm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Gemma 2 27B

This is a [Truss](https://truss.baseten.co/) for Gemma 2 27B Instruct. This README will walk you through how to deploy this Truss on Baseten to get your own instance of Gemma 2 27B Instruct.

## Gemma 2 27B Instruct Implementation

This implementation of Gemma 2 uses [vLLM](https://github.com/vllm-project/vllm).

Since Gemma 2 is a gated model, you will also need to provide your Huggingface access token after making sure you have access to [the model](https://huggingface.co/google/gemma-2-27b-it). Please use the [following guide](https://docs.baseten.co/deploy/guides/secrets) to add your Huggingface access token as a secret.

## Deployment

First, clone this repository:

```sh
git clone https://github.com/basetenlabs/truss-examples/
cd gemma2/gemma2-27b-it-vllm
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `gemma2/gemma2-27b-it-vllm` as your working directory, you can deploy the model with:

```sh
truss push --trusted
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## Gemma 2 27B Instruct API documentation

This section provides an overview of the Gemma 2 27B Instruct API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided prompt.

### API route: `predict`

The predict route is the primary method for generating text completions based on a given prompt. It takes several parameters:

- __prompt__: The input text that you want the model to generate a response for.
- __max_tokens__: The maximum number of output tokens.

## Example usage

You can also invoke your model via a REST API:

```
curl -X POST " https://app.baseten.co/model_versions/YOUR_MODEL_VERSION_ID/predict" \
-H "Content-Type: application/json" \
-H 'Authorization: Api-Key {YOUR_API_KEY}' \
-d '{"prompt": "what came before, the chicken or the egg?", "max_tokens": 64}'
```
17 changes: 17 additions & 0 deletions gemma2/gemma2-27b-it-vllm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model_name: "Gemma 2 27B Instruct VLLM"
python_version: py311
model_metadata:
example_model_input: {"prompt": "what is the meaning of life"}
main_model: google/gemma-2-27b-it
tensor_parallel: 2
max_num_seqs: 16
requirements:
- vllm==0.5.2
- https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp311-cp311-linux_x86_64.whl
resources:
accelerator: A100:2
use_gpu: true
runtime:
predict_concurrency: 128
secrets:
hf_access_token: null
Empty file.
91 changes: 91 additions & 0 deletions gemma2/gemma2-27b-it-vllm/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import subprocess
import uuid
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import os
from transformers import AutoTokenizer

os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" # for multiprocessing to work with CUDA
logger = logging.getLogger(__name__)


class Model:
def __init__(self, **kwargs):
self._config = kwargs["config"]
self.model = None
self.llm_engine = None
self.model_args = None
self.hf_secret_token = kwargs["secrets"]["hf_access_token"]
os.environ["HF_TOKEN"] = self.hf_secret_token

def load(self):
try:
result = subprocess.run(
["nvidia-smi"], capture_output=True, text=True, check=True
)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Command failed with code {e.returncode}: {e.stderr}")
model_metadata = self._config["model_metadata"]
logger.info(f"main model: {model_metadata['main_model']}")
logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}")
logger.info(f"max num seqs: {model_metadata['max_num_seqs']}")

self.model_args = AsyncEngineArgs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential improvement is to move everything to config, e.g. as in this example:
https://github.com/vshulman/truss-examples/tree/main/ultravox-vllm

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can merge without this change as other vLLM examples also support a partial list of arguments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through it, it looks like that example uses the vllm openai server instead of explicitly instantiating the vllm AsyncLLMEngine for the model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% -- I just think the same kwargs pattern can apply here. the benefit I see is that going forward all it would take to pass a new argument into vLLM, either the standalone OpenAI server or the Python API above, is adding it to the config.yaml.

model=model_metadata["main_model"],
trust_remote_code=True,
tensor_parallel_size=model_metadata["tensor_parallel"],
max_num_seqs=model_metadata["max_num_seqs"],
dtype="auto",
use_v2_block_manager=True,
enforce_eager=True,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args)
# create tokenizer for gemma 2 to apply chat template to prompts

self.tokenizer = AutoTokenizer.from_pretrained(model_metadata["main_model"])

try:
result = subprocess.run(
["nvidia-smi"], capture_output=True, text=True, check=True
)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Command failed with code {e.returncode}: {e.stderr}")

async def predict(self, model_input):
prompt = model_input.pop("prompt")
stream = model_input.pop("stream", True)

sampling_params = SamplingParams(**model_input)
idx = str(uuid.uuid4().hex)
chat = [
{"role": "user", "content": prompt},
]
# templatize the input to the model
input = self.tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
# since we accept any valid vllm sampling parameters, we can just pass it through

vllm_generator = self.llm_engine.generate(input, sampling_params, idx)

async def generator():
full_text = ""
async for output in vllm_generator:
text = output.outputs[0].text
delta = text[len(full_text) :]
full_text = text
yield delta

if stream:
return generator()
else:
full_text = ""
async for delta in generator():
full_text += delta
return {"text": full_text}
55 changes: 55 additions & 0 deletions gemma2/gemma2-9b-it-vllm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Gemma 2 9B

This is a [Truss](https://truss.baseten.co/) for Gemma 2 9B Instruct. This README will walk you through how to deploy this Truss on Baseten to get your own instance of Gemma 2 9B Instruct.

## Gemma 2 9B Instruct Implementation

This implementation of Gemma 2 uses [vLLM](https://github.com/vllm-project/vllm).

Since Gemma 2 is a gated model, you will also need to provide your Huggingface access token after making sure you have access to [the model](https://huggingface.co/google/gemma-2-9b-it). Please use the [following guide](https://docs.baseten.co/deploy/guides/secrets) to add your Huggingface access token as a secret.

## Deployment

First, clone this repository:

```sh
git clone https://github.com/basetenlabs/truss-examples/
cd gemma2/gemma2-9b-it-vllm
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `gemma2/gemma2-9b-it-vllm` as your working directory, you can deploy the model with:

```sh
truss push --trusted
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## Gemma 2 9B Instruct API documentation

This section provides an overview of the Gemma 2 9B Instruct API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided prompt.

### API route: `predict`

The predict route is the primary method for generating text completions based on a given prompt. It takes several parameters:

- __prompt__: The input text that you want the model to generate a response for.
- __max_tokens__: The maximum number of output tokens.

## Example usage

You can also invoke your model via a REST API:

```
curl -X POST " https://app.baseten.co/model_versions/YOUR_MODEL_VERSION_ID/predict" \
-H "Content-Type: application/json" \
-H 'Authorization: Api-Key {YOUR_API_KEY}' \
-d '{"prompt": "what came before, the chicken or the egg?", "max_tokens": 64}'
```
17 changes: 17 additions & 0 deletions gemma2/gemma2-9b-it-vllm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model_name: "Gemma 2 9B Instruct VLLM"
python_version: py311
model_metadata:
example_model_input: {"prompt": "what is the meaning of life"}
main_model: google/gemma-2-9b-it
tensor_parallel: 1
max_num_seqs: 16
requirements:
- vllm==0.5.2
- https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp311-cp311-linux_x86_64.whl
resources:
accelerator: L4
use_gpu: true
runtime:
predict_concurrency: 128
secrets:
hf_access_token: null
Empty file.
91 changes: 91 additions & 0 deletions gemma2/gemma2-9b-it-vllm/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging
import subprocess
import uuid
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
import os
from transformers import AutoTokenizer

os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

logger = logging.getLogger(__name__)


class Model:
def __init__(self, **kwargs):
self._config = kwargs["config"]
self.model = None
self.llm_engine = None
self.model_args = None
self.hf_secret_token = kwargs["secrets"]["hf_access_token"]
os.environ["HF_TOKEN"] = self.hf_secret_token

def load(self):
try:
result = subprocess.run(
["nvidia-smi"], capture_output=True, text=True, check=True
)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Command failed with code {e.returncode}: {e.stderr}")
model_metadata = self._config["model_metadata"]
logger.info(f"main model: {model_metadata['main_model']}")
logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}")
logger.info(f"max num seqs: {model_metadata['max_num_seqs']}")

self.model_args = AsyncEngineArgs(
model=model_metadata["main_model"],
trust_remote_code=True,
tensor_parallel_size=model_metadata["tensor_parallel"],
max_num_seqs=model_metadata["max_num_seqs"],
dtype="auto",
use_v2_block_manager=True,
enforce_eager=True,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args)
# create tokenizer for gemma 2 to apply chat template to prompts

self.tokenizer = AutoTokenizer.from_pretrained(model_metadata["main_model"])

try:
result = subprocess.run(
["nvidia-smi"], capture_output=True, text=True, check=True
)
print(result.stdout)
except subprocess.CalledProcessError as e:
print(f"Command failed with code {e.returncode}: {e.stderr}")

async def predict(self, model_input):
prompt = model_input.pop("prompt")
stream = model_input.pop("stream", True)

sampling_params = SamplingParams(**model_input)
idx = str(uuid.uuid4().hex)
chat = [
{"role": "user", "content": prompt},
]
# templatize the input to the model
input = self.tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
# since we accept any valid vllm sampling parameters, we can just pass it through

vllm_generator = self.llm_engine.generate(input, sampling_params, idx)

async def generator():
full_text = ""
async for output in vllm_generator:
text = output.outputs[0].text
delta = text[len(full_text) :]
full_text = text
yield delta

if stream:
return generator()
else:
full_text = ""
async for delta in generator():
full_text += delta
return {"text": full_text}
Loading