From 46d0ddae565909d8dd617e6448cbf40a4dc8c4d9 Mon Sep 17 00:00:00 2001 From: evangriffiths Date: Mon, 18 Mar 2024 12:46:30 +0000 Subject: [PATCH] Fix mypy --- main.py | 10 +++++----- prediction_market_agent/agents/autogen_agent.py | 2 +- prediction_market_agent/agents/langchain_agent.py | 10 ++++++++-- prediction_market_agent/tools/google_search.py | 6 +++++- prediction_market_agent/utils.py | 9 +++++---- 5 files changed, 24 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 63c0008f..e64794a6 100644 --- a/main.py +++ b/main.py @@ -1,10 +1,9 @@ from decimal import Decimal import typer -from prediction_market_agent_tooling.markets.markets import ( - MarketType, - get_binary_markets, -) +from prediction_market_agent_tooling.markets.agent_market import SortBy +from prediction_market_agent_tooling.markets.markets import MARKET_TYPE_MAP, MarketType +from prediction_market_agent_tooling.tools.utils import check_not_none import prediction_market_agent as pma from prediction_market_agent.agents.all_agents import AgentType, get_agent @@ -19,7 +18,8 @@ def main( Picks one market and answers it, optionally placing a bet. """ # Pick a market - market = get_binary_markets(market_type)[0] + cls = check_not_none(MARKET_TYPE_MAP.get(market_type)) + market = cls.get_binary_markets(limit=1, sort_by=SortBy.NEWEST)[0] # Create the agent and run it agent = get_agent(agent_type) diff --git a/prediction_market_agent/agents/autogen_agent.py b/prediction_market_agent/agents/autogen_agent.py index 6dc66270..5d61d386 100644 --- a/prediction_market_agent/agents/autogen_agent.py +++ b/prediction_market_agent/agents/autogen_agent.py @@ -18,7 +18,7 @@ def get_base_llm_config(self) -> dict[str, t.Any]: "config_list": [ { "model": "gpt-4", - "api_key": keys.openai_api_key, + "api_key": keys.openai_api_key.get_secret_value(), } ], "temperature": 0, diff --git a/prediction_market_agent/agents/langchain_agent.py b/prediction_market_agent/agents/langchain_agent.py index 997cca27..93bc7705 100644 --- a/prediction_market_agent/agents/langchain_agent.py +++ b/prediction_market_agent/agents/langchain_agent.py @@ -12,12 +12,18 @@ class LangChainAgent(AbstractAgent): def __init__(self, llm: Optional[BaseLLM] = None) -> None: keys = utils.APIKeys() - llm = OpenAI(openai_api_key=keys.openai_api_key) if not llm else llm + llm = ( + OpenAI(openai_api_key=keys.openai_api_key.get_secret_value()) + if not llm + else llm + ) # Can use pre-defined search tool # TODO: Tavily tool could give better results # https://docs.tavily.com/docs/tavily-api/langchain tools = load_tools( - ["serpapi", "llm-math"], llm=llm, serpapi_api_key=keys.serp_api_key + ["serpapi", "llm-math"], + llm=llm, + serpapi_api_key=keys.serp_api_key.get_secret_value(), ) self._agent = initialize_agent( tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True diff --git a/prediction_market_agent/tools/google_search.py b/prediction_market_agent/tools/google_search.py index edc6585e..47b86d08 100644 --- a/prediction_market_agent/tools/google_search.py +++ b/prediction_market_agent/tools/google_search.py @@ -4,7 +4,11 @@ def google_search(query: str) -> list[str]: - params = {"q": query, "api_key": utils.APIKeys().serp_api_key, "num": 4} + params = { + "q": query, + "api_key": utils.APIKeys().serp_api_key.get_secret_value(), + "num": 4, + } search = serpapi.GoogleSearch(params) urls = [result["link"] for result in search.get_dict()["organic_results"]] return urls diff --git a/prediction_market_agent/utils.py b/prediction_market_agent/utils.py index 570c580c..4ed88513 100644 --- a/prediction_market_agent/utils.py +++ b/prediction_market_agent/utils.py @@ -5,20 +5,21 @@ check_not_none, should_not_happen, ) +from pydantic import SecretStr class APIKeys(APIKeysBase): - SERP_API_KEY: t.Optional[str] = None - OPENAI_API_KEY: t.Optional[str] = None + SERP_API_KEY: t.Optional[SecretStr] = None + OPENAI_API_KEY: t.Optional[SecretStr] = None @property - def serp_api_key(self) -> str: + def serp_api_key(self) -> SecretStr: return check_not_none( self.SERP_API_KEY, "SERP_API_KEY missing in the environment." ) @property - def openai_api_key(self) -> str: + def openai_api_key(self) -> SecretStr: return check_not_none( self.OPENAI_API_KEY, "OPENAI_API_KEY missing in the environment." )