Skip to content

Commit

Permalink
feat: implements azure-openai
Browse files Browse the repository at this point in the history
make sure to have the following env variables set:
AZURE_OPENAI_API_KEY
AZURE_OPENAI_ENDPOINT

- slack error messages should be sent in thread
  • Loading branch information
Erez Sharim committed Oct 29, 2024
1 parent b1e2ee4 commit 41a0b36
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 38 deletions.
38 changes: 7 additions & 31 deletions app/llm/agents.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,25 @@
from typing import Optional

from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.prompts import MessagesPlaceholder
from langchain.prompts.chat import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.tools import Tool
from langchain_core.runnables import RunnablePassthrough
from langchain_core.utils.function_calling import convert_to_openai_function

from ..settings import settings
from .model import create_model, get_model_from_uri
from .prompts import MEMORY_KEY


def create_llm_function_converter(type: str):
if type == "openai":
return convert_to_openai_function
def agent_creator(type: str):
if type == "openai" or type == "azure-openai":
return create_openai_tools_agent
else:
raise NotImplementedError(f"LLM model {type} is not supported.")


def create_function_messages_formatter(type: str):
if type == "openai":
return format_to_openai_function_messages
else:
raise NotImplementedError(f"LLM model {type} is not supported.")
raise NotImplementedError(f"LLM agent creator {type} is not supported.")


def create_agent(
prompt: PromptTemplate,
# data_context: dict[str, Any],
tools: list[Tool] = [],
name="",
streaming=True,
Expand All @@ -48,20 +36,8 @@ def create_agent(
)

model_type, _, _ = get_model_from_uri(settings.LLM_MODEL)
model = create_model(temperature=0, streaming=streaming, name=name, model=model)
llm = create_model(temperature=0, streaming=streaming, name=name, model=model)

converter = create_llm_function_converter(model_type)
llm_with_tools = model.bind(functions=[converter(t) for t in tools])

function_formatter = create_function_messages_formatter(model_type)

agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: function_formatter(x["intermediate_steps"])
)
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
agent = agent_creator(model_type)(llm=llm, tools=tools, prompt=prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)
return agent_executor
30 changes: 27 additions & 3 deletions app/llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_openai import ChatOpenAI
from langchain_openai import AzureChatOpenAI, ChatOpenAI

from ..settings import settings
from .fake_model import FakeChatOpenAI

LLM_ALLOWED_ARGS = {
"openai": ["model", "temperature", "streaming", "name"],
"azure-openai": [
"temperature",
"streaming",
"name",
"model",
"api_version",
"deployment_name",
],
"fake": ["responses", "sleep"],
}

Expand All @@ -27,15 +35,15 @@ def create_model(model: Optional[str] = None, **args) -> BaseChatModel:
if model is None:
model = settings.LLM_MODEL
# TODO convert qp as model args
type, model, qp = get_model_from_uri(model)
type, model_name, qp = get_model_from_uri(model)
if type == "fake":
filtered_args = {
key: value for key, value in args.items() if key in LLM_ALLOWED_ARGS["fake"]
}

return FakeChatOpenAI(**filtered_args)
elif type.startswith("openai"):
args["model"] = model
args["model"] = model_name
if "temperature" not in args:
args["temperature"] = 0.9
args["streaming"] = True
Expand All @@ -46,5 +54,21 @@ def create_model(model: Optional[str] = None, **args) -> BaseChatModel:
}

return ChatOpenAI(**filtered_args)
elif type.startswith("azure-openai"):
args["model"] = model_name
if "deployment_name" not in args:
args["deployment_name"] = model_name
if "temperature" not in args:
args["temperature"] = 0.9
if "api_version" not in args:
args["api_version"] = "2024-05-01-preview"
args["streaming"] = True
filtered_args = {
key: value
for key, value in args.items()
if key in LLM_ALLOWED_ARGS["azure-openai"]
}

return AzureChatOpenAI(**filtered_args)
else:
raise ValueError(f"LLM_MODEL env var with type '{type}' is not supported")
2 changes: 1 addition & 1 deletion app/llm/tools/app_extra_intructions_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ async def _find_app(workspace_id: str, app_name: Optional[str] = "") -> str:
name="find_app",
description="finds the app name in the messages",
args_schema=FindAppInput,
return_direct=True,
return_direct=False,
handle_tool_error=True,
)
13 changes: 10 additions & 3 deletions app/slack/event_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,15 @@ def answer(client: WebClient, event, logger, say, context):
# TODO: consider letting the user choose the ws to work with in Slack
ws_ext_id = workspaces[0]
ws_store = factory_ws_store()
conversation: Conversation = None
ws: Workspace = None
conversation: Conversation | None = None
ws: Workspace | None = None
conversation_store = service_registry().get(ConversationStore)
with SQLAlchemyTransactionContext().manage() as tx_context:
ws = ws_store.get_by_id(workspace_id=ws_ext_id, tx_context=tx_context)
if ws is None:
say(text=f"We could not find workspace {ws_ext_id}")
say(
text=f"We could not find workspace {ws_ext_id}", thread_ts=thread_ts
)
return

conversation = conversation_store.get_by_external_id(
Expand Down Expand Up @@ -120,5 +122,10 @@ def answer(client: WebClient, event, logger, say, context):
except Exception as e:
import traceback

say(
f"Fatal bot error: {str(e)}. Please follow up with the system admin",
thread_ts=thread_ts,
)

logger.error(f"Error answering message: {e}")
logger.error(traceback.format_exc())

0 comments on commit 41a0b36

Please sign in to comment.