Skip to content

Commit

Permalink
Use vLLM to load LLMs (#230)
Browse files Browse the repository at this point in the history
* Upgrade to use vLLM
* improve API
* Isolated docker image for LLMs
  • Loading branch information
kyriediculous authored Dec 26, 2024
1 parent 140006a commit b81f898
Show file tree
Hide file tree
Showing 13 changed files with 698 additions and 453 deletions.
430 changes: 269 additions & 161 deletions runner/app/pipelines/llm.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
is_turbo_model,
split_prompt,
validate_torch_device,
get_max_memory,
)
23 changes: 23 additions & 0 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import re
import psutil
from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -365,3 +366,25 @@ def enable_loras(self) -> None:
if not self.loras_enabled:
self.pipeline.enable_lora()
self.loras_enabled = True


class MemoryInfo:
def __init__(self, gpu_memory, cpu_memory, num_gpus):
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.num_gpus = num_gpus

def __repr__(self):
return f"<MemoryInfo: GPUs={self.num_gpus}, CPU Memory={self.cpu_memory}, GPU Memory={self.gpu_memory}>"


def get_max_memory() -> MemoryInfo:
num_gpus = torch.cuda.device_count()
gpu_memory = {
i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"

memory_info = MemoryInfo(gpu_memory=gpu_memory,
cpu_memory=cpu_memory, num_gpus=num_gpus)

return memory_info
81 changes: 38 additions & 43 deletions runner/app/routes/llm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
from typing import Annotated
from fastapi import APIRouter, Depends, Form, status
import time
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import HTTPError, LLMResponse, http_error
from app.routes.utils import HTTPError, LLMRequest, LLMResponse, http_error
import json

router = APIRouter()
Expand All @@ -33,13 +33,7 @@
)
@router.post("/llm/", response_model=LLMResponse, responses=RESPONSES, include_in_schema=False)
async def llm(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
request: LLMRequest,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand All @@ -52,50 +46,50 @@ async def llm(
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
if request.model != "" and request.model != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
f"pipeline configured with {pipeline.model_id} but called with {request.model}"
),
)

try:
history_list = json.loads(history)
if not isinstance(history_list, list):
raise ValueError("History must be a JSON array")

generator = pipeline(
prompt=prompt,
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens
messages=[msg.dict() for msg in request.messages],
temperature=request.temperature,
max_tokens=request.max_tokens,
top_p=request.top_p,
top_k=request.top_k
)

if stream:
return StreamingResponse(stream_generator(generator), media_type="text/event-stream")
if request.stream:
return StreamingResponse(
stream_generator(generator),
media_type="text/event-stream"
)
else:
full_response = ""
last_chunk = None

async for chunk in generator:
if isinstance(chunk, dict):
tokens_used = chunk["tokens_used"]
break
full_response += chunk
if "choices" in chunk:
if "delta" in chunk["choices"][0]:
full_response += chunk["choices"][0]["delta"].get(
"content", "")
last_chunk = chunk

return LLMResponse(response=full_response, tokens_used=tokens_used)
usage = last_chunk.get("usage", {})

return LLMResponse(
response=full_response,
tokens_used=usage.get("total_tokens", 0),
id=last_chunk.get("id", ""),
model=last_chunk.get("model", pipeline.model_id),
created=last_chunk.get("created", int(time.time()))
)

except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format for history"}
)
except ValueError as ve:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(ve)}
)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
Expand All @@ -107,11 +101,12 @@ async def llm(
async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict): # This is the final result
yield f"data: {json.dumps(chunk)}\n\n"
break
else:
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
if isinstance(chunk, dict):
if "choices" in chunk:
# Regular streaming chunk or final chunk
yield f"data: {json.dumps(chunk)}\n\n"
if chunk["choices"][0].get("finish_reason") == "stop":
break
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
Expand Down
19 changes: 19 additions & 0 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,27 @@ class TextResponse(BaseModel):
chunks: List[Chunk] = Field(..., description="The generated text chunks.")


