-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
241 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .inference_engine import LlamaCppInferenceEngine | ||
|
||
__all__ = [LlamaCppInferenceEngine] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from llama_cpp import Llama | ||
|
||
|
||
class LlamaCppInferenceEngine(): | ||
_inference_engine = None | ||
|
||
def __new__(cls): | ||
if cls._inference_engine is None: | ||
cls._inference_engine = super().__new__(cls) | ||
return cls._inference_engine | ||
|
||
def __init__(self): | ||
if not hasattr(self, 'models'): | ||
self.__models = {} | ||
|
||
async def get_model(self, model_path, embedding=False, n_ctx=512): | ||
if model_path not in self.__models: | ||
self.__models[model_path] = Llama( | ||
model_path=model_path, n_gpu_layers=0, verbose=False, n_ctx=n_ctx, embedding=embedding) | ||
|
||
return self.__models[model_path] | ||
|
||
async def generate(self, model_path, prompt, stream=True): | ||
model = await self.get_model(model_path=model_path, n_ctx=32768) | ||
|
||
for chunk in model.create_completion(prompt=prompt, stream=stream): | ||
yield chunk | ||
|
||
async def chat(self, model_path, **chat_completion_request): | ||
model = await self.get_model(model_path=model_path, n_ctx=32768) | ||
return model.create_completion(**chat_completion_request) | ||
|
||
async def embed(self, model_path, content): | ||
model = await self.get_model(model_path=model_path, embedding=True) | ||
return model.embed(content) | ||
|
||
async def close_models(self): | ||
for _, model in self.__models: | ||
if model._sampler: | ||
model._sampler.close() | ||
model.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from .adapter import BaseAdapter | ||
from .generators import anthropic_stream_generator, sse_stream_generator | ||
from .generators import anthropic_stream_generator, sse_stream_generator, llamacpp_stream_generator | ||
from .litellmshim import LiteLLmShim | ||
|
||
__all__ = [ | ||
"sse_stream_generator", | ||
"anthropic_stream_generator", | ||
"llamacpp_stream_generator", | ||
"LiteLLmShim", | ||
"BaseAdapter", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from typing import Any, AsyncIterator, Dict, Optional | ||
|
||
from litellm import ChatCompletionRequest, ModelResponse | ||
|
||
from ..base import StreamGenerator | ||
from ..litellmshim import llamacpp_stream_generator | ||
from ..litellmshim.litellmshim import BaseAdapter | ||
|
||
|
||
class LlamaCppAdapter(BaseAdapter): | ||
""" | ||
This is just a wrapper around LiteLLM's adapter class interface that passes | ||
through the input and output as-is - LiteLLM's API expects OpenAI's API | ||
format. | ||
""" | ||
def __init__(self, stream_generator: StreamGenerator = llamacpp_stream_generator): | ||
super().__init__(stream_generator) | ||
|
||
def translate_completion_input_params( | ||
self, kwargs: Dict | ||
) -> Optional[ChatCompletionRequest]: | ||
try: | ||
return ChatCompletionRequest(**kwargs) | ||
except Exception as e: | ||
raise ValueError(f"Invalid completion parameters: {str(e)}") | ||
|
||
def translate_completion_output_params(self, response: ModelResponse) -> Any: | ||
return response | ||
|
||
def translate_completion_output_params_streaming( | ||
self, completion_stream: AsyncIterator[ModelResponse] | ||
) -> AsyncIterator[ModelResponse]: | ||
return completion_stream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Any, AsyncIterator, Dict | ||
|
||
from fastapi.responses import StreamingResponse | ||
from litellm import ModelResponse, acompletion | ||
|
||
from ..base import BaseCompletionHandler | ||
from .adapter import BaseAdapter | ||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
|
||
|
||
class LlamaCppCompletionHandler(BaseCompletionHandler): | ||
def __init__(self, adapter: BaseAdapter): | ||
self._adapter = adapter | ||
self.inference_engine = LlamaCppInferenceEngine() | ||
|
||
async def complete(self, data: Dict, api_key: str) -> AsyncIterator[Any]: | ||
""" | ||
Translate the input parameters to LiteLLM's format using the adapter and | ||
call the LiteLLM API. Then translate the response back to our format using | ||
the adapter. | ||
""" | ||
completion_request = self._adapter.translate_completion_input_params( | ||
data) | ||
if completion_request is None: | ||
raise Exception("Couldn't translate the request") | ||
|
||
# Replace n_predict option with max_tokens | ||
if 'n_predict' in completion_request: | ||
completion_request['max_tokens'] = completion_request['n_predict'] | ||
del completion_request['n_predict'] | ||
|
||
response = await self.inference_engine.chat('./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf', **completion_request) | ||
|
||
if isinstance(response, ModelResponse): | ||
return self._adapter.translate_completion_output_params(response) | ||
return self._adapter.translate_completion_output_params_streaming(response) | ||
|
||
def create_streaming_response( | ||
self, stream: AsyncIterator[Any] | ||
) -> StreamingResponse: | ||
""" | ||
Create a streaming response from a stream generator. The StreamingResponse | ||
is the format that FastAPI expects for streaming responses. | ||
""" | ||
return StreamingResponse( | ||
self._adapter.stream_generator(stream), | ||
headers={ | ||
"Cache-Control": "no-cache", | ||
"Connection": "keep-alive", | ||
"Transfer-Encoding": "chunked", | ||
}, | ||
status_code=200, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import json | ||
|
||
from fastapi import Header, HTTPException, Request | ||
|
||
from ..base import BaseProvider | ||
from .completion_handler import LlamaCppCompletionHandler | ||
from .adapter import LlamaCppAdapter | ||
|
||
|
||
class LlamaCppProvider(BaseProvider): | ||
def __init__(self): | ||
adapter = LlamaCppAdapter() | ||
completion_handler = LlamaCppCompletionHandler(adapter) | ||
super().__init__(completion_handler) | ||
|
||
def _setup_routes(self): | ||
""" | ||
Sets up the /chat route for the provider as expected by the | ||
Llama API. Extracts the API key from the "Authorization" header and | ||
passes it to the completion handler. | ||
""" | ||
@self.router.post("/completion") | ||
async def create_completion( | ||
request: Request, | ||
): | ||
body = await request.body() | ||
data = json.loads(body) | ||
|
||
stream = await self.complete(data, None) | ||
return self._completion_handler.create_streaming_response(stream) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
|
||
import pytest | ||
import pytest | ||
|
||
|
||
# @pytest.mark.asyncio | ||
# async def test_generate(inference_engine) -> None: | ||
# """Test code generation.""" | ||
|
||
# prompt = ''' | ||
# import requests | ||
|
||
# # Function to call API over http | ||
# def call_api(url): | ||
# ''' | ||
# model_path = "./models/qwen2.5-coder-1.5B.q5_k_m.gguf" | ||
|
||
# async for chunk in inference_engine.generate(model_path, prompt): | ||
# print(chunk) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_chat(inference_engine) -> None: | ||
"""Test chat completion.""" | ||
|
||
chat_request = {"prompt": "<|im_start|>user\\nhello<|im_end|>\\n<|im_start|>assistant\\n", | ||
"stream": True, "max_tokens": 4096, "top_k": 50, "temperature": 0} | ||
|
||
model_path = "./models/qwen2.5-coder-1.5b-instruct-q5_k_m.gguf" | ||
response = await inference_engine.chat(model_path, **chat_request) | ||
|
||
for chunk in response: | ||
assert chunk['choices'][0]['text'] is not None | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_embed(inference_engine) -> None: | ||
"""Test content embedding.""" | ||
|
||
content = "Can I use invokehttp package in my project?" | ||
model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf" | ||
vector = await inference_engine.embed(model_path, content=content) | ||
assert len(vector) == 384 |