Skip to content

Commit

Permalink
add streaming option for openai
Browse files Browse the repository at this point in the history
  • Loading branch information
matankley committed Aug 29, 2023
1 parent 4480539 commit 5d58f7d
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 21 deletions.
10 changes: 5 additions & 5 deletions src/declarai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def add_message(self, message: str, role: MessageRole) -> None:
"""
self._chat_history.add_message(Message(message=message, role=role))

def _exec(self, kwargs) -> LLMResponse:
def _exec(self, kwargs) -> Any:
"""
Executes the call to the LLM.
Expand All @@ -161,9 +161,11 @@ def _exec(self, kwargs) -> LLMResponse:
"""
self.llm_response = self.operator.predict(**kwargs)
self.add_message(self.llm_response.response, role=MessageRole.assistant)

if self.operator.parsed_send_func:
return self.operator.parsed_send_func.parse(self.llm_response.response)
return self.llm_response

return self.llm_response.response

def _exec_middlewares(self, kwargs) -> Any:
if self.middlewares:
Expand Down Expand Up @@ -324,9 +326,7 @@ def wrap(cls) -> Type[Chat]:

_decorator_kwargs = dict(
operator=operator_type(
llm=self.llm,
parsed=parsed_cls,
llm_params=llm_params,
llm=self.llm, parsed=parsed_cls, llm_params=llm_params
),
middlewares=middlewares,
chat_history=chat_history,
Expand Down
2 changes: 2 additions & 0 deletions src/declarai/operators/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class LLMResponse(BaseModel):
prompt_tokens: Optional[int] = None
completion_tokens: Optional[int] = None
total_tokens: Optional[int] = None
role: str = "assistant"
raw_response: Optional[dict] = None


class BaseLLMParams(TypedDict):
Expand Down
77 changes: 67 additions & 10 deletions src/declarai/operators/openai_operators/openai_llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
LLM implementation for OpenAI
"""
from typing import List, Optional
from typing import List, Optional, Iterator, Union

from openai.openai_object import OpenAIObject
import openai

