Skip to content

Commit

Permalink
chore: Fix openai assistants with new openai version (#66)
Browse files Browse the repository at this point in the history
* chore: Fis openai assistants with new openai version

* chore: Fix version check

* chore: Fix langchain history issue

* chore: Fix langchain input issue
  • Loading branch information
valeriosofi authored Apr 30, 2024
1 parent d98db51 commit bf1b4e5
Show file tree
Hide file tree
Showing 9 changed files with 1,379 additions and 1,196 deletions.
2 changes: 1 addition & 1 deletion nebuly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .init import init

__all__ = ["init", "new_interaction"]
__version__ = "0.3.27"
__version__ = "0.3.28"
14 changes: 13 additions & 1 deletion nebuly/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
),
Package(
"openai",
SupportedVersion("1.0.0"),
SupportedVersion("1.0.0", "1.21.0"),
(
"resources.chat.completions.Completions.create",
"resources.chat.completions.AsyncCompletions.create",
Expand All @@ -36,6 +36,18 @@
"resources.beta.threads.messages.messages.AsyncMessages.list",
),
),
Package(
"openai",
SupportedVersion("1.21.0"),
(
"resources.chat.completions.Completions.create",
"resources.chat.completions.AsyncCompletions.create",
"resources.completions.Completions.create",
"resources.completions.AsyncCompletions.create",
"resources.beta.threads.messages.Messages.list",
"resources.beta.threads.messages.AsyncMessages.list",
),
),
Package(
"cohere",
SupportedVersion("4.0.0"),
Expand Down
29 changes: 19 additions & 10 deletions nebuly/providers/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ def _get_input_and_history(chain: Chain, inputs: dict[str, Any] | Any) -> ModelI
def _get_input_and_history_runnable_seq(
sequence: RunnableSequence[Any, Any], inputs: dict[str, Any] | Any
) -> ModelInput:
first = getattr(sequence, "first", None)
steps = getattr(sequence, "steps", None)

if isinstance(first, PromptTemplate):
return ModelInput(prompt=_process_prompt_template(inputs, first))
if steps is not None and len(steps) > 0:
for step in steps:
if isinstance(step, PromptTemplate):
return ModelInput(prompt=_process_prompt_template(inputs, step))

if isinstance(first, ChatPromptTemplate):
prompt, history = _process_chat_prompt_template(inputs, first)
return ModelInput(prompt=prompt, history=history)
if isinstance(step, ChatPromptTemplate):
prompt, history = _process_chat_prompt_template(inputs, step)
return ModelInput(prompt=prompt, history=history)

return ModelInput(prompt="")

Expand Down Expand Up @@ -163,11 +165,18 @@ def _parse_langchain_history(inputs: dict[str, Any]) -> list[HistoryEntry]:
HistoryEntry(user=str(message[0]), assistant=str(message[1]))
for message in history
]
return [
HistoryEntry(
user=str(history[i].content), assistant=str(history[i + 1].content)

human_messages = [message for message in history if message.type == "human"]
ai_messages = [message for message in history if message.type == "ai"]

if len(human_messages) != len(ai_messages):
logger.warning(
"Unequal number of human and AI messages in chat history"
)
for i in range(0, len(history), 2)

return [
HistoryEntry(user=str(human.content), assistant=str(ai.content))
for human, ai in zip(human_messages, ai_messages)
]
return []

Expand Down
19 changes: 15 additions & 4 deletions nebuly/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from nebuly.entities import HistoryEntry, ModelInput
from nebuly.providers.base import PicklerHandler, ProviderDataExtractor
from nebuly.utils import is_module_version_less_than

try:
# This import is valid only starting from openai==1.8.0
Expand All @@ -37,6 +38,16 @@ class LegacyAPIResponse:
pass


if is_module_version_less_than("openai", "1.21.0"):
assistant_messages_list = "resources.beta.threads.messages.messages.Messages.list"
assistant_messages_async_list = (
"resources.beta.threads.messages.messages.AsyncMessages.list"
)
else:
assistant_messages_list = "resources.beta.threads.messages.Messages.list"
assistant_messages_async_list = "resources.beta.threads.messages.AsyncMessages.list"


try:
# This import is valid only starting from openai==1.8.0
from openai._response import AsyncAPIResponse # type: ignore # noqa: E501
Expand Down Expand Up @@ -176,8 +187,8 @@ def extract_input_and_history(self, outputs: Any) -> ModelInput:
history = self._extract_history()
return ModelInput(prompt=prompt, history=history)
if self.function_name in [
"resources.beta.threads.messages.messages.Messages.list",
"resources.beta.threads.messages.messages.AsyncMessages.list",
assistant_messages_list,
assistant_messages_async_list,
]:
if outputs.has_more:
return ModelInput(prompt="")
Expand Down Expand Up @@ -262,8 +273,8 @@ def extract_output( # pylint: disable=too-many-return-statements
payload = ChatCompletion(**payload_dict)
return cast(str, cast(Choice, payload.choices[0]).message.content)
if self.function_name in [
"resources.beta.threads.messages.messages.Messages.list",
"resources.beta.threads.messages.messages.AsyncMessages.list",
assistant_messages_list,
assistant_messages_async_list,
]:
if outputs.has_more:
return ""
Expand Down
24 changes: 24 additions & 0 deletions nebuly/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import importlib.metadata

from packaging import version


def is_module_version_less_than(module_name: str, version_spec: str) -> bool:
"""
Check if the installed version of a module is less than the specified version.
Parameters:
module_name (str): The name of the module to check.
version_spec (str): The version to compare against in the format
'major.minor.patch'.
Returns:
bool: True if the installed version is less than the specified version,
False otherwise.
"""
try:
installed_version = importlib.metadata.version(module_name)
return version.parse(installed_version) < version.parse(version_spec)
except importlib.metadata.PackageNotFoundError:
print(f"Module '{module_name}' is not installed.")
return False
Loading

0 comments on commit bf1b4e5

Please sign in to comment.