Skip to content

Commit

Permalink
Update config to add singleton instance and add inference unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ptelang committed Nov 27, 2024
1 parent f0c0b38 commit 458deb8
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 42 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ jobs:
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Checkout github repo
uses: actions/checkout@v4
with:
lfs: true

- name: Checkout LFS objects
run: git lfs pull

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
Expand Down
11 changes: 11 additions & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
class Config:
"""Application configuration with priority resolution."""

# Singleton instance of Config which is set in Config.load().
# All consumers can call: Config.get_config() to get the config.
__config = None

port: int = 8989
host: str = "localhost"
log_level: LogLevel = LogLevel.INFO
Expand Down Expand Up @@ -208,4 +212,11 @@ def load(
if prompts_path is not None:
config.prompts = PromptConfig.from_file(prompts_path)

# Set the __config class attribute
Config.__config = config

return config

@classmethod
def get_config(cls):
return cls.__config
60 changes: 39 additions & 21 deletions src/codegate/inference/inference_engine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
from llama_cpp import Llama

from codegate.codegate_logging import setup_logging


class LlamaCppInferenceEngine:
_inference_engine = None
"""
A wrapper class for llama.cpp models
Attributes:
__inference_engine: Singleton instance of this class
"""

__inference_engine = None

def __new__(cls):
if cls._inference_engine is None:
cls._inference_engine = super().__new__(cls)
return cls._inference_engine
if cls.__inference_engine is None:
cls.__inference_engine = super().__new__(cls)
return cls.__inference_engine

def __init__(self):
if not hasattr(self, "models"):
self.__models = {}
self.__logger = setup_logging()

async def get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0):
def __del__(self):
self.__close_models()

async def __get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0):
"""
Returns Llama model object from __models if present. Otherwise, the model
is loaded and added to __models and returned.
"""
if model_path not in self.__models:
self.__logger.info(
f"Loading model from {model_path} with parameters "
f"n_gpu_layers={n_gpu_layers} and n_ctx={n_ctx}"
)
self.__models[model_path] = Llama(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
Expand All @@ -25,29 +46,26 @@ async def get_model(self, model_path, embedding=False, n_ctx=512, n_gpu_layers=0

return self.__models[model_path]

async def generate(
self, model_path, prompt, n_ctx=512, n_gpu_layers=0, stream=True
):
model = await self.get_model(
model_path=model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers
)

for chunk in model.create_completion(prompt=prompt, stream=stream):
yield chunk

async def chat(
self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_request
):
model = await self.get_model(
async def chat(self, model_path, n_ctx=512, n_gpu_layers=0, **chat_completion_request):
"""
Generates a chat completion using the specified model and request parameters.
"""
model = await self.__get_model(
model_path=model_path, n_ctx=n_ctx, n_gpu_layers=n_gpu_layers
)
return model.create_completion(**chat_completion_request)

async def embed(self, model_path, content):
model = await self.get_model(model_path=model_path, embedding=True)
"""
Generates an embedding for the given content using the specified model.
"""
model = await self.__get_model(model_path=model_path, embedding=True)
return model.embed(content)

async def close_models(self):
async def __close_models(self):
"""
Closes all open models and samplers
"""
for _, model in self.__models:
if model._sampler:
model._sampler.close()
Expand Down
4 changes: 3 additions & 1 deletion src/codegate/providers/litellmshim/generators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Any, AsyncIterator, Iterator
import asyncio

from pydantic import BaseModel

Expand Down Expand Up @@ -46,8 +47,9 @@ async def llamacpp_stream_generator(stream: Iterator[Any]) -> AsyncIterator[str]
if hasattr(chunk, "model_dump_json"):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
try:
chunk['content'] = chunk['choices'][0]['text']
chunk["content"] = chunk["choices"][0]["text"]
yield f"data:{json.dumps(chunk)}\n\n"
await asyncio.sleep(0)
except Exception as e:
yield f"data:{str(e)}\n\n"
except Exception as e:
Expand Down
7 changes: 3 additions & 4 deletions src/codegate/providers/llamacpp/completion_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

class LlamaCppCompletionHandler(BaseCompletionHandler):
def __init__(self, adapter: BaseAdapter):
self._config = Config.from_file('./config.yaml')
self._adapter = adapter
self.inference_engine = LlamaCppInferenceEngine()

Expand Down Expand Up @@ -53,9 +52,9 @@ async def execute_completion(
"""
Execute the completion request with LiteLLM's API
"""
response = await self.inference_engine.chat(self._config.chat_model_path,
self._config.chat_model_n_ctx,
self._config.chat_model_n_gpu_layers,
response = await self.inference_engine.chat(Config.get_config().chat_model_path,
Config.get_config().chat_model_n_ctx,
Config.get_config().chat_model_n_gpu_layers,
**request)
return response

Expand Down
1 change: 0 additions & 1 deletion tests/providers/litellmshim/test_litellmshim.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async def mock_stream() -> AsyncIterator[ModelResponse]:
"model": "gpt-3.5-turbo",
"stream": True,
}
api_key = "test-key"

# Execute
result_stream = await litellm_shim.execute_completion(data)
Expand Down
29 changes: 15 additions & 14 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import pytest

# @pytest.mark.asyncio
Expand All @@ -20,25 +19,27 @@
@pytest.mark.asyncio
async def test_chat(inference_engine) -> None:
"""Test chat completion."""
pass

# chat_request = {"prompt":
# "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n",
# "stream": True, "max_tokens": 4096, "top_k": 50, "temperature": 0}
chat_request = {
"prompt": "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n",
"stream": True,
"max_tokens": 4096,
"top_k": 50,
"temperature": 0,
}

# model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
# response = await inference_engine.chat(model_path, **chat_request)
model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf"
response = await inference_engine.chat(model_path, **chat_request)

# for chunk in response:
# assert chunk['choices'][0]['text'] is not None
for chunk in response:
assert chunk["choices"][0]["text"] is not None


@pytest.mark.asyncio
async def test_embed(inference_engine) -> None:
"""Test content embedding."""
pass

# content = "Can I use invokehttp package in my project?"
# model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
# vector = await inference_engine.embed(model_path, content=content)
# assert len(vector) == 384
content = "Can I use invokehttp package in my project?"
model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
vector = await inference_engine.embed(model_path, content=content)
assert len(vector) == 384

0 comments on commit 458deb8

Please sign in to comment.