From 191b78c20052293a6de7d3fefe809a5c04d2ba87 Mon Sep 17 00:00:00 2001 From: Pankaj Telang Date: Mon, 25 Nov 2024 10:30:12 -0500 Subject: [PATCH] Add inference code --- config.yaml | 7 +++ pyproject.toml | 2 + src/codegate/inference/__init__.py | 3 ++ src/codegate/inference/inference_engine.py | 41 ++++++++++++++ .../providers/litellmshim/__init__.py | 3 +- .../providers/litellmshim/generators.py | 19 +++++++ src/codegate/providers/llamacpp/__init__.py | 0 src/codegate/providers/llamacpp/adapter.py | 33 ++++++++++++ .../providers/llamacpp/completion_handler.py | 53 +++++++++++++++++++ src/codegate/providers/llamacpp/provider.py | 30 +++++++++++ src/codegate/server.py | 2 + tests/conftest.py | 7 ++- tests/test_inference.py | 43 +++++++++++++++ 13 files changed, 241 insertions(+), 2 deletions(-) create mode 100644 src/codegate/inference/__init__.py create mode 100644 src/codegate/inference/inference_engine.py create mode 100644 src/codegate/providers/llamacpp/__init__.py create mode 100644 src/codegate/providers/llamacpp/adapter.py create mode 100644 src/codegate/providers/llamacpp/completion_handler.py create mode 100644 src/codegate/providers/llamacpp/provider.py create mode 100644 tests/test_inference.py diff --git a/config.yaml b/config.yaml index 113b6aa4..822732b1 100644 --- a/config.yaml +++ b/config.yaml @@ -10,3 +10,10 @@ log_level: "INFO" # One of: ERROR, WARNING, INFO, DEBUG # Note: This configuration can be overridden by: # 1. CLI arguments (--port, --host, --log-level) # 2. Environment variables (CODEGATE_APP_PORT, CODEGATE_APP_HOST, CODEGATE_APP_LOG_LEVEL) + + +# Inference configuration +embedding_model_repo: "leliuga/all-MiniLM-L6-v2-GGUF" +embedding_model_file: "*Q8_0.gguf" + +chat_model_path: "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d889f6fa..68113519 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ fastapi = ">=0.115.5" uvicorn = ">=0.32.1" litellm = "^1.52.15" +llama_cpp_python = ">=0.3.2" + [tool.poetry.group.dev.dependencies] pytest = ">=7.4.0" pytest-cov = ">=4.1.0" diff --git a/src/codegate/inference/__init__.py b/src/codegate/inference/__init__.py new file mode 100644 index 00000000..2278aee6 --- /dev/null +++ b/src/codegate/inference/__init__.py @@ -0,0 +1,3 @@ +from .inference_engine import LlamaCppInferenceEngine + +__all__ = [LlamaCppInferenceEngine] \ No newline at end of file diff --git a/src/codegate/inference/inference_engine.py b/src/codegate/inference/inference_engine.py new file mode 100644 index 00000000..7e82c3dd --- /dev/null +++ b/src/codegate/inference/inference_engine.py @@ -0,0 +1,41 @@ +from llama_cpp import Llama + + +class LlamaCppInferenceEngine(): + _inference_engine = None + + def __new__(cls): + if cls._inference_engine is None: + cls._inference_engine = super().__new__(cls) + return cls._inference_engine + + def __init__(self): + if not hasattr(self, 'models'): + self.__models = {} + + async def get_model(self, model_path, embedding=False, n_ctx=512): + if model_path not in self.__models: + self.__models[model_path] = Llama( + model_path=model_path, n_gpu_layers=0, verbose=False, n_ctx=n_ctx, embedding=embedding) + + return self.__models[model_path] + + async def generate(self, model_path, prompt, stream=True): + model = await self.get_model(model_path=model_path, n_ctx=32768) + + for chunk in model.create_completion(prompt=prompt, stream=stream): + yield chunk + + async def chat(self, model_path, **chat_completion_request): + model = await self.get_model(model_path=model_path, n_ctx=32768) + return model.create_completion(**chat_completion_request) + + async def embed(self, model_path, content): + model = await self.get_model(model_path=model_path, embedding=True) + return model.embed(content) + + async def close_models(self): + for _, model in self.__models: + if model._sampler: + model._sampler.close() + model.close() diff --git a/src/codegate/providers/litellmshim/__init__.py b/src/codegate/providers/litellmshim/__init__.py index ec191270..e23e34db 100644 --- a/src/codegate/providers/litellmshim/__init__.py +++ b/src/codegate/providers/litellmshim/__init__.py @@ -1,10 +1,11 @@ from .adapter import BaseAdapter -from .generators import anthropic_stream_generator, sse_stream_generator +from .generators import anthropic_stream_generator, sse_stream_generator, llamacpp_stream_generator from .litellmshim import LiteLLmShim __all__ = [ "sse_stream_generator", "anthropic_stream_generator", + "llamacpp_stream_generator", "LiteLLmShim", "BaseAdapter", ] diff --git a/src/codegate/providers/litellmshim/generators.py b/src/codegate/providers/litellmshim/generators.py index 306f1900..4bf5f76c 100644 --- a/src/codegate/providers/litellmshim/generators.py +++ b/src/codegate/providers/litellmshim/generators.py @@ -37,3 +37,22 @@ async def anthropic_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterato yield f"event: {event_type}\ndata:{str(e)}\n\n" except Exception as e: yield f"data: {str(e)}\n\n" + + +async def llamacpp_stream_generator(stream: AsyncIterator[Any]) -> AsyncIterator[str]: + """OpenAI-style SSE format""" + try: + for chunk in stream: + if hasattr(chunk, "model_dump_json"): + if not hasattr(chunk.choices[0], 'content'): + chunk.choices[0].content = 'foo' + chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) + try: + chunk['content'] = chunk['choices'][0]['text'] + yield f"data:{json.dumps(chunk)}\n\n" + except Exception as e: + yield f"data:{str(e)}\n\n" + except Exception as e: + yield f"data: {str(e)}\n\n" + finally: + yield "data: [DONE]\n\n" diff --git a/src/codegate/providers/llamacpp/__init__.py b/src/codegate/providers/llamacpp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/providers/llamacpp/adapter.py b/src/codegate/providers/llamacpp/adapter.py new file mode 100644 index 00000000..eb59b338 --- /dev/null +++ b/src/codegate/providers/llamacpp/adapter.py @@ -0,0 +1,33 @@ +from typing import Any, AsyncIterator, Dict, Optional + +from litellm import ChatCompletionRequest, ModelResponse + +from ..base import StreamGenerator +from ..litellmshim import llamacpp_stream_generator +from ..litellmshim.litellmshim import BaseAdapter + + +class LlamaCppAdapter(BaseAdapter): + """ + This is just a wrapper around LiteLLM's adapter class interface that passes + through the input and output as-is - LiteLLM's API expects OpenAI's API + format. + """ + def __init__(self, stream_generator: StreamGenerator = llamacpp_stream_generator): + super().__init__(stream_generator) + + def translate_completion_input_params( + self, kwargs: Dict + ) -> Optional[ChatCompletionRequest]: + try: + return ChatCompletionRequest(**kwargs) + except Exception as e: + raise ValueError(f"Invalid completion parameters: {str(e)}") + + def translate_completion_output_params(self, response: ModelResponse) -> Any: + return response + + def translate_completion_output_params_streaming( + self, completion_stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + return completion_stream diff --git a/src/codegate/providers/llamacpp/completion_handler.py b/src/codegate/providers/llamacpp/completion_handler.py new file mode 100644 index 00000000..99cce6cd --- /dev/null +++ b/src/codegate/providers/llamacpp/completion_handler.py @@ -0,0 +1,53 @@ +from typing import Any, AsyncIterator, Dict + +from fastapi.responses import StreamingResponse +from litellm import ModelResponse, acompletion + +from ..base import BaseCompletionHandler +from .adapter import BaseAdapter +from codegate.inference.inference_engine import LlamaCppInferenceEngine + + +class LlamaCppCompletionHandler(BaseCompletionHandler): + def __init__(self, adapter: BaseAdapter): + self._adapter = adapter + self.inference_engine = LlamaCppInferenceEngine() + + async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: + """ + Translate the input parameters to LiteLLM's format using the adapter and + call the LiteLLM API. Then translate the response back to our format using + the adapter. + """ + completion_request = self._adapter.translate_completion_input_params( + data) + if completion_request is None: + raise Exception("Couldn't translate the request") + + # Replace n_predict option with max_tokens + if 'n_predict' in completion_request: + completion_request['max_tokens'] = completion_request['n_predict'] + del completion_request['n_predict'] + + response = await self.inference_engine.chat('./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf', **completion_request) + + if isinstance(response, ModelResponse): + return self._adapter.translate_completion_output_params(response) + return self._adapter.translate_completion_output_params_streaming(response) + + def create_streaming_response( + self, stream: AsyncIterator[Any] + ) -> StreamingResponse: + """ + Create a streaming response from a stream generator. The StreamingResponse + is the format that FastAPI expects for streaming responses. + """ + return StreamingResponse( + self._adapter.stream_generator(stream), + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + status_code=200, + ) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py new file mode 100644 index 00000000..b036b000 --- /dev/null +++ b/src/codegate/providers/llamacpp/provider.py @@ -0,0 +1,30 @@ +import json + +from fastapi import Header, HTTPException, Request + +from ..base import BaseProvider +from .completion_handler import LlamaCppCompletionHandler +from .adapter import LlamaCppAdapter + + +class LlamaCppProvider(BaseProvider): + def __init__(self): + adapter = LlamaCppAdapter() + completion_handler = LlamaCppCompletionHandler(adapter) + super().__init__(completion_handler) + + def _setup_routes(self): + """ + Sets up the /chat route for the provider as expected by the + Llama API. Extracts the API key from the "Authorization" header and + passes it to the completion handler. + """ + @self.router.post("/completion") + async def create_completion( + request: Request, + ): + body = await request.body() + data = json.loads(body) + + stream = await self.complete(data, None) + return self._completion_handler.create_streaming_response(stream) diff --git a/src/codegate/server.py b/src/codegate/server.py index a9c203e2..18730411 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -4,6 +4,7 @@ from .providers.anthropic.provider import AnthropicProvider from .providers.openai.provider import OpenAIProvider from .providers.registry import ProviderRegistry +from .providers.llamacpp.provider import LlamaCppProvider def init_app() -> FastAPI: @@ -19,6 +20,7 @@ def init_app() -> FastAPI: # Register all known providers registry.add_provider("openai", OpenAIProvider()) registry.add_provider("anthropic", AnthropicProvider()) + registry.add_provider("llamacpp", LlamaCppProvider()) # Create and add system routes system_router = APIRouter(tags=["System"]) # Tags group endpoints in the docs diff --git a/tests/conftest.py b/tests/conftest.py index afbb5956..c0b46042 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ import yaml from codegate.config import Config - +from codegate.inference import LlamaCppInferenceEngine @pytest.fixture def temp_config_file(tmp_path: Path) -> Iterator[Path]: @@ -94,3 +94,8 @@ def parse_json_log(log_line: str) -> dict[str, Any]: return json.loads(log_line) except json.JSONDecodeError as e: pytest.fail(f"Invalid JSON log line: {e}") + + +@pytest.fixture +def inference_engine() -> LlamaCppInferenceEngine: + return LlamaCppInferenceEngine() \ No newline at end of file diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 00000000..a8523059 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,43 @@ + +import pytest +import pytest + + +# @pytest.mark.asyncio +# async def test_generate(inference_engine) -> None: +# """Test code generation.""" + +# prompt = ''' +# import requests + +# # Function to call API over http +# def call_api(url): +# ''' +# model_path = "./models/qwen2.5-coder-1.5B.q5_k_m.gguf" + +# async for chunk in inference_engine.generate(model_path, prompt): +# print(chunk) + + +@pytest.mark.asyncio +async def test_chat(inference_engine) -> None: + """Test chat completion.""" + + chat_request = {"prompt": "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n", + "stream": True, "max_tokens": 4096, "top_k": 50, "temperature": 0} + + model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf" + response = await inference_engine.chat(model_path, **chat_request) + + for chunk in response: + assert chunk['choices'][0]['text'] is not None + + +@pytest.mark.asyncio +async def test_embed(inference_engine) -> None: + """Test content embedding.""" + + content = "Can I use invokehttp package in my project?" + model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" + vector = await inference_engine.embed(model_path, content=content) + assert len(vector) == 384