-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding ultravox to truss examples (#323)
- Loading branch information
1 parent
920209f
commit efb67e4
Showing
4 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Ultravox vLLM Truss | ||
|
||
This is a [Truss](https://truss.baseten.co/) for Ultravox using the vLLM OpenAI Compatible server. This Truss is designed to provide an efficient and scalable way to serve Ultravox and other models in an OpenAI compatible way using vLLM. | ||
|
||
## OpenAI Bridge Compatibility | ||
|
||
This Truss is compatible with a *custom* version of our [bridge endpoint for OpenAI ChatCompletion users](https://docs.baseten.co/api-reference/openai). This means you can easily integrate this model into your existing applications that use the OpenAI API format. | ||
|
||
``` | ||
client = OpenAI( | ||
api_key=os.environ["BASETEN_API_KEY"], | ||
base_url=f"https://bridge.baseten.co/{model_id}/direct/v1" | ||
) | ||
``` | ||
|
||
## Truss | ||
|
||
Truss is an open-source model serving framework developed by Baseten. It allows you to develop and deploy machine learning models onto Baseten (and other platforms like [AWS](https://truss.baseten.co/deploy/aws) or [GCP](https://truss.baseten.co/deploy/gcp)). Using Truss, you can develop a GPU model using [live-reload](https://baseten.co/blog/technical-deep-dive-truss-live-reload), package models and their associated code, create Docker containers, and deploy on Baseten. | ||
|
||
## Deployment | ||
|
||
First, clone this repository: | ||
|
||
```sh | ||
git clone https://github.com/basetenlabs/truss-examples.git | ||
cd ultravox | ||
``` | ||
|
||
Before deployment: | ||
|
||
1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys). | ||
2. Install the latest version of Truss: `pip install --upgrade truss` | ||
|
||
With `ultravox` as your working directory, you can deploy the model with: | ||
|
||
```sh | ||
truss push | ||
``` | ||
|
||
Paste your Baseten API key if prompted. | ||
|
||
For more information, see [Truss documentation](https://truss.baseten.co). | ||
|
||
## vLLM OpenAI Compatible Server | ||
|
||
This Truss demonstrates how to start [vLLM's OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). The Truss is primarily used to start the server and then route requests to it. It currently supports ChatCompletions only. | ||
|
||
### Passing startup arguments to the server | ||
|
||
In the config any key-values under `model_metadata: arguments:` will be passed to the vLLM OpenAI-compatible server at startup. | ||
|
||
### Base Image | ||
|
||
You can use any vLLM compatible base image. | ||
|
||
## API Documentation | ||
|
||
The API follows the OpenAI ChatCompletion format. You can interact with the model using the standard ChatCompletion interface. | ||
|
||
Example usage: | ||
|
||
```python | ||
from openai import OpenAI | ||
|
||
client = OpenAI( | ||
api_key="YOUR-API-KEY", | ||
base_url="https://bridge.baseten.co/MODEL-ID/v1" | ||
) | ||
|
||
response = client.chat.completions.create( | ||
model="fixie-ai/ultravox-v0.2", | ||
messages=[{ | ||
"role": "user", | ||
"content": [ | ||
{"type": "text", "text": "Summarize the following: <|audio|>"}, | ||
{"type": "image_url", "image_url": {"url": f"data:audio/wav;base64,{base64_wav}"}} | ||
] | ||
}] | ||
stream=True | ||
) | ||
|
||
for chunk in response: | ||
print(chunk.choices[0].delta) | ||
``` | ||
|
||
## Future Improvements | ||
|
||
We are actively working on enhancing this Truss. Some planned improvements include: | ||
|
||
- Adding support for distributed serving with Ray (https://docs.vllm.ai/en/latest/serving/distributed_serving.html) | ||
- Implementing model caching for improved performance | ||
|
||
Stay tuned for updates! | ||
|
||
## Support | ||
|
||
If you have any questions or need assistance, please open an issue in this repository or contact our support team. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
base_image: | ||
image: vshulman/vllm-openai-fixie:latest | ||
python_executable_path: /usr/bin/python3 | ||
model_metadata: | ||
arguments: | ||
model: fixie-ai/ultravox-v0.2 | ||
audio_token_id: 128002 | ||
environment_variables: {} | ||
external_package_dirs: [] | ||
model_name: Ultravox v0.2 | ||
python_version: py310 | ||
runtime: | ||
predict_concurrency: 512 | ||
requirements: | ||
- httpx | ||
resources: | ||
accelerator: A100 | ||
use_gpu: true | ||
secrets: {} | ||
system_packages: | ||
- python3.10-venv |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import json | ||
import subprocess | ||
import time | ||
from typing import Any, Dict, List | ||
|
||
import httpx # Changed from aiohttp to httpx | ||
|
||
|
||
class Model: | ||
MAX_FAILED_SECONDS = 600 # 10 minutes; the reason this would take this long is mostly if we download a large model | ||
|
||
def __init__(self, data_dir, config, secrets): | ||
self._secrets = secrets | ||
self._config = config | ||
self.vllm_base_url = None | ||
|
||
# TODO: uncomment for multi-GPU support | ||
# command = "ray start --head" | ||
# subprocess.check_output(command, shell=True, text=True) | ||
|
||
def load(self): | ||
self._client = httpx.AsyncClient(timeout=None) | ||
|
||
self._vllm_config = self._config["model_metadata"]["arguments"] | ||
|
||
command = ["python3", "-m", "vllm.entrypoints.openai.api_server"] | ||
for key, value in self._vllm_config.items(): | ||
command.append(f"--{key.replace('_', '-')}") | ||
command.append(str(value)) | ||
|
||
subprocess.Popen(command) | ||
|
||
if "port" in self._vllm_config: | ||
self._vllm_port = self._vllm_config["port"] | ||
else: | ||
self._vllm_port = 8000 | ||
|
||
self.vllm_base_url = f"http://localhost:{self._vllm_port}" | ||
|
||
# Polling to check if the server is up | ||
server_up = False | ||
start_time = time.time() | ||
while time.time() - start_time < self.MAX_FAILED_SECONDS: | ||
try: | ||
response = httpx.get(f"{self.vllm_base_url}/health") | ||
if response.status_code == 200: | ||
server_up = True | ||
break | ||
except httpx.RequestError: | ||
time.sleep(1) # Wait for 1 second before retrying | ||
|
||
if not server_up: | ||
raise RuntimeError( | ||
"Server failed to start within the maximum allowed time." | ||
) | ||
|
||
async def predict(self, model_input): | ||
|
||
# if the key metrics: true is present, let's return the vLLM /metrics endpoint | ||
if model_input.get("metrics", False): | ||
response = await self._client.get(f"{self.vllm_base_url}/metrics") | ||
return response.text | ||
|
||
# convenience for Baseten bridge | ||
if "model" not in model_input and "model" in self._vllm_config: | ||
print( | ||
f"model_input missing model due to Baseten bridge, using {self._vllm_config['model']}" | ||
) | ||
model_input["model"] = self._vllm_config["model"] | ||
|
||
stream = model_input.get("stream", False) | ||
if stream: | ||
|
||
async def generator(): | ||
async with self._client.stream( | ||
"POST", | ||
f"{self.vllm_base_url}/v1/chat/completions", | ||
json=model_input, | ||
) as response: | ||
async for chunk in response.aiter_bytes(): | ||
if chunk: | ||
yield chunk | ||
|
||
return generator() | ||
else: | ||
response = await self._client.post( | ||
f"{self.vllm_base_url}/v1/chat/completions", | ||
json=model_input, | ||
) | ||
return response.json() |