diff --git a/prediction_market_agent/agents/langchain_agent.py b/prediction_market_agent/agents/langchain_agent.py index bae9964b..997cca27 100644 --- a/prediction_market_agent/agents/langchain_agent.py +++ b/prediction_market_agent/agents/langchain_agent.py @@ -1,5 +1,8 @@ +from typing import Optional + from langchain.agents import AgentType, initialize_agent, load_tools from langchain_community.llms import OpenAI +from langchain_core.language_models import BaseLLM from prediction_market_agent_tooling.markets.agent_market import AgentMarket from prediction_market_agent import utils @@ -7,9 +10,9 @@ class LangChainAgent(AbstractAgent): - def __init__(self) -> None: + def __init__(self, llm: Optional[BaseLLM] = None) -> None: keys = utils.APIKeys() - llm = OpenAI(openai_api_key=keys.openai_api_key) + llm = OpenAI(openai_api_key=keys.openai_api_key) if not llm else llm # Can use pre-defined search tool # TODO: Tavily tool could give better results # https://docs.tavily.com/docs/tavily-api/langchain diff --git a/prediction_market_agent/agents/ollama_langchain_agent.py b/prediction_market_agent/agents/ollama_langchain_agent.py new file mode 100644 index 00000000..7c6b65f6 --- /dev/null +++ b/prediction_market_agent/agents/ollama_langchain_agent.py @@ -0,0 +1,17 @@ +from langchain_community.llms.ollama import Ollama + +from prediction_market_agent.agents.langchain_agent import LangChainAgent +from prediction_market_agent.tools.ollama_utils import is_ollama_running + + +class OllamaLangChainAgent(LangChainAgent): + def __init__(self) -> None: + # Make sure Ollama is running locally + if not is_ollama_running(): + raise EnvironmentError( + "Ollama is not running, cannot instantiate Ollama agent" + ) + llm = Ollama( + model="mistral", base_url="http://localhost:11434" + ) # Mistral since it supports function calling + super().__init__(llm=llm) diff --git a/prediction_market_agent/tools/ollama_utils.py b/prediction_market_agent/tools/ollama_utils.py new file mode 100644 index 00000000..efe73d56 --- /dev/null +++ b/prediction_market_agent/tools/ollama_utils.py @@ -0,0 +1,6 @@ +import requests + + +def is_ollama_running(base_url: str = "http://localhost:11434") -> bool: + r = requests.get(f"{base_url}/api/tags") + return r.status_code == 200