Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support ainvoke, astream #2

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 8 additions & 39 deletions langchain_llamacpp_chat_model/llama_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,21 @@
from llama_cpp import Llama

from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import Field


class LLamaOpenAIClientProxy:
def __init__(self, llama: Llama):
self.llama = llama

def create(self, **kwargs):
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()


class LlamaCreateContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

def __enter__(self):
return self()

def __exit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False
from .llama_client_proxy import LLamaOpenAIClientProxy
from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy


class LlamaChatModel(BaseChatOpenAI):
model_name: str = Field(default="", alias="model")
model_name: str = "unknown"

def __init__(
self,
llama: Llama,
**kwargs,
):
super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama))
super().__init__(
**kwargs,
client=LLamaOpenAIClientProxy(llama=llama),
async_client=LLamaOpenAIClientAsyncProxy(llama=llama),
)
52 changes: 52 additions & 0 deletions langchain_llamacpp_chat_model/llama_client_async_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from llama_cpp import Llama


async def to_async_iterator(iterator):
for item in iterator:
yield item


class LlamaCreateAsyncContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.response)
except Exception:
raise StopAsyncIteration()

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

async def __aenter__(self):
return self()

async def __aexit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False


class LLamaOpenAIClientAsyncProxy:
def __init__(self, llama: Llama):
self.llama = llama

async def create(self, **kwargs):
proxy = LlamaCreateAsyncContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()
38 changes: 38 additions & 0 deletions langchain_llamacpp_chat_model/llama_client_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from llama_cpp import Llama


class LlamaCreateContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

def __enter__(self):
return self()

def __exit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False


class LLamaOpenAIClientProxy:
def __init__(self, llama: Llama):
self.llama = llama

def create(self, **kwargs):
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()
44 changes: 44 additions & 0 deletions tests/test_functional/test_ainvoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from llama_cpp import Llama
import pytest
from langchain_core.messages import AIMessage, HumanMessage

from langchain_llamacpp_chat_model import LlamaChatModel

from langchain_core.pydantic_v1 import BaseModel, Field
from tests.test_functional.models_configuration import create_llama, models_to_test


class Joke(BaseModel):
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")


class TestAInvoke:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_ainvoke(self, instance: LlamaChatModel):
result = await instance.ainvoke("Say Hi!")

assert len(result.content) > 0

@pytest.mark.asyncio
async def test_conversation_memory(self, instance: LlamaChatModel):
result = await instance.ainvoke(
input=[
HumanMessage(content="Remember that I like bananas"),
AIMessage(content="Okay"),
HumanMessage(content="What do I like?"),
]
)

assert "banana" in result.content
47 changes: 47 additions & 0 deletions tests/test_functional/test_astream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from llama_cpp import Llama
import pytest
from langchain_core.messages import AIMessage, HumanMessage

from langchain_llamacpp_chat_model import LlamaChatModel
from tests.test_functional.models_configuration import create_llama, models_to_test


class TestAStream:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_astream(self, instance: LlamaChatModel):

chunks = []
async for chunk in instance.astream("Say Hi!"):
chunks.append(chunk)

final_content = "".join(chunk.content for chunk in chunks)

assert len(final_content) > 0

@pytest.mark.asyncio
async def test_conversation_memory(self, instance: LlamaChatModel):
stream = instance.astream(
input=[
HumanMessage(content="Remember that I like bananas"),
AIMessage(content="Okay"),
HumanMessage(content="What do I like?"),
]
)

final_content = ""
async for token in stream:
final_content += token.content

assert len(final_content) > 0
assert "banana" in final_content
19 changes: 19 additions & 0 deletions tests/test_functional/test_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ def magic_number_tool(input: int) -> int:
result = llm_with_tool.invoke("What is the magic mumber of 2?")

assert result.tool_calls[0]["name"] == "magic_number_tool"


class TestAInvoke:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_ainvoke(self, instance: LlamaChatModel):
result = await instance.ainvoke("Say Hi!")

assert len(result.content) > 0