diff --git a/autogpt/autogpt/app/configurator.py b/autogpt/autogpt/app/configurator.py index 51f7afc0a363..923bc2147755 100644 --- a/autogpt/autogpt/app/configurator.py +++ b/autogpt/autogpt/app/configurator.py @@ -51,8 +51,9 @@ async def apply_overrides_to_config( raise click.UsageError("--continuous-limit can only be used with --continuous") # Check availability of configured LLMs; fallback to other LLM if unavailable - config.fast_llm = await check_model(config.fast_llm, "fast_llm") - config.smart_llm = await check_model(config.smart_llm, "smart_llm") + config.fast_llm, config.smart_llm = await check_models( + (config.fast_llm, "fast_llm"), (config.smart_llm, "smart_llm") + ) if skip_reprompt: config.skip_reprompt = True @@ -61,17 +62,22 @@ async def apply_overrides_to_config( config.skip_news = True -async def check_model( - model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"] -) -> ModelName: +async def check_models( + *models: tuple[ModelName, Literal["smart_llm", "fast_llm"]] +) -> tuple[ModelName, ...]: """Check if model is available for use. If not, return gpt-3.5-turbo.""" multi_provider = MultiProvider() - models = await multi_provider.get_available_chat_models() - - if any(model_name == m.name for m in models): - return model_name - - logger.warning( - f"You don't have access to {model_name}. Setting {model_type} to {GPT_3_MODEL}." - ) - return GPT_3_MODEL + available_models = await multi_provider.get_available_chat_models() + + checked_models: list[ModelName] = [] + for model, model_type in models: + if any(model == m.name for m in available_models): + checked_models.append(model) + else: + logger.warning( + f"You don't have access to {model}. " + f"Setting {model_type} to {GPT_3_MODEL}." + ) + checked_models.append(GPT_3_MODEL) + + return tuple(checked_models) diff --git a/forge/forge/llm/providers/multi.py b/forge/forge/llm/providers/multi.py index e0b08352299b..e6accfff7906 100644 --- a/forge/forge/llm/providers/multi.py +++ b/forge/forge/llm/providers/multi.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Iterator, Optional, Sequence, TypeVar, get_args +from typing import Any, AsyncIterator, Callable, Optional, Sequence, TypeVar, get_args from pydantic import ValidationError @@ -68,7 +68,7 @@ async def get_available_models(self) -> Sequence[ChatModelInfo[ModelName]]: async def get_available_chat_models(self) -> Sequence[ChatModelInfo[ModelName]]: models = [] - for provider in self.get_available_providers(): + async for provider in self.get_available_providers(): models.extend(await provider.get_available_chat_models()) return models @@ -120,14 +120,18 @@ def get_model_provider(self, model: ModelName) -> ChatModelProvider: model_info = CHAT_MODELS[model] return self._get_provider(model_info.provider_name) - def get_available_providers(self) -> Iterator[ChatModelProvider]: + async def get_available_providers(self) -> AsyncIterator[ChatModelProvider]: for provider_name in ModelProviderName: - self._logger.debug(f"Checking if {provider_name} is available...") + self._logger.debug(f"Checking if provider {provider_name} is available...") try: - yield self._get_provider(provider_name) - self._logger.debug(f"{provider_name} is available!") + provider = self._get_provider(provider_name) + await provider.get_available_models() # check connection + yield provider + self._logger.debug(f"Provider '{provider_name}' is available!") except ValueError: pass + except Exception as e: + self._logger.debug(f"Provider '{provider_name}' is failing: {e}") def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider: _provider = self._provider_instances.get(provider_name)