Skip to content

Commit

Permalink
fix(agent): Unbreak LLM status check on start-up
Browse files Browse the repository at this point in the history
Fixes #7508

- Amend `app/configurator.py:check_model(..)` to check multiple models at once and save duplicate API calls
- Amend `MultiProvider.get_available_providers()` to verify availability by fetching models and handle failure
  • Loading branch information
Pwuts committed Jul 23, 2024
1 parent aca7165 commit e7885f9
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
34 changes: 20 additions & 14 deletions autogpt/autogpt/app/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ async def apply_overrides_to_config(
raise click.UsageError("--continuous-limit can only be used with --continuous")

# Check availability of configured LLMs; fallback to other LLM if unavailable
config.fast_llm = await check_model(config.fast_llm, "fast_llm")
config.smart_llm = await check_model(config.smart_llm, "smart_llm")
config.fast_llm, config.smart_llm = await check_models(
(config.fast_llm, "fast_llm"), (config.smart_llm, "smart_llm")
)

if skip_reprompt:
config.skip_reprompt = True
Expand All @@ -61,17 +62,22 @@ async def apply_overrides_to_config(
config.skip_news = True


async def check_model(
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
) -> ModelName:
async def check_models(
*models: tuple[ModelName, Literal["smart_llm", "fast_llm"]]
) -> tuple[ModelName, ...]:
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
multi_provider = MultiProvider()
models = await multi_provider.get_available_chat_models()

if any(model_name == m.name for m in models):
return model_name

logger.warning(
f"You don't have access to {model_name}. Setting {model_type} to {GPT_3_MODEL}."
)
return GPT_3_MODEL
available_models = await multi_provider.get_available_chat_models()

checked_models: list[ModelName] = []
for model, model_type in models:
if any(model == m.name for m in available_models):
checked_models.append(model)
else:
logger.warning(
f"You don't have access to {model}. "
f"Setting {model_type} to {GPT_3_MODEL}."
)
checked_models.append(GPT_3_MODEL)

return tuple(checked_models)
16 changes: 10 additions & 6 deletions forge/forge/llm/providers/multi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar, get_args
from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args

from pydantic import ValidationError

Expand Down Expand Up @@ -68,7 +68,7 @@ async def get_available_models(self) -> Sequence[ChatModelInfo[ModelName]]:

async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]:
models = []
for provider in self.get_available_providers():
async for provider in self.get_available_providers():
models.extend(await provider.get_available_chat_models())
return models

Expand Down Expand Up @@ -120,14 +120,18 @@ def get_model_provider(self, model: ModelName) -> ChatModelProvider:
model_info = CHAT_MODELS[model]
return self._get_provider(model_info.provider_name)

def get_available_providers(self) -> Iterator[ChatModelProvider]:
async def get_available_providers(self) -> AsyncIterator[ChatModelProvider]:
for provider_name in ModelProviderName:
self._logger.debug(f"Checking if {provider_name} is available...")
self._logger.debug(f"Checking if provider {provider_name} is available...")
try:
yield self._get_provider(provider_name)
self._logger.debug(f"{provider_name} is available!")
provider = self._get_provider(provider_name)
await provider.get_available_models() # check connection
yield provider
self._logger.debug(f"Provider '{provider_name}' is available!")
except ValueError:
pass
except Exception as e:
self._logger.debug(f"Provider '{provider_name}' is failing: {e}")

def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
_provider = self._provider_instances.get(provider_name)
Expand Down

0 comments on commit e7885f9

Please sign in to comment.