Skip to content

Commit

Permalink
feat: support ainvoke, astream (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint authored Jul 8, 2024
1 parent 58ad941 commit c710736
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 39 deletions.
47 changes: 8 additions & 39 deletions langchain_llamacpp_chat_model/llama_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,21 @@
from llama_cpp import Llama

from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import Field


class LLamaOpenAIClientProxy:
def __init__(self, llama: Llama):
self.llama = llama

def create(self, **kwargs):
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()


class LlamaCreateContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

def __enter__(self):
return self()

def __exit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False
from .llama_client_proxy import LLamaOpenAIClientProxy
from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy


class LlamaChatModel(BaseChatOpenAI):
model_name: str = Field(default="", alias="model")
model_name: str = "unknown"

def __init__(
self,
llama: Llama,
**kwargs,
):
super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama))
super().__init__(
**kwargs,
client=LLamaOpenAIClientProxy(llama=llama),
async_client=LLamaOpenAIClientAsyncProxy(llama=llama),
)
52 changes: 52 additions & 0 deletions langchain_llamacpp_chat_model/llama_client_async_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from llama_cpp import Llama


async def to_async_iterator(iterator):
for item in iterator:
yield item


class LlamaCreateAsyncContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self.response)
except Exception:
raise StopAsyncIteration()

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

async def __aenter__(self):
return self()

async def __aexit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False


class LLamaOpenAIClientAsyncProxy:
def __init__(self, llama: Llama):
self.llama = llama

async def create(self, **kwargs):
proxy = LlamaCreateAsyncContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()
38 changes: 38 additions & 0 deletions langchain_llamacpp_chat_model/llama_client_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from llama_cpp import Llama


class LlamaCreateContextManager:

def __init__(self, llama: Llama, **kwargs):
self.llama = llama
self.kwargs = kwargs
self.response = None

def __call__(self):
self.kwargs.pop("n", None)
self.kwargs.pop(
"parallel_tool_calls", None
) # LLamaCPP does not support parallel tool calls for now

self.response = self.llama.create_chat_completion(**self.kwargs)
return self.response

def __enter__(self):
return self()

def __exit__(self, exception_type, exception_value, exception_traceback):
if hasattr(self.response, "close"):
self.response.close()
return False


class LLamaOpenAIClientProxy:
def __init__(self, llama: Llama):
self.llama = llama

def create(self, **kwargs):
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
if "stream" in kwargs and kwargs["stream"] is True:
return proxy
else:
return proxy()
44 changes: 44 additions & 0 deletions tests/test_functional/test_ainvoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from llama_cpp import Llama
import pytest
from langchain_core.messages import AIMessage, HumanMessage

from langchain_llamacpp_chat_model import LlamaChatModel

from langchain_core.pydantic_v1 import BaseModel, Field
from tests.test_functional.models_configuration import create_llama, models_to_test


class Joke(BaseModel):
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")


class TestAInvoke:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_ainvoke(self, instance: LlamaChatModel):
result = await instance.ainvoke("Say Hi!")

assert len(result.content) > 0

@pytest.mark.asyncio
async def test_conversation_memory(self, instance: LlamaChatModel):
result = await instance.ainvoke(
input=[
HumanMessage(content="Remember that I like bananas"),
AIMessage(content="Okay"),
HumanMessage(content="What do I like?"),
]
)

assert "banana" in result.content
47 changes: 47 additions & 0 deletions tests/test_functional/test_astream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from llama_cpp import Llama
import pytest
from langchain_core.messages import AIMessage, HumanMessage

from langchain_llamacpp_chat_model import LlamaChatModel
from tests.test_functional.models_configuration import create_llama, models_to_test


class TestAStream:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_astream(self, instance: LlamaChatModel):

chunks = []
async for chunk in instance.astream("Say Hi!"):
chunks.append(chunk)

final_content = "".join(chunk.content for chunk in chunks)

assert len(final_content) > 0

@pytest.mark.asyncio
async def test_conversation_memory(self, instance: LlamaChatModel):
stream = instance.astream(
input=[
HumanMessage(content="Remember that I like bananas"),
AIMessage(content="Okay"),
HumanMessage(content="What do I like?"),
]
)

final_content = ""
async for token in stream:
final_content += token.content

assert len(final_content) > 0
assert "banana" in final_content
19 changes: 19 additions & 0 deletions tests/test_functional/test_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,22 @@ def magic_number_tool(input: int) -> int:
result = llm_with_tool.invoke("What is the magic mumber of 2?")

assert result.tool_calls[0]["name"] == "magic_number_tool"


class TestAInvoke:

@pytest.fixture(
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
)
def llama(self, request) -> Llama:
return create_llama(request)

@pytest.fixture
def instance(self, llama):
return LlamaChatModel(llama=llama)

@pytest.mark.asyncio
async def test_ainvoke(self, instance: LlamaChatModel):
result = await instance.ainvoke("Say Hi!")

assert len(result.content) > 0

0 comments on commit c710736

Please sign in to comment.