Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add python binding for rust llm modules #252

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@
**/.github
**/*backup*/
.dockerignore
**/target/*
**/target/*
**/.venv/*
27 changes: 27 additions & 0 deletions examples/python_rs/llm/vllm/http/openai_service/curl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

# list models
echo "\n\n### Listing models"
curl http://localhost:8000/v1/models

# create completion
echo "\n\n### Creating completions"
curl -X POST http://localhost:8000/v1/chat/completions \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "mock_model",
"messages": [
{
"role":"user",
"content":"Hello! How are you?"
}
],
"max_tokens": 64,
"stream": true,
"temperature": 0.7,
"top_p": 0.9,
"frequency_penalty": 0.1,
"presence_penalty": 0.2,
"top_k": 5
}'
85 changes: 85 additions & 0 deletions examples/python_rs/llm/vllm/http/openai_service/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import time
import uuid

import uvloop
from triton_distributed_rs import (
DistributedRuntime,
HttpAsyncEngine,
HttpService,
triton_worker,
)


class MockEngine:
def __init__(self, model_name):
self.model_name = model_name

def generate(self, request):
id = f"chat-{uuid.uuid4()}"
created = int(time.time())
model = self.model_name
print(f"{created} | Received request: {request}")

async def generator():
num_chunks = 5
for i in range(num_chunks):
mock_content = f"chunk{i}"
finish_reason = "stop" if (i == num_chunks - 1) else None
chunk = {
"id": id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": i,
"delta": {"role": None, "content": mock_content},
"logprobs": None,
"finish_reason": finish_reason,
}
],
}
yield chunk

return generator()


@triton_worker()
async def worker(runtime: DistributedRuntime):
model: str = "mock_model"
served_model_name: str = "mock_model"

loop = asyncio.get_running_loop()
python_engine = MockEngine(model)
engine = HttpAsyncEngine(python_engine.generate, loop)

host: str = "localhost"
port: int = 8000
service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine)

print("Starting service...")
shutdown_signal = service.run(runtime.child_token())

try:
print(f"Serving endpoint: {host}:{port}/v1/models")
print(f"Serving endpoint: {host}:{port}/v1/chat/completions")
print(f"Serving the following models: {service.list_chat_completions_models()}")
# Block until shutdown signal received
await shutdown_signal
except KeyboardInterrupt:
# TODO: Handle KeyboardInterrupt gracefully in triton_worker
# TODO: Caught by DistributedRuntime or HttpService, so it's not caught here
pass
except Exception as e:
print(f"Unexpected error occurred: {e}")
finally:
print("Shutting down worker...")
runtime.shutdown()


if __name__ == "__main__":
uvloop.install()
# TODO: linter complains about lack of runtime arg passed
asyncio.run(worker())
Empty file.
37 changes: 37 additions & 0 deletions examples/python_rs/llm/vllm/preprocessor_backend/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import asyncio

import uvloop
from triton_distributed_rs import DistributedRuntime, triton_worker

uvloop.install()


@triton_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = (
runtime.namespace("triton-init").component("preprocessor").endpoint("generate")
)

# create client
client = await endpoint.client()

chat_completion_request = dict(
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
messages=[{"role": "user", "content": "what is deep learning?"}],
max_tokens=64,
stream=True,
)

# issue request
stream = await client.generate(chat_completion_request)

# process response
async for resp in stream:
print(resp)


asyncio.run(worker())
28 changes: 28 additions & 0 deletions examples/python_rs/llm/vllm/preprocessor_backend/curl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# list models
echo "\n\n### Listing models"
curl http://localhost:8000/v1/models

# create completion
echo "\n\n### Creating completions"
curl -X 'POST' \
'http://localhost:8000/v1/chat/completions' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"messages": [
{
"role":"user",
"content":"what is deep learning?"
}
],
"max_tokens": 64,
"stream": true,
"temperature": 0.7,
"top_p": 0.9,
"frequency_penalty": 0.1,
"presence_penalty": 0.2,
"top_k": 5
}'
95 changes: 95 additions & 0 deletions examples/python_rs/llm/vllm/preprocessor_backend/http_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import asyncio
import logging

import uvloop
from triton_distributed_rs import (
DistributedRuntime,
HttpAsyncEngine,
HttpService,
triton_worker,
)

logging.basicConfig(level=logging.INFO)


class OpenAIChatService:
def __init__(self, model_name, model_path, preprocessor):
self.model_name = model_name
self.model_path = model_path
self.preprocessor = preprocessor

async def generate(self, request):
print(f"Received request: {request}")
logging.info(f"Received request: {request}")
async for resp in await self.preprocessor.random(request):
logging.info(f"Sending response: {resp}")
yield resp["data"]


@triton_worker()
async def worker(
runtime: DistributedRuntime, model_name: str, model_path: str, port: int
):
loop = asyncio.get_running_loop()
preprocessor = (
await runtime.namespace("triton-init")
.component("preprocessor")
.endpoint("generate")
.client()
)
python_engine = OpenAIChatService(model_name, model_path, preprocessor)

engine = HttpAsyncEngine(python_engine.generate, loop)

host: str = "localhost"
service: HttpService = HttpService(port=port)
service.add_chat_completions_model(model_name, engine)

logging.info("Starting service...")
shutdown_signal = service.run(runtime.child_token())
try:
logging.info(f"Serving endpoint: {host}:{port}/v1/models")
# TODO: add completion endpoint
logging.info(
f"Serving chat completion endpoint: {host}:{port}/v1/chat/completions"
)
logging.info(
f"Serving the following models: {service.list_chat_completions_models()}"
)
# Block until shutdown signal received
await shutdown_signal
except KeyboardInterrupt:
# FIXME: Caught by DistributedRuntime or HttpService, so not caught here
pass
except Exception as e:
logging.error(f"Unexpected error occurred: {e}")
finally:
logging.info("Shutting down worker...")
runtime.shutdown()


## Add arg parse to parse the model name and port


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
)
parser.add_argument(
"--model-path",
type=str,
default="~/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
)
# only used by http service
parser.add_argument("--port", type=int, default=8000)
return parser.parse_args()


if __name__ == "__main__":
uvloop.install()
args = parse_args()
asyncio.run(worker(args.model_name, args.model_path, args.port))
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
import asyncio

import uvloop
from triton_distributed_rs import (
DistributedRuntime,
ModelDeploymentCard,
OAIChatPreprocessor,
triton_worker,
)

uvloop.install()


@triton_worker()
async def preprocessor(runtime: DistributedRuntime, model_name: str, model_path: str):
# create model deployment card
mdc = await ModelDeploymentCard.from_local_path(model_path, model_name)
# create preprocessor endpoint
component = runtime.namespace("triton-init").component("preprocessor")
await component.create_service()
endpoint = component.endpoint("generate")

# create backend endpoint
backend = runtime.namespace("triton-init").component("backend").endpoint("generate")

# start preprocessor service with next backend
chat = OAIChatPreprocessor(mdc, endpoint, next=backend)
await chat.start()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
)
parser.add_argument(
"--model-path",
type=str,
default="~/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
asyncio.run(preprocessor(args.model_name, args.model_path))
Loading
Loading