from declarai.operators import BaseLLM, BaseLLMParams, LLMResponse, Message
Expand Down Expand Up @@ -51,15 +52,24 @@ def __init__(
self._kwargs = {
"headers": headers,
"timeout": timeout,
"stream": stream,
"request_timeout": request_timeout,
**kwargs,
}
self.openai = openai
self.api_key = api_key
self.api_type = api_type
self.stream = stream
self.model = model_name

@property
def streaming(self) -> bool:
"""
Returns whether the LLM is streaming or not
Returns:
bool: True if the LLM is streaming, False otherwise
"""
return self.stream

def predict(
self,
messages: List[Message],
Expand All @@ -69,10 +79,12 @@ def predict(
top_p: float = 1,
frequency_penalty: int = 0,
presence_penalty: int = 0,
) -> LLMResponse:
stream: bool = None,
) -> Union[Iterator[LLMResponse], LLMResponse]:
"""
Predicts the next message using OpenAI
Args:
stream: if to stream the response
messages: List of messages that are used as context for the prediction
model: the model to use for the prediction
temperature: the temperature to use for the prediction
Expand All @@ -85,6 +97,8 @@ def predict(
LLMResponse: The response from the LLM
"""
if stream is None:
stream = self.stream
openai_messages = [{"role": m.role, "content": m.message} for m in messages]
res = self.openai.ChatCompletion.create(
model=model or self.model,
Expand All @@ -96,15 +110,22 @@ def predict(
presence_penalty=presence_penalty,
api_key=self.api_key,
api_type=self.api_type,
stream=stream,
**self._kwargs,
)
return LLMResponse(
response=res.choices[0]["message"]["content"],
model=res.model,
prompt_tokens=res["usage"]["prompt_tokens"],
completion_tokens=res["usage"]["completion_tokens"],
total_tokens=res["usage"]["total_tokens"],
)

if stream:
return handle_streaming_response(res)

else:
return LLMResponse(
response=res.choices[0]["message"]["content"],
model=res.model,
prompt_tokens=res["usage"]["prompt_tokens"],
completion_tokens=res["usage"]["completion_tokens"],
total_tokens=res["usage"]["total_tokens"],
raw_response=res.to_dict_recursive(),
)


@register_llm(provider="openai")
Expand Down Expand Up @@ -214,3 +235,39 @@ def __init__(
api_version=api_version,
api_base=api_base,
)


def handle_streaming_response(api_response: OpenAIObject) -> Iterator[LLMResponse]:
"""
Accumulate chunk deltas into a full response. Returns the full message.
"""
response = {"role": None, "response": "", "raw_response": ""}

for r in api_response: # noqa
response["raw_response"] = r.to_dict_recursive()

delta = r.choices[0]["delta"]
response["model"] = r.model
if r.usage:
response["prompt_tokens"] = r.usage["prompt_tokens"]
response["completion_tokens"] = r.usage["completion_tokens"]
response["total_tokens"] = r.usage["total_tokens"]

if "role" in delta:
response["role"] = delta["role"]

if delta.get("function_call"):
fn_call = delta.get("function_call")
if "function_call" not in response["data"]:
response["data"]["function_call"] = {"name": None, "arguments": ""}
if "name" in fn_call:
response["data"]["function_call"]["name"] = fn_call.name
if "arguments" in fn_call:
response["data"]["function_call"]["arguments"] += (
fn_call.arguments or ""
)

if "content" in delta:
response["response"] += delta.content or ""

yield LLMResponse(**response)
32 changes: 30 additions & 2 deletions src/declarai/operators/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,28 @@ def __init__(
llm: LLM,
parsed: PythonParser,
llm_params: LLMParamsType = None,
streaming: bool = None,
**kwargs: Dict,
):
self.llm = llm
self.parsed = parsed
self.llm_params = llm_params or {}
self._call_streaming = streaming

@property
def streaming(self) -> bool:
"""
Returns whether the operator is streaming or not
Returns:
"""
if self._call_streaming is not None:
return self._call_streaming

if hasattr(self.llm, "streaming"):
return self.llm.streaming

return False

@abstractmethod
def compile(self, **kwargs) -> CompiledTemplate:
Expand Down Expand Up @@ -70,6 +87,8 @@ def predict(
The response from the LLM
"""
llm_params = llm_params or self.llm_params # Order is important -
if self.streaming is not None:
llm_params["stream"] = self.streaming # streaming should be the last param
# provided params during execution should override the ones provided during initialization
return self.llm.predict(**self.compile(**kwargs), **llm_params)

Expand Down Expand Up @@ -103,13 +122,22 @@ class BaseChatOperator(BaseOperator):
"""

def __init__(
self, system: Optional[str] = None, greeting: Optional[str] = None, **kwargs
self,
system: Optional[str] = None,
greeting: Optional[str] = None,
parsed: PythonParser = None,
**kwargs,
):
super().__init__(**kwargs)
super().__init__(parsed=parsed, **kwargs)
self.system = system or self.parsed.docstring_freeform
self.greeting = greeting or getattr(self.parsed.decorated, "greeting", None)
self.parsed_send_func = (
PythonParser(self.parsed.decorated.send)
if getattr(self.parsed.decorated, "send", None)
else None
)

if self.streaming:
raise ValueError(
"Streaming is not supported for chat operators. Please disable streaming."
)
8 changes: 7 additions & 1 deletion src/declarai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def plan(self, **kwargs) -> FutureTask:

def _exec(self, kwargs) -> Any:
self.llm_response = self.operator.predict(**kwargs)
return self.operator.parse_output(self.llm_response.response)
if not self.operator.streaming:
return self.operator.parse_output(self.llm_response.response)
return self.llm_response

def _exec_middlewares(self, kwargs) -> Any:
if self.middlewares:
Expand Down Expand Up @@ -199,6 +201,7 @@ def task(
*,
middlewares: List[Type[TaskMiddleware]] = None,
llm_params: LLMParamsType = None,
streaming: bool = None,
**kwargs,
) -> Callable[[Callable], Task]:
...
Expand All @@ -209,13 +212,15 @@ def task(
*,
middlewares: List[Type[TaskMiddleware]] = None,
llm_params: LLMParamsType = None,
streaming: bool = None,
):
"""
The decorator that creates the task
Args:
func: the function to decorate that represents the task
middlewares: middleware to use while executing the task
llm_params: llm_params to use when calling the llm
streaming: whether to stream the response from the llm or not
Returns:
(Task): the task that was created
Expand All @@ -228,6 +233,7 @@ def wrap(_func: Callable) -> Task:
parsed=PythonParser(_func),
llm=self.llm,
llm_params=llm_params,
streaming=streaming,
)
llm_task = Task(operator=operator, middlewares=middlewares)
llm_task.__name__ = _func.__name__
Expand Down
6 changes: 3 additions & 3 deletions tests/test_declarai.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,19 @@ def test_declarai(mocked_task_decorator, mocked_resolve_llm):
def test_declarai_openai():
kwargs = {
"model": "davinci",
"openai_token": "test_token"
"openai_token": "test_token",
"stream": True,
}
declarai = Declarai.openai(
**kwargs
)

assert declarai.llm.streaming is True
assert declarai.llm.provider == "openai"
assert declarai.llm.model == "davinci"
assert declarai.llm.api_key == "test_token"



def test_declarai_azure_openai():
kwargs = {
"deployment_name": "test",
Expand All @@ -52,4 +53,3 @@ def test_declarai_azure_openai():
assert declarai.llm.api_key == "123"
assert declarai.llm._kwargs["api_base"] == "456"
assert declarai.llm._kwargs["api_version"] == "789"

0 comments on commit 5d58f7d

Please sign in to comment.