Skip to content

Commit

Permalink
Add inference code
Browse files Browse the repository at this point in the history
  • Loading branch information
ptelang committed Nov 25, 2024
1 parent 1b711ff commit 191b78c
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 2 deletions.
7 changes: 7 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG
# Note: This configuration can be overridden by:
# 1. CLI arguments (--port, --host, --log-level)
# 2. Environment variables (CODEGATE_APP_PORT, CODEGATE_APP_HOST, CODEGATE_APP_LOG_LEVEL)


# Inference configuration
embedding_model_repo: "leliuga/all-MiniLM-L6-v2-GGUF"
embedding_model_file: "*Q8_0.gguf"

chat_model_path: "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ fastapi = ">=0.115.5"
uvicorn = ">=0.32.1"

litellm = "^1.52.15"
llama_cpp_python = ">=0.3.2"

[tool.poetry.group.dev.dependencies]
pytest = ">=7.4.0"
pytest-cov = ">=4.1.0"
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/inference/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .inference_engine import LlamaCppInferenceEngine

__all__ = [LlamaCppInferenceEngine]
41 changes: 41 additions & 0 deletions src/codegate/inference/inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from llama_cpp import Llama


class LlamaCppInferenceEngine():
_inference_engine = None

def __new__(cls):
if cls._inference_engine is None:
cls._inference_engine = super().__new__(cls)
return cls._inference_engine

def __init__(self):
if not hasattr(self, 'models'):
self.__models = {}

async def get_model(self, model_path, embedding=False, n_ctx=512):
if model_path not in self.__models:
self.__models[model_path] = Llama(
model_path=model_path, n_gpu_layers=0, verbose=False, n_ctx=n_ctx, embedding=embedding)

return self.__models[model_path]

async def generate(self, model_path, prompt, stream=True):
model = await self.get_model(model_path=model_path, n_ctx=32768)

for chunk in model.create_completion(prompt=prompt, stream=stream):
yield chunk

async def chat(self, model_path, **chat_completion_request):
model = await self.get_model(model_path=model_path, n_ctx=32768)
return model.create_completion(**chat_completion_request)

async def embed(self, model_path, content):
model = await self.get_model(model_path=model_path, embedding=True)
return model.embed(content)

async def close_models(self):
for _, model in self.__models:
if model._sampler:
model._sampler.close()
model.close()
3 changes: 2 additions & 1 deletion src/codegate/providers/litellmshim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .adapter import BaseAdapter
from .generators import anthropic_stream_generator, sse_stream_generator
from .generators import anthropic_stream_generator, sse_stream_generator, llamacpp_stream_generator
from .litellmshim import LiteLLmShim

__all__ = [
"sse_stream_generator",
"anthropic_stream_generator",
"llamacpp_stream_generator",
"LiteLLmShim",
"BaseAdapter",
]
19 changes: 19 additions & 0 deletions src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,22 @@ async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterato
yield f"event: {event_type}\ndata:{str(e)}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"


async def llamacpp_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]:
"""OpenAI-style SSE format"""
try:
for chunk in stream:
if hasattr(chunk, "model_dump_json"):
if not hasattr(chunk.choices[0], 'content'):
chunk.choices[0].content = 'foo'
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
chunk['content'] = chunk['choices'][0]['text']
yield f"data:{json.dumps(chunk)}\n\n"
except Exception as e:
yield f"data:{str(e)}\n\n"
except Exception as e:
yield f"data: {str(e)}\n\n"
finally:
yield "data: [DONE]\n\n"
Empty file.
33 changes: 33 additions & 0 deletions src/codegate/providers/llamacpp/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, AsyncIterator, Dict, Optional

from litellm import ChatCompletionRequest, ModelResponse

from ..base import StreamGenerator
from ..litellmshim import llamacpp_stream_generator
from ..litellmshim.litellmshim import BaseAdapter


class LlamaCppAdapter(BaseAdapter):
"""
This is just a wrapper around LiteLLM's adapter class interface that passes
through the input and output as-is - LiteLLM's API expects OpenAI's API
format.
"""
def __init__(self, stream_generator: StreamGenerator = llamacpp_stream_generator):
super().__init__(stream_generator)

def translate_completion_input_params(
self, kwargs: Dict
) -> Optional[ChatCompletionRequest]:
try:
return ChatCompletionRequest(**kwargs)
except Exception as e:
raise ValueError(f"Invalid completion parameters: {str(e)}")

def translate_completion_output_params(self, response: ModelResponse) -> Any:
return response

def translate_completion_output_params_streaming(
self, completion_stream: AsyncIterator[ModelResponse]
) -> AsyncIterator[ModelResponse]:
return completion_stream
53 changes: 53 additions & 0 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, AsyncIterator, Dict

