Skip to content

Commit

Permalink
Merge pull request #106 from vendi-ai/introduce-middlewares-to-chat
Browse files Browse the repository at this point in the history
introduce middlewares to chat - Closes #105
  • Loading branch information
matankley authored Aug 28, 2023
2 parents ef4821b + 28aac35 commit 4480539
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 30 deletions.
38 changes: 19 additions & 19 deletions src/declarai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Chat(BaseChat, metaclass=ChatMeta):
is_declarai (bool): A class-level attribute indicating if the chat is of type 'declarai'. Always set to `True`.
llm_response (LLMResponse): The response from the LLM (Language Model).
This attribute is set during the execution of the chat.
_kwargs (Dict[str, Any]): A dictionary to store additional keyword arguments, used for passing kwargs between
_call_kwargs (Dict[str, Any]): A dictionary to store additional keyword arguments, used for passing kwargs between
the execution of the chat and the execution of the middlewares.
middlewares (List[TaskMiddleware] or None): Middlewares used for every iteration of the chat.
operator (BaseChatOperator): The operator used for the chat.
Expand All @@ -95,13 +95,13 @@ class Chat(BaseChat, metaclass=ChatMeta):

is_declarai = True
llm_response: LLMResponse
_kwargs: Dict[str, Any]
_call_kwargs: Dict[str, Any]

def __init__(
self,
*,
operator: BaseChatOperator,
middlewares: List[TaskMiddleware] = None,
middlewares: List[Type[TaskMiddleware]] = None,
chat_history: BaseChatMessageHistory = None,
greeting: str = None,
system: str = None,
Expand Down Expand Up @@ -136,7 +136,8 @@ def compile(self, **kwargs) -> List[Message]:
Returns: List[Message] - The compiled messages that will be sent to the LLM.
"""
compiled = self.operator.compile(messages=self._chat_history.history, **kwargs)
messages = kwargs.pop("messages", None) or self._chat_history.history
compiled = self.operator.compile(messages=messages, **kwargs)
return compiled

def add_message(self, message: str, role: MessageRole) -> None:
Expand All @@ -159,22 +160,19 @@ def _exec(self, kwargs) -> LLMResponse:
The raw response from the LLM, together with the metadata.
"""
self.llm_response = self.operator.predict(**kwargs)
self.add_message(self.llm_response.response, role=MessageRole.assistant)
if self.operator.parsed_send_func:
return self.operator.parsed_send_func.parse(self.llm_response.response)
return self.llm_response

def _exec_with_message_state(self, kwargs) -> Any:
"""
Executes the call to the LLM and adds the response to the chat history as an assistant message.
Args:
kwargs: Keyword arguments to pass to the LLM like `temperature`, `max_tokens`, etc.
Returns:
The parsed response from the LLM.
"""
raw_response = self._exec(kwargs).response
self.add_message(raw_response, role=MessageRole.assistant)
if self.operator.parsed_send_func:
return self.operator.parsed_send_func.parse(raw_response)
return raw_response
def _exec_middlewares(self, kwargs) -> Any:
if self.middlewares:
exec_with_middlewares = None
for middleware in self.middlewares:
exec_with_middlewares = middleware(self, self._call_kwargs)
if exec_with_middlewares:
return exec_with_middlewares()
return self._exec(kwargs)

def __call__(
self, *, messages: List[Message], llm_params: LLMParamsType = None, **kwargs
Expand All @@ -198,7 +196,9 @@ def __call__(
) # order is important! We prioritize runtime params that
if runtime_llm_params:
kwargs["llm_params"] = runtime_llm_params
return self._exec_with_message_state(kwargs)

self._call_kwargs = kwargs
return self._exec_middlewares(kwargs)

def send(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/declarai/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
Middlewares are used to extend the functionality of the Declarai execution flow.
"""

from .internal import LoggingMiddleware
4 changes: 2 additions & 2 deletions src/declarai/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base class for task middlewares.
"""
from abc import abstractmethod # pylint: disable=E0611
from typing import Any
from typing import Any, Dict

from declarai._base import TaskType

Expand All @@ -20,7 +20,7 @@ class TaskMiddleware:
_kwargs: The keyword arguments to pass to the task
"""

def __init__(self, task: TaskType, kwargs):
def __init__(self, task: TaskType, kwargs: Dict[str, Any] = None):
self._task = task
self._kwargs = kwargs

Expand Down
5 changes: 3 additions & 2 deletions src/declarai/middleware/internal/log_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def after(self, task: TaskType):
"task_name": task.__name__,
"llm_model": task.llm_response.model,
"template": str(task.compile()),
"call_kwargs": str(task._kwargs),
"compiled_template": str(task.compile(**task._kwargs)),
"call_kwargs": str(self._kwargs),
"compiled_template": str(task.compile(**self._kwargs)),
"result": task.llm_response.response,
"time": end_time,
}
logger.info(log_record)
print(log_record)
2 changes: 1 addition & 1 deletion src/declarai/middleware/third_party/wandb_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def after(self, task):
},
start_time_ms=self._start_time_ms,
end_time_ms=end_time_ms,
inputs={"query": task.compile(**task.call_kwargs)},
inputs={"query": task.compile(**self._kwargs)},
outputs={"response": task.llm_response.response},
)

Expand Down
13 changes: 7 additions & 6 deletions src/declarai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ class Task(BaseTask):
Attributes:
operator: the operator to use to interact with the LLM
llm_response (LLMResponse): the response from the LLM
_kwargs: the kwargs that were passed to the task are set as attributes on the task and passed to the middlewares
_call_kwargs: the kwargs that were passed to the task are set as attributes on the task and passed to the middlewares
"""

is_declarai = True
llm_response: LLMResponse
_kwargs: Dict[str, Any]
_call_kwargs: Dict[str, Any]

def __init__(
self, operator: BaseOperator, middlewares: List[TaskMiddleware] = None
self, operator: BaseOperator, middlewares: List[Type[TaskMiddleware]] = None
):
self.middlewares = middlewares
self.operator = operator
Expand Down Expand Up @@ -147,7 +147,7 @@ def _exec_middlewares(self, kwargs) -> Any:
if self.middlewares:
exec_with_middlewares = None
for middleware in self.middlewares:
exec_with_middlewares = middleware(self, self._kwargs)
exec_with_middlewares = middleware(self, self._call_kwargs)
if exec_with_middlewares:
return exec_with_middlewares()
return self._exec(kwargs)
Expand All @@ -162,13 +162,14 @@ def __call__(self, *, llm_params: LLMParamsType = None, **kwargs) -> Any:
Returns: the user defined return type of the task
"""
self._kwargs = kwargs
runtime_llm_params = (
llm_params or self.llm_params
) # order is important! We prioritize runtime params that
# were passed
if runtime_llm_params:
self._kwargs["llm_params"] = runtime_llm_params
kwargs["llm_params"] = runtime_llm_params

self._call_kwargs = kwargs
return self._exec_middlewares(kwargs)


Expand Down

0 comments on commit 4480539

Please sign in to comment.