class LLMMessage(BaseModel):
role: str
content: str


class LLMRequest(BaseModel):
messages: List[LLMMessage]
model: str = ""
temperature: float = 0.7
max_tokens: int = 256
top_p: float = 1.0
top_k: int = -1
stream: bool = False


class LLMResponse(BaseModel):
response: str
tokens_used: int
id: str
model: str
created: int


class ImageToTextResponse(BaseModel):
Expand All @@ -101,6 +119,7 @@ class LiveVideoToVideoResponse(BaseModel):
description="URL for subscribing to events for pipeline status and logs",
)


class APIError(BaseModel):
"""API error response model."""

Expand Down
74 changes: 74 additions & 0 deletions runner/docker/Dockerfile.llm
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Based on https://github.com/huggingface/api-inference-community/blob/main/docker_images/diffusers/Dockerfile

FROM nvidia/cuda:12.1.1-cudnn8-runtime-ubuntu20.04
LABEL maintainer="Yondon Fu <[email protected]>"

# Add any system dependency here
# RUN apt-get update -y && apt-get install libXXX -y

ENV DEBIAN_FRONTEND=noninteractive

# Install prerequisites
RUN apt-get update && \
apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git \
ffmpeg

# Install pyenv
RUN curl https://pyenv.run | bash

# Set environment variables for pyenv
ENV PYENV_ROOT /root/.pyenv
ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH

# Install your desired Python version
ARG PYTHON_VERSION=3.11
RUN pyenv install $PYTHON_VERSION && \
pyenv global $PYTHON_VERSION && \
pyenv rehash

# Upgrade pip and install your desired packages
ARG PIP_VERSION=24.2
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \
pip install --no-cache-dir torch==2.4.0 torchvision torchaudio pip-tools

WORKDIR /app

COPY ./requirements.llm.in /app
RUN pip-compile requirements.llm.in -o requirements.txt
RUN pip install --no-cache-dir -r requirements.txt

# Most DL models are quite large in terms of memory, using workers is a HUGE
# slowdown because of the fork and GIL with python.
# Using multiple pods seems like a better default strategy.
# Feel free to override if it does not make sense for your library.
ARG max_workers=1
ENV MAX_WORKERS=$max_workers
ENV HUGGINGFACE_HUB_CACHE=/models
ENV DIFFUSERS_CACHE=/models
ENV MODEL_DIR=/models
# This ensures compatbility with how GPUs are addressed within go-livepeer
ENV CUDA_DEVICE_ORDER=PCI_BUS_ID

# vLLM configuration
ENV USE_8BIT=false
ENV MAX_NUM_BATCHED_TOKENS=8192
ENV MAX_NUM_SEQS=128
ENV MAX_MODEL_LEN=8192
ENV GPU_MEMORY_UTILIZATION=0.85
ENV TENSOR_PARALLEL_SIZE=1
ENV PIPELINE_PARALLEL_SIZE=1
# To use multiple GPUs, set TENSOR_PARALLEL_SIZE and PIPELINE_PARALLEL_SIZE
# Total GPUs used = TENSOR_PARALLEL_SIZE × PIPELINE_PARALLEL_SIZE
# Example for 4 GPUs:
# - Option 1: TENSOR_PARALLEL_SIZE=2, PIPELINE_PARALLEL_SIZE=2
# - Option 2: TENSOR_PARALLEL_SIZE=4, PIPELINE_PARALLEL_SIZE=1
# - Option 3: TENSOR_PARALLEL_SIZE=1, PIPELINE_PARALLEL_SIZE=4

COPY app/ /app/app
COPY images/ /app/images
COPY bench.py /app/bench.py
COPY example_data/ /app/example_data

CMD ["uvicorn", "app.main:app", "--log-config", "app/cfg/uvicorn_logging_config.json", "--host", "0.0.0.0", "--port", "8000"]
Loading

0 comments on commit b81f898

Please sign in to comment.