-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable Azure Open AI function calling via Azure Functions #1133
base: main
Are you sure you want to change the base?
Changes from 5 commits
5d059d0
d9dd961
fc7cda2
57b70cb
c2f8947
5972249
212ae0b
a67a788
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
convert_to_pf_format, | ||
format_pf_non_streaming_response, | ||
) | ||
import requests | ||
|
||
bp = Blueprint("routes", __name__, static_folder="static", template_folder="static") | ||
|
||
|
@@ -111,6 +112,9 @@ async def assets(path): | |
MS_DEFENDER_ENABLED = os.environ.get("MS_DEFENDER_ENABLED", "true").lower() == "true" | ||
|
||
|
||
azure_openai_tools = [] | ||
azure_openai_available_tools = [] | ||
|
||
# Initialize Azure OpenAI Client | ||
async def init_openai_client(): | ||
azure_openai_client = None | ||
|
@@ -159,6 +163,19 @@ async def init_openai_client(): | |
# Default Headers | ||
default_headers = {"x-ms-useragent": USER_AGENT} | ||
|
||
# Remote function calls | ||
if app_settings.azure_openai.function_call_azure_functions_enabled: | ||
azure_functions_tools_url = f"{app_settings.azure_openai.function_call_azure_functions_tools_base_url}?code={app_settings.azure_openai.function_call_azure_functions_tools_key}" | ||
response = requests.get(azure_functions_tools_url) | ||
response_status_code = response.status_code | ||
if response_status_code == requests.codes.ok: | ||
azure_openai_tools.extend(json.loads(response.text)) | ||
for tool in azure_openai_tools: | ||
azure_openai_available_tools.append(tool["function"]["name"]) | ||
else: | ||
logging.error(f"An error occurred while getting OpenAI Function Call tools metadata: {response.status_code}") | ||
|
||
|
||
azure_openai_client = AsyncAzureOpenAI( | ||
api_version=app_settings.azure_openai.preview_api_version, | ||
api_key=aoai_api_key, | ||
|
@@ -173,6 +190,20 @@ async def init_openai_client(): | |
azure_openai_client = None | ||
raise e | ||
|
||
def openai_remote_azure_function_call(function_name, function_args): | ||
if app_settings.azure_openai.function_call_azure_functions_enabled is not True: | ||
return | ||
|
||
azure_functions_tool_url = f"{app_settings.azure_openai.function_call_azure_functions_tool_base_url}?code={app_settings.azure_openai.function_call_azure_functions_tool_key}" | ||
headers = {'content-type': 'application/json'} | ||
body = { | ||
"tool_name": function_name, | ||
"tool_arguments": json.loads(function_args) | ||
} | ||
response = requests.post(azure_functions_tool_url, data=json.dumps(body), headers=headers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's make this async as well |
||
response.raise_for_status() | ||
|
||
return response.text | ||
|
||
async def init_cosmosdb_client(): | ||
cosmos_conversation_client = None | ||
|
@@ -219,22 +250,28 @@ def prepare_model_args(request_body, request_headers): | |
|
||
for message in request_messages: | ||
if message: | ||
if message["role"] == "assistant" and "context" in message: | ||
context_obj = json.loads(message["context"]) | ||
messages.append( | ||
{ | ||
"role": message["role"], | ||
"content": message["content"], | ||
"context": context_obj | ||
} | ||
) | ||
else: | ||
messages.append( | ||
{ | ||
"role": message["role"], | ||
"content": message["content"] | ||
} | ||
) | ||
match message["role"]: | ||
case "user": | ||
messages.append( | ||
{ | ||
"role": message["role"], | ||
"content": message["content"] | ||
} | ||
) | ||
case "assistant" | "function" | "tool": | ||
messages_helper = {} | ||
messages_helper["role"] = message["role"] | ||
if "name" in message: | ||
messages_helper["name"] = message["name"] | ||
if "function_call" in message: | ||
messages_helper["function_call"] = message["function_call"] | ||
messages_helper["content"] = message["content"] | ||
if "context" in message: | ||
context_obj = json.loads(message["context"]) | ||
messages_helper["context"] = context_obj | ||
|
||
messages.append(messages_helper) | ||
|
||
|
||
user_json = None | ||
if (MS_DEFENDER_ENABLED): | ||
|
@@ -254,14 +291,18 @@ def prepare_model_args(request_body, request_headers): | |
"user": user_json | ||
} | ||
|
||
if app_settings.datasource: | ||
model_args["extra_body"] = { | ||
"data_sources": [ | ||
app_settings.datasource.construct_payload_configuration( | ||
request=request | ||
) | ||
] | ||
} | ||
if messages[-1]["role"] == "user": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a safety check here that len(messages) > 0 would be good |
||
if app_settings.azure_openai.function_call_azure_functions_enabled and len(azure_openai_tools) > 0: | ||
model_args["tools"] = azure_openai_tools | ||
|
||
if app_settings.datasource: | ||
model_args["extra_body"] = { | ||
"data_sources": [ | ||
app_settings.datasource.construct_payload_configuration( | ||
request=request | ||
) | ||
] | ||
} | ||
|
||
model_args_clean = copy.deepcopy(model_args) | ||
if model_args_clean.get("extra_body"): | ||
|
@@ -335,6 +376,43 @@ async def promptflow_request(request): | |
logging.error(f"An error occurred while making promptflow_request: {e}") | ||
|
||
|
||
def process_function_call(response): | ||
response_message = response.choices[0].message | ||
messages = [] | ||
|
||
if response_message.tool_calls: | ||
for tool_call in response_message.tool_calls: | ||
# Check if function exists | ||
if tool_call.function.name not in azure_openai_available_tools: | ||
continue | ||
|
||
function_response = openai_remote_azure_function_call(tool_call.function.name, tool_call.function.arguments) | ||
|
||
# adding assistant response to messages | ||
messages.append( | ||
{ | ||
"role": response_message.role, | ||
"function_call": { | ||
"name": tool_call.function.name, | ||
"arguments": tool_call.function.arguments, | ||
}, | ||
"content": None, | ||
} | ||
) | ||
|
||
# adding function response to messages | ||
messages.append( | ||
{ | ||
"role": "function", | ||
"name": tool_call.function.name, | ||
"content": function_response, | ||
} | ||
) # extend conversation with function response | ||
|
||
return messages | ||
|
||
return None | ||
|
||
async def send_chat_request(request_body, request_headers): | ||
filtered_messages = [] | ||
messages = request_body.get("messages", []) | ||
|
@@ -370,18 +448,107 @@ async def complete_chat_request(request_body, request_headers): | |
else: | ||
response, apim_request_id = await send_chat_request(request_body, request_headers) | ||
history_metadata = request_body.get("history_metadata", {}) | ||
return format_non_streaming_response(response, history_metadata, apim_request_id) | ||
non_streaming_response = format_non_streaming_response(response, history_metadata, apim_request_id) | ||
|
||
if app_settings.azure_openai.function_call_azure_functions_enabled: | ||
function_response = process_function_call(response) | ||
|
||
if function_response: | ||
request_body["messages"].extend(function_response) | ||
|
||
response, apim_request_id = await send_chat_request(request_body, request_headers) | ||
history_metadata = request_body.get("history_metadata", {}) | ||
non_streaming_response = format_non_streaming_response(response, history_metadata, apim_request_id) | ||
|
||
return non_streaming_response | ||
|
||
|
||
async def stream_chat_request(request_body, request_headers): | ||
response, apim_request_id = await send_chat_request(request_body, request_headers) | ||
history_metadata = request_body.get("history_metadata", {}) | ||
|
||
async def generate(): | ||
async for completionChunk in response: | ||
yield format_stream_response(completionChunk, history_metadata, apim_request_id) | ||
messages = [] | ||
|
||
async def generate(apim_request_id, history_metadata): | ||
tool_calls = [] | ||
current_tool_call = None | ||
tool_arguments_stream = "" | ||
function_messages = [] | ||
tool_name = "" | ||
tool_call_streaming_state = "INITIAL" | ||
|
||
return generate() | ||
async for completionChunk in response: | ||
if app_settings.azure_openai.function_call_azure_functions_enabled: | ||
if hasattr(completionChunk, "choices") and len(completionChunk.choices) > 0: | ||
response_message = completionChunk.choices[0].delta | ||
|
||
# Function calling stream processing | ||
if response_message.tool_calls and tool_call_streaming_state in ["INITIAL", "STREAMING"]: | ||
tool_call_streaming_state = "STREAMING" | ||
for tool_call_chunk in response_message.tool_calls: | ||
# New tool call | ||
if tool_call_chunk.id: | ||
if current_tool_call: | ||
tool_arguments_stream += tool_call_chunk.function.arguments if tool_call_chunk.function.arguments else "" | ||
current_tool_call["tool_arguments"] = tool_arguments_stream | ||
tool_arguments_stream = "" | ||
tool_name = "" | ||
tool_calls.append(current_tool_call) | ||
|
||
current_tool_call = { | ||
"tool_id": tool_call_chunk.id, | ||
"tool_name": tool_call_chunk.function.name if tool_name == "" else tool_name | ||
} | ||
else: | ||
tool_arguments_stream += tool_call_chunk.function.arguments if tool_call_chunk.function.arguments else "" | ||
|
||
# Function call - Streaming completed | ||
elif response_message.tool_calls is None and tool_call_streaming_state == "STREAMING": | ||
current_tool_call["tool_arguments"] = tool_arguments_stream | ||
tool_calls.append(current_tool_call) | ||
|
||
for tool_call in tool_calls: | ||
tool_response = openai_remote_azure_function_call(tool_call["tool_name"], tool_call["tool_arguments"]) | ||
|
||
function_messages.append({ | ||
"role": "assistant", | ||
"function_call": { | ||
"name" : tool_call["tool_name"], | ||
"arguments": tool_call["tool_arguments"] | ||
}, | ||
"content": None | ||
}) | ||
function_messages.append({ | ||
"tool_call_id": tool_call["tool_id"], | ||
"role": "function", | ||
"name": tool_call["tool_name"], | ||
"content": tool_response, | ||
}) | ||
|
||
# Reset for the next tool call | ||
messages = function_messages | ||
function_messages = [] | ||
tool_calls = [] | ||
current_tool_call = None | ||
tool_arguments_stream = "" | ||
tool_name = "" | ||
tool_id = None | ||
tool_call_streaming_state = "COMPLETED" | ||
|
||
request_body["messages"].extend(messages) | ||
|
||
function_response, apim_request_id = await send_chat_request(request_body, request_headers) | ||
|
||
async for functionCompletionChunk in function_response: | ||
yield format_stream_response(functionCompletionChunk, history_metadata, apim_request_id) | ||
goventur marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
else: | ||
# No function call, asistant response | ||
yield format_stream_response(completionChunk, history_metadata, apim_request_id) | ||
|
||
else: | ||
yield format_stream_response(completionChunk, history_metadata, apim_request_id) | ||
return generate(apim_request_id=apim_request_id, history_metadata=history_metadata) | ||
|
||
|
||
async def conversation_internal(request_body, request_headers): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -123,6 +123,11 @@ class _AzureOpenAISettings(BaseSettings): | |
embedding_endpoint: Optional[str] = None | ||
embedding_key: Optional[str] = None | ||
embedding_name: Optional[str] = None | ||
function_call_azure_functions_enabled: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we make this Optional[bool] for consistency? |
||
function_call_azure_functions_tools_key: Optional[str] = None | ||
function_call_azure_functions_tools_base_url: Optional[str] = None | ||
function_call_azure_functions_tool_key: Optional[str] = None | ||
function_call_azure_functions_tool_base_url: Optional[str] = None | ||
|
||
@field_validator('tools', mode='before') | ||
@classmethod | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use an async call here