from fastapi.responses import StreamingResponse
from litellm import ModelResponse, acompletion

from ..base import BaseCompletionHandler
from .adapter import BaseAdapter
from codegate.inference.inference_engine import LlamaCppInferenceEngine


class LlamaCppCompletionHandler(BaseCompletionHandler):
def __init__(self, adapter: BaseAdapter):
self._adapter = adapter
self.inference_engine = LlamaCppInferenceEngine()

async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]:
"""
Translate the input parameters to LiteLLM's format using the adapter and
call the LiteLLM API. Then translate the response back to our format using
the adapter.
"""
completion_request = self._adapter.translate_completion_input_params(
data)
if completion_request is None:
raise Exception("Couldn't translate the request")

# Replace n_predict option with max_tokens
if 'n_predict' in completion_request:
completion_request['max_tokens'] = completion_request['n_predict']
del completion_request['n_predict']

response = await self.inference_engine.chat('./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf', **completion_request)

if isinstance(response, ModelResponse):
return self._adapter.translate_completion_output_params(response)
return self._adapter.translate_completion_output_params_streaming(response)

def create_streaming_response(
self, stream: AsyncIterator[Any]
) -> StreamingResponse:
"""
Create a streaming response from a stream generator. The StreamingResponse
is the format that FastAPI expects for streaming responses.
"""
return StreamingResponse(
self._adapter.stream_generator(stream),
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
status_code=200,
)
30 changes: 30 additions & 0 deletions src/codegate/providers/llamacpp/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json

from fastapi import Header, HTTPException, Request

from ..base import BaseProvider
from .completion_handler import LlamaCppCompletionHandler
from .adapter import LlamaCppAdapter


class LlamaCppProvider(BaseProvider):
def __init__(self):
adapter = LlamaCppAdapter()
completion_handler = LlamaCppCompletionHandler(adapter)
super().__init__(completion_handler)

def _setup_routes(self):
"""
Sets up the /chat route for the provider as expected by the
Llama API. Extracts the API key from the "Authorization" header and
passes it to the completion handler.
"""
@self.router.post("/completion")
async def create_completion(
request: Request,
):
body = await request.body()
data = json.loads(body)

stream = await self.complete(data, None)
return self._completion_handler.create_streaming_response(stream)
2 changes: 2 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .providers.anthropic.provider import AnthropicProvider
from .providers.openai.provider import OpenAIProvider
from .providers.registry import ProviderRegistry
from .providers.llamacpp.provider import LlamaCppProvider


def init_app() -> FastAPI:
Expand All @@ -19,6 +20,7 @@ def init_app() -> FastAPI:
# Register all known providers
registry.add_provider("openai", OpenAIProvider())
registry.add_provider("anthropic", AnthropicProvider())
registry.add_provider("llamacpp", LlamaCppProvider())

# Create and add system routes
system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs
Expand Down
7 changes: 6 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import yaml

from codegate.config import Config

from codegate.inference import LlamaCppInferenceEngine

@pytest.fixture
def temp_config_file(tmp_path: Path) -> Iterator[Path]:
Expand Down Expand Up @@ -94,3 +94,8 @@ def parse_json_log(log_line: str) -> dict[str, Any]:
return json.loads(log_line)
except json.JSONDecodeError as e:
pytest.fail(f"Invalid JSON log line: {e}")


@pytest.fixture
def inference_engine() -> LlamaCppInferenceEngine:
return LlamaCppInferenceEngine()
43 changes: 43 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

import pytest
import pytest


# @pytest.mark.asyncio
# async def test_generate(inference_engine) -> None:
# """Test code generation."""

# prompt = '''
# import requests

# # Function to call API over http
# def call_api(url):
# '''
# model_path = "./models/qwen2.5-coder-1.5B.q5_k_m.gguf"

# async for chunk in inference_engine.generate(model_path, prompt):
# print(chunk)


@pytest.mark.asyncio
async def test_chat(inference_engine) -> None:
"""Test chat completion."""

chat_request = {"prompt": "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n",
"stream": True, "max_tokens": 4096, "top_k": 50, "temperature": 0}

model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
response = await inference_engine.chat(model_path, **chat_request)

for chunk in response:
assert chunk['choices'][0]['text'] is not None


@pytest.mark.asyncio
async def test_embed(inference_engine) -> None:
"""Test content embedding."""

content = "Can I use invokehttp package in my project?"
model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
vector = await inference_engine.embed(model_path, content=content)
assert len(vector) == 384

0 comments on commit 191b78c

Please sign in to comment.