Skip to content

Commit

Permalink
pin to mistal 0.4.2 for now
Browse files Browse the repository at this point in the history
  • Loading branch information
jjallaire committed Aug 7, 2024
1 parent 226320c commit 20f6e8f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ dev = [
"ipywidgets",
"langchain",
"langchainhub",
"mistralai",
"mistralai==0.4.2",
"mypy",
"nbformat",
"openai",
Expand Down
8 changes: 8 additions & 0 deletions src/inspect_ai/_util/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def module_version_error(
)


def module_max_version_error(feature: str, package: str, max_version: str) -> Exception:
return PrerequisiteError(
f"[bold]ERROR[/bold]: {feature} supports only version {max_version} and earlier of package {package} "
f"(you have version {version(package)} installed).\n\n"
f"Install the older version with with:\n\n[bold]pip install {package}=={max_version}[/bold]\n"
)


def exception_message(ex: BaseException) -> str:
return getattr(ex, "message", repr(ex))

Expand Down
7 changes: 6 additions & 1 deletion src/inspect_ai/_util/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

import semver

from .error import module_version_error
from .error import module_max_version_error, module_version_error


def verify_required_version(feature: str, package: str, version: str) -> None:
if not has_required_version(package, version):
raise module_version_error(feature, package, version)


def verify_max_version(feature: str, package: str, max_version: str) -> None:
if semver.Version.parse(version(package)).compare(max_version) > 0:
raise module_max_version_error(feature, package, max_version)


def has_required_version(package: str, required_version: str) -> bool:
if semver.Version.parse(version(package)).compare(required_version) >= 0:
return True
Expand Down
6 changes: 3 additions & 3 deletions src/inspect_ai/model/_providers/providers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version
from inspect_ai._util.version import verify_max_version, verify_required_version

from .._model import ModelAPI
from .._registry import modelapi
Expand Down Expand Up @@ -138,7 +138,7 @@ def cf() -> type[ModelAPI]:
def mistral() -> type[ModelAPI]:
FEATURE = "Mistral API"
PACKAGE = "mistralai"
MIN_VERSION = "0.1.3"
MAX_VERSION = "0.4.2"

# verify we have the package
try:
Expand All @@ -147,7 +147,7 @@ def mistral() -> type[ModelAPI]:
raise pip_dependency_error(FEATURE, [PACKAGE])

# verify version
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
verify_max_version(FEATURE, PACKAGE, MAX_VERSION)

# in the clear
from .mistral import MistralAPI
Expand Down

0 comments on commit 20f6e8f

Please sign in to comment.