Skip to content

Commit

Permalink
Adding ultravox to truss examples (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
htrivedi99 authored Jul 11, 2024
1 parent 920209f commit efb67e4
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 0 deletions.
97 changes: 97 additions & 0 deletions ultravox/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Ultravox vLLM Truss

This is a [Truss](https://truss.baseten.co/) for Ultravox using the vLLM OpenAI Compatible server. This Truss is designed to provide an efficient and scalable way to serve Ultravox and other models in an OpenAI compatible way using vLLM.

## OpenAI Bridge Compatibility

This Truss is compatible with a *custom* version of our [bridge endpoint for OpenAI ChatCompletion users](https://docs.baseten.co/api-reference/openai). This means you can easily integrate this model into your existing applications that use the OpenAI API format.

```
client = OpenAI(
api_key=os.environ["BASETEN_API_KEY"],
base_url=f"https://bridge.baseten.co/{model_id}/direct/v1"
)
```

## Truss

Truss is an open-source model serving framework developed by Baseten. It allows you to develop and deploy machine learning models onto Baseten (and other platforms like [AWS](https://truss.baseten.co/deploy/aws) or [GCP](https://truss.baseten.co/deploy/gcp)). Using Truss, you can develop a GPU model using [live-reload](https://baseten.co/blog/technical-deep-dive-truss-live-reload), package models and their associated code, create Docker containers, and deploy on Baseten.

## Deployment

First, clone this repository:

```sh
git clone https://github.com/basetenlabs/truss-examples.git
cd ultravox
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `ultravox` as your working directory, you can deploy the model with:

```sh
truss push
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## vLLM OpenAI Compatible Server

This Truss demonstrates how to start [vLLM's OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). The Truss is primarily used to start the server and then route requests to it. It currently supports ChatCompletions only.

### Passing startup arguments to the server

In the config any key-values under `model_metadata: arguments:` will be passed to the vLLM OpenAI-compatible server at startup.

### Base Image

You can use any vLLM compatible base image.

## API Documentation

The API follows the OpenAI ChatCompletion format. You can interact with the model using the standard ChatCompletion interface.

Example usage:

```python
from openai import OpenAI

client = OpenAI(
api_key="YOUR-API-KEY",
base_url="https://bridge.baseten.co/MODEL-ID/v1"
)

response = client.chat.completions.create(
model="fixie-ai/ultravox-v0.2",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": "Summarize the following: <|audio|>"},
{"type": "image_url", "image_url": {"url": f"data:audio/wav;base64,{base64_wav}"}}
]
}]
stream=True
)

for chunk in response:
print(chunk.choices[0].delta)
```

## Future Improvements

We are actively working on enhancing this Truss. Some planned improvements include:

- Adding support for distributed serving with Ray (https://docs.vllm.ai/en/latest/serving/distributed_serving.html)
- Implementing model caching for improved performance

Stay tuned for updates!

## Support

If you have any questions or need assistance, please open an issue in this repository or contact our support team.
21 changes: 21 additions & 0 deletions ultravox/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
base_image:
image: vshulman/vllm-openai-fixie:latest
python_executable_path: /usr/bin/python3
model_metadata:
arguments:
model: fixie-ai/ultravox-v0.2
audio_token_id: 128002
environment_variables: {}
external_package_dirs: []
model_name: Ultravox v0.2
python_version: py310
runtime:
predict_concurrency: 512
requirements:
- httpx
resources:
accelerator: A100
use_gpu: true
secrets: {}
system_packages:
- python3.10-venv
Empty file added ultravox/model/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions ultravox/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
import subprocess
import time
from typing import Any, Dict, List

import httpx # Changed from aiohttp to httpx


class Model:
MAX_FAILED_SECONDS = 600 # 10 minutes; the reason this would take this long is mostly if we download a large model

def __init__(self, data_dir, config, secrets):
self._secrets = secrets
self._config = config
self.vllm_base_url = None

# TODO: uncomment for multi-GPU support
# command = "ray start --head"
# subprocess.check_output(command, shell=True, text=True)

def load(self):
self._client = httpx.AsyncClient(timeout=None)

self._vllm_config = self._config["model_metadata"]["arguments"]

command = ["python3", "-m", "vllm.entrypoints.openai.api_server"]
for key, value in self._vllm_config.items():
command.append(f"--{key.replace('_', '-')}")
command.append(str(value))

subprocess.Popen(command)

if "port" in self._vllm_config:
self._vllm_port = self._vllm_config["port"]
else:
self._vllm_port = 8000

self.vllm_base_url = f"http://localhost:{self._vllm_port}"

# Polling to check if the server is up
server_up = False
start_time = time.time()
while time.time() - start_time < self.MAX_FAILED_SECONDS:
try:
response = httpx.get(f"{self.vllm_base_url}/health")
if response.status_code == 200:
server_up = True
break
except httpx.RequestError:
time.sleep(1) # Wait for 1 second before retrying

if not server_up:
raise RuntimeError(
"Server failed to start within the maximum allowed time."
)

async def predict(self, model_input):

# if the key metrics: true is present, let's return the vLLM /metrics endpoint
if model_input.get("metrics", False):
response = await self._client.get(f"{self.vllm_base_url}/metrics")
return response.text

# convenience for Baseten bridge
if "model" not in model_input and "model" in self._vllm_config:
print(
f"model_input missing model due to Baseten bridge, using {self._vllm_config['model']}"
)
model_input["model"] = self._vllm_config["model"]

stream = model_input.get("stream", False)
if stream:

async def generator():
async with self._client.stream(
"POST",
f"{self.vllm_base_url}/v1/chat/completions",
json=model_input,
) as response:
async for chunk in response.aiter_bytes():
if chunk:
yield chunk

return generator()
else:
response = await self._client.post(
f"{self.vllm_base_url}/v1/chat/completions",
json=model_input,
)
return response.json()

0 comments on commit efb67e4

Please sign in to comment.