From def5290caec8ee1387c90582d79c82cacbbe8d89 Mon Sep 17 00:00:00 2001 From: Samuel Date: Mon, 8 Jul 2024 09:12:09 -0400 Subject: [PATCH] feat: support ainvoke, astream --- .../llama_chat_model.py | 47 +++-------------- .../llama_client_async_proxy.py | 52 +++++++++++++++++++ .../llama_client_proxy.py | 38 ++++++++++++++ tests/test_functional/test_ainvoke.py | 44 ++++++++++++++++ tests/test_functional/test_astream.py | 47 +++++++++++++++++ tests/test_functional/test_invoke.py | 19 +++++++ 6 files changed, 208 insertions(+), 39 deletions(-) create mode 100644 langchain_llamacpp_chat_model/llama_client_async_proxy.py create mode 100644 langchain_llamacpp_chat_model/llama_client_proxy.py create mode 100644 tests/test_functional/test_ainvoke.py create mode 100644 tests/test_functional/test_astream.py diff --git a/langchain_llamacpp_chat_model/llama_chat_model.py b/langchain_llamacpp_chat_model/llama_chat_model.py index e3324c0..0c03cc2 100644 --- a/langchain_llamacpp_chat_model/llama_chat_model.py +++ b/langchain_llamacpp_chat_model/llama_chat_model.py @@ -1,52 +1,21 @@ from llama_cpp import Llama from langchain_openai.chat_models.base import BaseChatOpenAI -from pydantic import Field - -class LLamaOpenAIClientProxy: - def __init__(self, llama: Llama): - self.llama = llama - - def create(self, **kwargs): - proxy = LlamaCreateContextManager(llama=self.llama, **kwargs) - if "stream" in kwargs and kwargs["stream"] is True: - return proxy - else: - return proxy() - - -class LlamaCreateContextManager: - - def __init__(self, llama: Llama, **kwargs): - self.llama = llama - self.kwargs = kwargs - self.response = None - - def __call__(self): - self.kwargs.pop("n", None) - self.kwargs.pop( - "parallel_tool_calls", None - ) # LLamaCPP does not support parallel tool calls for now - - self.response = self.llama.create_chat_completion(**self.kwargs) - return self.response - - def __enter__(self): - return self() - - def __exit__(self, exception_type, exception_value, exception_traceback): - if hasattr(self.response, "close"): - self.response.close() - return False +from .llama_client_proxy import LLamaOpenAIClientProxy +from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy class LlamaChatModel(BaseChatOpenAI): - model_name: str = Field(default="", alias="model") + model_name: str = "unknown" def __init__( self, llama: Llama, **kwargs, ): - super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama)) + super().__init__( + **kwargs, + client=LLamaOpenAIClientProxy(llama=llama), + async_client=LLamaOpenAIClientAsyncProxy(llama=llama), + ) diff --git a/langchain_llamacpp_chat_model/llama_client_async_proxy.py b/langchain_llamacpp_chat_model/llama_client_async_proxy.py new file mode 100644 index 0000000..f4fb1cc --- /dev/null +++ b/langchain_llamacpp_chat_model/llama_client_async_proxy.py @@ -0,0 +1,52 @@ +from llama_cpp import Llama + + +async def to_async_iterator(iterator): + for item in iterator: + yield item + + +class LlamaCreateAsyncContextManager: + + def __init__(self, llama: Llama, **kwargs): + self.llama = llama + self.kwargs = kwargs + self.response = None + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.response) + except Exception: + raise StopAsyncIteration() + + def __call__(self): + self.kwargs.pop("n", None) + self.kwargs.pop( + "parallel_tool_calls", None + ) # LLamaCPP does not support parallel tool calls for now + + self.response = self.llama.create_chat_completion(**self.kwargs) + return self.response + + async def __aenter__(self): + return self() + + async def __aexit__(self, exception_type, exception_value, exception_traceback): + if hasattr(self.response, "close"): + self.response.close() + return False + + +class LLamaOpenAIClientAsyncProxy: + def __init__(self, llama: Llama): + self.llama = llama + + async def create(self, **kwargs): + proxy = LlamaCreateAsyncContextManager(llama=self.llama, **kwargs) + if "stream" in kwargs and kwargs["stream"] is True: + return proxy + else: + return proxy() diff --git a/langchain_llamacpp_chat_model/llama_client_proxy.py b/langchain_llamacpp_chat_model/llama_client_proxy.py new file mode 100644 index 0000000..54b2d2d --- /dev/null +++ b/langchain_llamacpp_chat_model/llama_client_proxy.py @@ -0,0 +1,38 @@ +from llama_cpp import Llama + + +class LlamaCreateContextManager: + + def __init__(self, llama: Llama, **kwargs): + self.llama = llama + self.kwargs = kwargs + self.response = None + + def __call__(self): + self.kwargs.pop("n", None) + self.kwargs.pop( + "parallel_tool_calls", None + ) # LLamaCPP does not support parallel tool calls for now + + self.response = self.llama.create_chat_completion(**self.kwargs) + return self.response + + def __enter__(self): + return self() + + def __exit__(self, exception_type, exception_value, exception_traceback): + if hasattr(self.response, "close"): + self.response.close() + return False + + +class LLamaOpenAIClientProxy: + def __init__(self, llama: Llama): + self.llama = llama + + def create(self, **kwargs): + proxy = LlamaCreateContextManager(llama=self.llama, **kwargs) + if "stream" in kwargs and kwargs["stream"] is True: + return proxy + else: + return proxy() diff --git a/tests/test_functional/test_ainvoke.py b/tests/test_functional/test_ainvoke.py new file mode 100644 index 0000000..c35af5f --- /dev/null +++ b/tests/test_functional/test_ainvoke.py @@ -0,0 +1,44 @@ +from llama_cpp import Llama +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_llamacpp_chat_model import LlamaChatModel + +from langchain_core.pydantic_v1 import BaseModel, Field +from tests.test_functional.models_configuration import create_llama, models_to_test + + +class Joke(BaseModel): + setup: str = Field(description="The setup of the joke") + punchline: str = Field(description="The punchline to the joke") + + +class TestAInvoke: + + @pytest.fixture( + params=models_to_test, ids=[config["repo_id"] for config in models_to_test] + ) + def llama(self, request) -> Llama: + return create_llama(request) + + @pytest.fixture + def instance(self, llama): + return LlamaChatModel(llama=llama) + + @pytest.mark.asyncio + async def test_ainvoke(self, instance: LlamaChatModel): + result = await instance.ainvoke("Say Hi!") + + assert len(result.content) > 0 + + @pytest.mark.asyncio + async def test_conversation_memory(self, instance: LlamaChatModel): + result = await instance.ainvoke( + input=[ + HumanMessage(content="Remember that I like bananas"), + AIMessage(content="Okay"), + HumanMessage(content="What do I like?"), + ] + ) + + assert "banana" in result.content diff --git a/tests/test_functional/test_astream.py b/tests/test_functional/test_astream.py new file mode 100644 index 0000000..c2dcd1a --- /dev/null +++ b/tests/test_functional/test_astream.py @@ -0,0 +1,47 @@ +from llama_cpp import Llama +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from langchain_llamacpp_chat_model import LlamaChatModel +from tests.test_functional.models_configuration import create_llama, models_to_test + + +class TestAStream: + + @pytest.fixture( + params=models_to_test, ids=[config["repo_id"] for config in models_to_test] + ) + def llama(self, request) -> Llama: + return create_llama(request) + + @pytest.fixture + def instance(self, llama): + return LlamaChatModel(llama=llama) + + @pytest.mark.asyncio + async def test_astream(self, instance: LlamaChatModel): + + chunks = [] + async for chunk in instance.astream("Say Hi!"): + chunks.append(chunk) + + final_content = "".join(chunk.content for chunk in chunks) + + assert len(final_content) > 0 + + @pytest.mark.asyncio + async def test_conversation_memory(self, instance: LlamaChatModel): + stream = instance.astream( + input=[ + HumanMessage(content="Remember that I like bananas"), + AIMessage(content="Okay"), + HumanMessage(content="What do I like?"), + ] + ) + + final_content = "" + async for token in stream: + final_content += token.content + + assert len(final_content) > 0 + assert "banana" in final_content diff --git a/tests/test_functional/test_invoke.py b/tests/test_functional/test_invoke.py index 55d841e..827ed57 100644 --- a/tests/test_functional/test_invoke.py +++ b/tests/test_functional/test_invoke.py @@ -67,3 +67,22 @@ def magic_number_tool(input: int) -> int: result = llm_with_tool.invoke("What is the magic mumber of 2?") assert result.tool_calls[0]["name"] == "magic_number_tool" + + +class TestAInvoke: + + @pytest.fixture( + params=models_to_test, ids=[config["repo_id"] for config in models_to_test] + ) + def llama(self, request) -> Llama: + return create_llama(request) + + @pytest.fixture + def instance(self, llama): + return LlamaChatModel(llama=llama) + + @pytest.mark.asyncio + async def test_ainvoke(self, instance: LlamaChatModel): + result = await instance.ainvoke("Say Hi!") + + assert len(result.content) > 0