generated from samuelint/python-poetry-template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
208 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,21 @@ | ||
from llama_cpp import Llama | ||
|
||
from langchain_openai.chat_models.base import BaseChatOpenAI | ||
from pydantic import Field | ||
|
||
|
||
class LLamaOpenAIClientProxy: | ||
def __init__(self, llama: Llama): | ||
self.llama = llama | ||
|
||
def create(self, **kwargs): | ||
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs) | ||
if "stream" in kwargs and kwargs["stream"] is True: | ||
return proxy | ||
else: | ||
return proxy() | ||
|
||
|
||
class LlamaCreateContextManager: | ||
|
||
def __init__(self, llama: Llama, **kwargs): | ||
self.llama = llama | ||
self.kwargs = kwargs | ||
self.response = None | ||
|
||
def __call__(self): | ||
self.kwargs.pop("n", None) | ||
self.kwargs.pop( | ||
"parallel_tool_calls", None | ||
) # LLamaCPP does not support parallel tool calls for now | ||
|
||
self.response = self.llama.create_chat_completion(**self.kwargs) | ||
return self.response | ||
|
||
def __enter__(self): | ||
return self() | ||
|
||
def __exit__(self, exception_type, exception_value, exception_traceback): | ||
if hasattr(self.response, "close"): | ||
self.response.close() | ||
return False | ||
from .llama_client_proxy import LLamaOpenAIClientProxy | ||
from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy | ||
|
||
|
||
class LlamaChatModel(BaseChatOpenAI): | ||
model_name: str = Field(default="", alias="model") | ||
model_name: str = "unknown" | ||
|
||
def __init__( | ||
self, | ||
llama: Llama, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama)) | ||
super().__init__( | ||
**kwargs, | ||
client=LLamaOpenAIClientProxy(llama=llama), | ||
async_client=LLamaOpenAIClientAsyncProxy(llama=llama), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from llama_cpp import Llama | ||
|
||
|
||
async def to_async_iterator(iterator): | ||
for item in iterator: | ||
yield item | ||
|
||
|
||
class LlamaCreateAsyncContextManager: | ||
|
||
def __init__(self, llama: Llama, **kwargs): | ||
self.llama = llama | ||
self.kwargs = kwargs | ||
self.response = None | ||
|
||
def __aiter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
try: | ||
return next(self.response) | ||
except Exception: | ||
raise StopAsyncIteration() | ||
|
||
def __call__(self): | ||
self.kwargs.pop("n", None) | ||
self.kwargs.pop( | ||
"parallel_tool_calls", None | ||
) # LLamaCPP does not support parallel tool calls for now | ||
|
||
self.response = self.llama.create_chat_completion(**self.kwargs) | ||
return self.response | ||
|
||
async def __aenter__(self): | ||
return self() | ||
|
||
async def __aexit__(self, exception_type, exception_value, exception_traceback): | ||
if hasattr(self.response, "close"): | ||
self.response.close() | ||
return False | ||
|
||
|
||
class LLamaOpenAIClientAsyncProxy: | ||
def __init__(self, llama: Llama): | ||
self.llama = llama | ||
|
||
async def create(self, **kwargs): | ||
proxy = LlamaCreateAsyncContextManager(llama=self.llama, **kwargs) | ||
if "stream" in kwargs and kwargs["stream"] is True: | ||
return proxy | ||
else: | ||
return proxy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from llama_cpp import Llama | ||
|
||
|
||
class LlamaCreateContextManager: | ||
|
||
def __init__(self, llama: Llama, **kwargs): | ||
self.llama = llama | ||
self.kwargs = kwargs | ||
self.response = None | ||
|
||
def __call__(self): | ||
self.kwargs.pop("n", None) | ||
self.kwargs.pop( | ||
"parallel_tool_calls", None | ||
) # LLamaCPP does not support parallel tool calls for now | ||
|
||
self.response = self.llama.create_chat_completion(**self.kwargs) | ||
return self.response | ||
|
||
def __enter__(self): | ||
return self() | ||
|
||
def __exit__(self, exception_type, exception_value, exception_traceback): | ||
if hasattr(self.response, "close"): | ||
self.response.close() | ||
return False | ||
|
||
|
||
class LLamaOpenAIClientProxy: | ||
def __init__(self, llama: Llama): | ||
self.llama = llama | ||
|
||
def create(self, **kwargs): | ||
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs) | ||
if "stream" in kwargs and kwargs["stream"] is True: | ||
return proxy | ||
else: | ||
return proxy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from llama_cpp import Llama | ||
import pytest | ||
from langchain_core.messages import AIMessage, HumanMessage | ||
|
||
from langchain_llamacpp_chat_model import LlamaChatModel | ||
|
||
from langchain_core.pydantic_v1 import BaseModel, Field | ||
from tests.test_functional.models_configuration import create_llama, models_to_test | ||
|
||
|
||
class Joke(BaseModel): | ||
setup: str = Field(description="The setup of the joke") | ||
punchline: str = Field(description="The punchline to the joke") | ||
|
||
|
||
class TestAInvoke: | ||
|
||
@pytest.fixture( | ||
params=models_to_test, ids=[config["repo_id"] for config in models_to_test] | ||
) | ||
def llama(self, request) -> Llama: | ||
return create_llama(request) | ||
|
||
@pytest.fixture | ||
def instance(self, llama): | ||
return LlamaChatModel(llama=llama) | ||
|
||
@pytest.mark.asyncio | ||
async def test_ainvoke(self, instance: LlamaChatModel): | ||
result = await instance.ainvoke("Say Hi!") | ||
|
||
assert len(result.content) > 0 | ||
|
||
@pytest.mark.asyncio | ||
async def test_conversation_memory(self, instance: LlamaChatModel): | ||
result = await instance.ainvoke( | ||
input=[ | ||
HumanMessage(content="Remember that I like bananas"), | ||
AIMessage(content="Okay"), | ||
HumanMessage(content="What do I like?"), | ||
] | ||
) | ||
|
||
assert "banana" in result.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from llama_cpp import Llama | ||
import pytest | ||
from langchain_core.messages import AIMessage, HumanMessage | ||
|
||
from langchain_llamacpp_chat_model import LlamaChatModel | ||
from tests.test_functional.models_configuration import create_llama, models_to_test | ||
|
||
|
||
class TestAStream: | ||
|
||
@pytest.fixture( | ||
params=models_to_test, ids=[config["repo_id"] for config in models_to_test] | ||
) | ||
def llama(self, request) -> Llama: | ||
return create_llama(request) | ||
|
||
@pytest.fixture | ||
def instance(self, llama): | ||
return LlamaChatModel(llama=llama) | ||
|
||
@pytest.mark.asyncio | ||
async def test_astream(self, instance: LlamaChatModel): | ||
|
||
chunks = [] | ||
async for chunk in instance.astream("Say Hi!"): | ||
chunks.append(chunk) | ||
|
||
final_content = "".join(chunk.content for chunk in chunks) | ||
|
||
assert len(final_content) > 0 | ||
|
||
@pytest.mark.asyncio | ||
async def test_conversation_memory(self, instance: LlamaChatModel): | ||
stream = instance.astream( | ||
input=[ | ||
HumanMessage(content="Remember that I like bananas"), | ||
AIMessage(content="Okay"), | ||
HumanMessage(content="What do I like?"), | ||
] | ||
) | ||
|
||
final_content = "" | ||
async for token in stream: | ||
final_content += token.content | ||
|
||
assert len(final_content) > 0 | ||
assert "banana" in final_content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters