From 057c40022b9ed13fdf4c738e0bb898bc0999be1e Mon Sep 17 00:00:00 2001 From: Robert Szefler Date: Tue, 11 Jun 2024 13:43:36 +0200 Subject: [PATCH] More LLM provider config refactors --- holmes.py | 41 ++++--- holmes/common/env_vars.py | 5 - holmes/config.py | 224 ++++++-------------------------------- holmes/core/provider.py | 167 ++++++++++++++++++++++++++++ server.py | 8 +- 5 files changed, 226 insertions(+), 219 deletions(-) create mode 100644 holmes/core/provider.py diff --git a/holmes.py b/holmes.py index 32303500..2eef5e32 100644 --- a/holmes.py +++ b/holmes.py @@ -5,7 +5,7 @@ import re import warnings from pathlib import Path -from typing import List, Optional, Pattern +from typing import List, Optional import typer from rich.console import Console @@ -13,7 +13,8 @@ from rich.markdown import Markdown from rich.rule import Rule -from holmes.config import LLMConfig, LLMProviderType +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.provider import LLMProvider from holmes.plugins.destinations import DestinationType from holmes.plugins.prompts import load_prompt from holmes import get_version @@ -29,9 +30,10 @@ # Common cli options +llm_provider_names = ", ".join(str(tp) for tp in LLMProviderType) opt_llm: Optional[LLMProviderType] = typer.Option( LLMProviderType.OPENAI, - help="LLM provider ('openai' or 'azure')", # TODO list all + help="LLM provider (supported values: {llm_provider_names})" ) opt_api_key: Optional[str] = typer.Option( None, @@ -136,7 +138,7 @@ def ask( Ask any question and answer using available tools """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -145,15 +147,17 @@ def ask( max_steps=max_steps, custom_toolsets=custom_toolsets, ) + provider = LLMProvider(config) system_prompt = load_prompt(system_prompt) - ai = config.create_toolcalling_llm(console, allowed_toolsets) + ai = provider.create_toolcalling_llm(console, allowed_toolsets) console.print("[bold yellow]User:[/bold yellow] " + prompt) response = ai.call(system_prompt, prompt) text_result = Markdown(response.result) if show_tool_output and response.tool_calls: for tool_call in response.tool_calls: console.print(f"[bold magenta]Used Tool:[/bold magenta]", end="") - # we need to print this separately with markup=False because it contains arbitrary text and we don't want console.print to interpret it + # we need to print this separately with markup=False because it contains arbitrary text + # and we don't want console.print to interpret it console.print(f"{tool_call.description}. Output=\n{tool_call.result}", markup=False) console.print(f"[bold green]AI:[/bold green]", end=" ") console.print(text_result, soft_wrap=True) @@ -195,7 +199,7 @@ def alertmanager( Investigate a Prometheus/Alertmanager alert """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -210,17 +214,18 @@ def alertmanager( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) - + provider = LLMProvider(config) + if alertname: alertname = re.compile(alertname) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) + ai = provider.create_issue_investigator(console, allowed_toolsets) - source = config.create_alertmanager_source() + source = provider.create_alertmanager_source() if destination == DestinationType.SLACK: - slack = config.create_slack_destination() + slack = provider.create_slack_destination() try: issues = source.fetch_issues(alertname) @@ -291,7 +296,7 @@ def jira( Investigate a Jira ticket """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -305,10 +310,11 @@ def jira( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider = LLMProvider(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_jira_source() + ai = provider.create_issue_investigator(console, allowed_toolsets) + source = provider.create_jira_source() try: # TODO: allow passing issue ID issues = source.fetch_issues() @@ -380,7 +386,7 @@ def github( Investigate a GitHub issue """ console = init_logging(verbose) - config = LLMConfig.load_from_file( + config = BaseLLMConfig.load_from_file( config_file, api_key=api_key, llm=llm, @@ -395,10 +401,11 @@ def github( custom_toolsets=custom_toolsets, custom_runbooks=custom_runbooks ) + provider = LLMProvider(config) system_prompt = load_prompt(system_prompt) - ai = config.create_issue_investigator(console, allowed_toolsets) - source = config.create_github_source() + ai = provider.create_issue_investigator(console, allowed_toolsets) + source = provider.create_github_source() try: issues = source.fetch_issues() except Exception as e: diff --git a/holmes/common/env_vars.py b/holmes/common/env_vars.py index bed9e4df..1487135b 100644 --- a/holmes/common/env_vars.py +++ b/holmes/common/env_vars.py @@ -11,8 +11,3 @@ STORE_API_KEY = os.environ.get("STORE_API_KEY", "") STORE_EMAIL = os.environ.get("STORE_EMAIL", "") STORE_PASSWORD = os.environ.get("STORE_PASSWORD", "") - -# Currently supports BUILTIN and ROBUSTA_AI -AI_AGENT = os.environ.get("AI_AGENT", "BUILTIN") - -ROBUSTA_AI_URL = os.environ.get("ROBUSTA_AI_URL", "") diff --git a/holmes/config.py b/holmes/config.py index a98e982c..b165f513 100644 --- a/holmes/config.py +++ b/holmes/config.py @@ -4,24 +4,7 @@ from strenum import StrEnum from typing import Annotated, Any, Dict, List, Optional, get_args, get_type_hints -from openai import AzureOpenAI, OpenAI from pydantic import SecretStr, FilePath -from pydash.arrays import concat -from rich.console import Console - -from holmes.core.runbooks import RunbookManager -from holmes.core.tool_calling_llm import ( - IssueInvestigator, - ToolCallingLLM, - YAMLToolExecutor, -) -from holmes.core.tools import ToolsetPattern, get_matching_toolsets -from holmes.plugins.destinations.slack import SlackDestination -from holmes.plugins.runbooks import load_builtin_runbooks, load_runbooks_from_file -from holmes.plugins.sources.jira import JiraSource -from holmes.plugins.sources.github import GitHubSource -from holmes.plugins.sources.prometheus.plugin import AlertManagerSource -from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file from holmes.utils.pydantic_utils import BaseConfig, load_model_from_file, EnvVarName @@ -32,7 +15,7 @@ class LLMProviderType(StrEnum): class BaseLLMConfig(BaseConfig): - llm: LLMProviderType = LLMProviderType.OPENAI + llm_provider: LLMProviderType = LLMProviderType.OPENAI # FIXME: the following settings do not belong here. They define the # configuration of specific actions, and not of the LLM provider. @@ -70,7 +53,7 @@ def _collect_env_vars(cls) -> Dict[str, Any]: vars_dict = {} hints = get_type_hints(cls, include_extras=True) for field_name in cls.model_fields: - if field_name == "llm": + if field_name == "llm_provider": # Handled in load_from_env continue tp_obj = hints[field_name] @@ -81,191 +64,29 @@ def _collect_env_vars(cls) -> Dict[str, Any]: else: env_var_name = field_name.upper() if env_var_name in os.environ: + # TODO parse lists (custom_runbooks/custom_toolsets) vars_dict[field_name] = os.environ[env_var_name] return vars_dict @classmethod def load_from_env(cls) -> "BaseLLMConfig": - llm_name = os.getenv("LLM_PROVIDER", "OPENAI").lower() - llm_provider_type = LLMProviderType(llm_name) - if llm_provider_type == LLMProviderType.AZURE: + llm_name = os.environ.get("LLM_PROVIDER", "OPENAI").lower() + llm_provider = LLMProviderType(llm_name) + if llm_provider == LLMProviderType.AZURE: final_class = AzureLLMConfig - elif llm_provider_type == LLMProviderType.OPENAI: + elif llm_provider == LLMProviderType.OPENAI: final_class = OpenAILLMConfig - elif llm_provider_type == LLMProviderType.ROBUSTA: + elif llm_provider == LLMProviderType.ROBUSTA: final_class = RobustaLLMConfig else: raise NotImplementedError(f"Unknown LLM {llm_name}") kwargs = final_class._collect_env_vars() - ret = final_class(**kwargs) - return ret - - -class BaseOpenAIConfig(BaseLLMConfig): - model: Optional[str] = "gpt-4o" - max_steps: Optional[int] = 10 - - -class OpenAILLMConfig(BaseOpenAIConfig): - api_key: Optional[SecretStr] - - -class AzureLLMConfig(BaseOpenAIConfig): - api_key: Optional[SecretStr] - endpoint: Optional[str] - azure_api_version: Optional[str] = "2024-02-01" - - -class RobustaLLMConfig(BaseLLMConfig): - url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] - - -# TODO refactor everything below - - -class LLMConfig(BaseLLMConfig): - - def create_llm(self) -> OpenAI: - if self.llm == LLMProviderType.OPENAI: - return OpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - ) - elif self.llm == LLMProviderType.AZURE: - return AzureOpenAI( - api_key=self.api_key.get_secret_value() if self.api_key else None, - azure_endpoint=self.azure_endpoint, - api_version=self.azure_api_version, - ) - else: - raise ValueError(f"Unknown LLM type: {self.llm}") - - def _create_tool_executor( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> YAMLToolExecutor: - all_toolsets = load_builtin_toolsets() - for ts_path in self.custom_toolsets: - all_toolsets.extend(load_toolsets_from_file(ts_path)) - - if allowed_toolsets == "*": - matching_toolsets = all_toolsets - else: - matching_toolsets = get_matching_toolsets( - all_toolsets, allowed_toolsets.split(",") - ) - - enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] - for ts in all_toolsets: - if ts not in matching_toolsets: - console.print( - f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" - ) - elif ts not in enabled_toolsets: - console.print( - f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" - ) - #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") - else: - logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") - # console.print(f"[green]Loaded toolset {ts.name}[/green] from {ts.get_path()}") - - enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) - logging.debug( - f"Starting AI session with tools: {[t.name for t in enabled_tools]}" - ) - return YAMLToolExecutor(enabled_toolsets) - - def create_toolcalling_llm( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> ToolCallingLLM: - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return ToolCallingLLM( - self.create_llm(), - self.model, - tool_executor, - self.max_steps, - ) - - def create_issue_investigator( - self, console: Console, allowed_toolsets: ToolsetPattern - ) -> IssueInvestigator: - all_runbooks = load_builtin_runbooks() - for runbook_path in self.custom_runbooks: - all_runbooks.extend(load_runbooks_from_file(runbook_path)) - - runbook_manager = RunbookManager(all_runbooks) - tool_executor = self._create_tool_executor(console, allowed_toolsets) - return IssueInvestigator( - self.create_llm(), - self.model, - tool_executor, - runbook_manager, - self.max_steps, - ) - - def create_jira_source(self) -> JiraSource: - if self.jira_url is None: - raise ValueError("--jira-url must be specified") - if not ( - self.jira_url.startswith("http://") or self.jira_url.startswith("https://") - ): - raise ValueError("--jira-url must start with http:// or https://") - if self.jira_username is None: - raise ValueError("--jira-username must be specified") - if self.jira_api_key is None: - raise ValueError("--jira-api-key must be specified") - - return JiraSource( - url=self.jira_url, - username=self.jira_username, - api_key=self.jira_api_key.get_secret_value(), - jql_query=self.jira_query, - ) - - def create_github_source(self) -> GitHubSource: - if not ( - self.github_url.startswith( - "http://") or self.github_url.startswith("https://") - ): - raise ValueError("--github-url must start with http:// or https://") - if self.github_owner is None: - raise ValueError("--github-owner must be specified") - if self.github_repository is None: - raise ValueError("--github-repository must be specified") - if self.github_pat is None: - raise ValueError("--github-pat must be specified") - - return GitHubSource( - url=self.github_url, - owner=self.github_owner, - pat=self.github_pat.get_secret_value(), - repository=self.github_repository, - query=self.github_query, - ) - - def create_alertmanager_source(self) -> AlertManagerSource: - if self.alertmanager_url is None: - raise ValueError("--alertmanager-url must be specified") - if not ( - self.alertmanager_url.startswith("http://") - or self.alertmanager_url.startswith("https://") - ): - raise ValueError("--alertmanager-url must start with http:// or https://") - - return AlertManagerSource( - url=self.alertmanager_url, - username=self.alertmanager_username, - password=self.alertmanager_password, - ) - - def create_slack_destination(self): - if self.slack_token is None: - raise ValueError("--slack-token must be specified") - if self.slack_channel is None: - raise ValueError("--slack-channel must be specified") - return SlackDestination(self.slack_token.get_secret_value(), self.slack_channel) + kwargs["llm_provider"] = llm_provider + return final_class(**kwargs) @classmethod - def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": + def load_from_file(cls, config_file: Optional[str], **kwargs) -> "BaseLLMConfig": + # FIXME! if config_file is not None: logging.debug("Loading config from file %s", config_file) config_from_file = load_model_from_file(cls, config_file) @@ -283,7 +104,26 @@ def load_from_file(cls, config_file: Optional[str], **kwargs) -> "Config": merged_config = config_from_file.dict() # remove Nones to avoid overriding config file with empty cli args cli_overrides = { - k: v for k, v in config_from_cli.dict().items() if v is not None and v != [] + k: v for k, v in config_from_cli.model_dump().items() if v is not None and v != [] } merged_config.update(cli_overrides) return cls(**merged_config) + + +class BaseOpenAIConfig(BaseLLMConfig): + model: Annotated[Optional[str], EnvVarName("AI_MODEL")] = "gpt-4o" + max_steps: Optional[int] = 10 + + +class OpenAILLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("OPENAI_API_KEY")] + + +class AzureLLMConfig(BaseOpenAIConfig): + api_key: Annotated[Optional[SecretStr], EnvVarName("AZURE_API_KEY")] + endpoint: Annotated[Optional[str], EnvVarName("AZURE_ENDPOINT")] + azure_api_version: Optional[str] = "2024-02-01" + + +class RobustaLLMConfig(BaseLLMConfig): + url: Annotated[str, EnvVarName("ROBUSTA_AI_URL")] diff --git a/holmes/core/provider.py b/holmes/core/provider.py new file mode 100644 index 00000000..8fd61863 --- /dev/null +++ b/holmes/core/provider.py @@ -0,0 +1,167 @@ +import logging + +from openai import AzureOpenAI, OpenAI +from pydash.arrays import concat +from rich.console import Console + +from holmes.config import BaseLLMConfig, LLMProviderType +from holmes.core.runbooks import RunbookManager +from holmes.core.tool_calling_llm import ( + IssueInvestigator, + ToolCallingLLM, + YAMLToolExecutor, +) +from holmes.core.tools import ToolsetPattern, get_matching_toolsets +from holmes.plugins.destinations.slack import SlackDestination +from holmes.plugins.runbooks import load_builtin_runbooks, load_runbooks_from_file +from holmes.plugins.sources.jira import JiraSource +from holmes.plugins.sources.github import GitHubSource +from holmes.plugins.sources.prometheus.plugin import AlertManagerSource +from holmes.plugins.toolsets import load_builtin_toolsets, load_toolsets_from_file + + +class LLMProvider: + def __init__(self, config: BaseLLMConfig): + self.config = config + + def create_llm(self) -> OpenAI: + if self.config.llm_provider == LLMProviderType.OPENAI: + return OpenAI( + api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + ) + elif self.config.llm_provider == LLMProviderType.AZURE: + return AzureOpenAI( + api_key=self.config.api_key.get_secret_value() if self.config.api_key else None, + azure_endpoint=self.config.azure_endpoint, + api_version=self.config.azure_api_version, + ) + elif self.config.llm_provider == LLMProviderType.ROBUSTA: + # TODO + pass + else: + raise ValueError(f"Unknown LLM type: {self.config.llm_provider}") + + def create_toolcalling_llm( + self, console: Console, allowed_toolsets: ToolsetPattern + ) -> ToolCallingLLM: + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return ToolCallingLLM( + self.create_llm(), + self.config.model, + tool_executor, + self.config.max_steps, + ) + + def create_issue_investigator( + self, console: Console, allowed_toolsets: ToolsetPattern + ) -> IssueInvestigator: + all_runbooks = load_builtin_runbooks() + for runbook_path in self.config.custom_runbooks: + all_runbooks.extend(load_runbooks_from_file(runbook_path)) + + runbook_manager = RunbookManager(all_runbooks) + tool_executor = self._create_tool_executor(console, allowed_toolsets) + return IssueInvestigator( + self.create_llm(), + self.config.model, + tool_executor, + runbook_manager, + self.config.max_steps, + ) + + def create_jira_source(self) -> JiraSource: + if self.config.jira_url is None: + raise ValueError("--jira-url must be specified") + if not ( + self.config.jira_url.startswith("http://") or self.config.jira_url.startswith("https://") + ): + raise ValueError("--jira-url must start with http:// or https://") + if self.config.jira_username is None: + raise ValueError("--jira-username must be specified") + if self.config.jira_api_key is None: + raise ValueError("--jira-api-key must be specified") + + return JiraSource( + url=self.config.jira_url, + username=self.config.jira_username, + api_key=self.config.jira_api_key.get_secret_value(), + jql_query=self.config.jira_query, + ) + + def create_github_source(self) -> GitHubSource: + if not ( + self.config.github_url.startswith("http://") + or self.config.github_url.startswith("https://") + ): + raise ValueError("--github-url must start with http:// or https://") + if self.config.github_owner is None: + raise ValueError("--github-owner must be specified") + if self.config.github_repository is None: + raise ValueError("--github-repository must be specified") + if self.config.github_pat is None: + raise ValueError("--github-pat must be specified") + + return GitHubSource( + url=self.config.github_url, + owner=self.config.github_owner, + pat=self.config.github_pat.get_secret_value(), + repository=self.config.github_repository, + query=self.config.github_query, + ) + + def create_alertmanager_source(self) -> AlertManagerSource: + if self.config.alertmanager_url is None: + raise ValueError("--alertmanager-url must be specified") + if not ( + self.config.alertmanager_url.startswith("http://") + or self.config.alertmanager_url.startswith("https://") + ): + raise ValueError("--alertmanager-url must start with http:// or https://") + + return AlertManagerSource( + url=self.config.alertmanager_url, + username=self.config.alertmanager_username, + password=self.config.alertmanager_password, + ) + + def create_slack_destination(self): + if self.config.slack_token is None: + raise ValueError("--slack-token must be specified") + if self.config.slack_channel is None: + raise ValueError("--slack-channel must be specified") + return SlackDestination(self.config.slack_token.get_secret_value(), self.config.slack_channel) + + def _create_tool_executor( + self, console: Console, allowed_toolsets: ToolsetPattern + ) -> YAMLToolExecutor: + all_toolsets = load_builtin_toolsets() + for ts_path in self.config.custom_toolsets: + all_toolsets.extend(load_toolsets_from_file(ts_path)) + + if allowed_toolsets == "*": + matching_toolsets = all_toolsets + else: + matching_toolsets = get_matching_toolsets( + all_toolsets, allowed_toolsets.split(",") + ) + + enabled_toolsets = [ts for ts in matching_toolsets if ts.is_enabled()] + for ts in all_toolsets: + if ts not in matching_toolsets: + console.print( + f"[yellow]Disabling toolset {ts.name} [/yellow] from {ts.get_path()}" + ) + elif ts not in enabled_toolsets: + console.print( + f"[yellow]Not loading toolset {ts.name}[/yellow] ({ts.get_disabled_reason()})" + ) + #console.print(f"[red]The following tools will be disabled: {[t.name for t in ts.tools]}[/red])") + else: + logging.debug(f"Loaded toolset {ts.name} from {ts.get_path()}") + # console.print(f"[green]Loaded toolset {ts.name}[/green] from {ts.get_path()}") + + enabled_tools = concat(*[ts.tools for ts in enabled_toolsets]) + logging.debug( + f"Starting AI session with tools: {[t.name for t in enabled_tools]}" + ) + return YAMLToolExecutor(enabled_toolsets) diff --git a/server.py b/server.py index d73e8e20..692d13bb 100644 --- a/server.py +++ b/server.py @@ -23,15 +23,13 @@ from rich.console import Console from holmes.common.env_vars import ( - AI_AGENT, ALLOWED_TOOLSETS, HOLMES_HOST, HOLMES_PORT, - ROBUSTA_AI_URL, ) from holmes.config import LLMConfig from holmes.core.issue import Issue -from holmes.core.supabase_dal import AuthToken, SupabaseDal +from holmes.core.supabase_dal import SupabaseDal from holmes.plugins.prompts import load_prompt @@ -69,13 +67,13 @@ def init_logging(): init_logging() -logging.info(f"Starting AI server with {AI_AGENT=}, {ROBUSTA_AI_URL=}") +config = LLMConfig.load_from_env() +logging.info(f"Starting AI server with config: {config}") dal = SupabaseDal() session_manager = SessionManager(dal, "RelayHolmes") app = FastAPI() console = Console() -config = LLMConfig.load_from_env() @app.post("/api/investigate")