From 20f6e8f3ac17b4fb29ed55173e72a5c5fa95b5ac Mon Sep 17 00:00:00 2001 From: "J.J. Allaire" Date: Wed, 7 Aug 2024 17:17:16 -0400 Subject: [PATCH] pin to mistal 0.4.2 for now --- pyproject.toml | 2 +- src/inspect_ai/_util/error.py | 8 ++++++++ src/inspect_ai/_util/version.py | 7 ++++++- src/inspect_ai/model/_providers/providers.py | 6 +++--- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 97f41cc98..6fb29464c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ dev = [ "ipywidgets", "langchain", "langchainhub", - "mistralai", + "mistralai==0.4.2", "mypy", "nbformat", "openai", diff --git a/src/inspect_ai/_util/error.py b/src/inspect_ai/_util/error.py index e99000518..d12eda933 100644 --- a/src/inspect_ai/_util/error.py +++ b/src/inspect_ai/_util/error.py @@ -24,6 +24,14 @@ def module_version_error( ) +def module_max_version_error(feature: str, package: str, max_version: str) -> Exception: + return PrerequisiteError( + f"[bold]ERROR[/bold]: {feature} supports only version {max_version} and earlier of package {package} " + f"(you have version {version(package)} installed).\n\n" + f"Install the older version with with:\n\n[bold]pip install {package}=={max_version}[/bold]\n" + ) + + def exception_message(ex: BaseException) -> str: return getattr(ex, "message", repr(ex)) diff --git a/src/inspect_ai/_util/version.py b/src/inspect_ai/_util/version.py index 44e21d8f9..6c22c46a9 100644 --- a/src/inspect_ai/_util/version.py +++ b/src/inspect_ai/_util/version.py @@ -2,7 +2,7 @@ import semver -from .error import module_version_error +from .error import module_max_version_error, module_version_error def verify_required_version(feature: str, package: str, version: str) -> None: @@ -10,6 +10,11 @@ def verify_required_version(feature: str, package: str, version: str) -> None: raise module_version_error(feature, package, version) +def verify_max_version(feature: str, package: str, max_version: str) -> None: + if semver.Version.parse(version(package)).compare(max_version) > 0: + raise module_max_version_error(feature, package, max_version) + + def has_required_version(package: str, required_version: str) -> bool: if semver.Version.parse(version(package)).compare(required_version) >= 0: return True diff --git a/src/inspect_ai/model/_providers/providers.py b/src/inspect_ai/model/_providers/providers.py index 7835436d5..585019742 100644 --- a/src/inspect_ai/model/_providers/providers.py +++ b/src/inspect_ai/model/_providers/providers.py @@ -1,5 +1,5 @@ from inspect_ai._util.error import pip_dependency_error -from inspect_ai._util.version import verify_required_version +from inspect_ai._util.version import verify_max_version, verify_required_version from .._model import ModelAPI from .._registry import modelapi @@ -138,7 +138,7 @@ def cf() -> type[ModelAPI]: def mistral() -> type[ModelAPI]: FEATURE = "Mistral API" PACKAGE = "mistralai" - MIN_VERSION = "0.1.3" + MAX_VERSION = "0.4.2" # verify we have the package try: @@ -147,7 +147,7 @@ def mistral() -> type[ModelAPI]: raise pip_dependency_error(FEATURE, [PACKAGE]) # verify version - verify_required_version(FEATURE, PACKAGE, MIN_VERSION) + verify_max_version(FEATURE, PACKAGE, MAX_VERSION) # in the clear from .mistral import MistralAPI