Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch processing for ClaudeHandler #716

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from anthropic import Anthropic
from anthropic.types import TextBlock, ToolUseBlock
from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.beta.messages.batch_create_params import Request
from bfcl.model_handler.base_handler import BaseHandler
from bfcl.model_handler.constant import DEFAULT_SYSTEM_PROMPT, GORILLA_TO_OPENAPI
from bfcl.model_handler.model_style import ModelStyle
Expand Down Expand Up @@ -66,23 +68,43 @@ def decode_execute(self, result):
else:
function_call = convert_to_function_call(result)
return function_call

# Helper function to process the batch response and return results
def _process_batch_response(self, batch_response):
results = []
for result in batch_response:
if result.result.type == "succeeded":
results.append(result.result.message.content)
elif result.result.type == "errored":
print(f"Error: {result.result.error}")

return results

#### FC methods ####

def _query_FC(self, inference_data: dict):
inference_data["inference_input_log"] = {
"message": repr(inference_data["message"]),
"tools": inference_data["tools"],
}
# Initialize batch request list
batch_requests = []

# For each message in the inference data, add to the batch request
for message in inference_data["message"]:
batch_requests.append(
Request(
custom_id=f"fc-{message['content'][:20]}", # Custom ID for each request
params=MessageCreateParamsNonStreaming(
model=self.model_name.strip("-FC"),
max_tokens=8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096,
tools=inference_data["tools"],
messages=[message],
)
)
)

# Send the batch request
batch_response = self.client.beta.messages.batches.create(requests=batch_requests)

return self.client.messages.create(
model=self.model_name.strip("-FC"),
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
tools=inference_data["tools"],
messages=inference_data["message"],
)
# Process the batch response
return self._process_batch_response(batch_response)

def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
for round_idx in range(len(test_entry["question"])):
Expand Down Expand Up @@ -183,22 +205,29 @@ def _add_execution_results_FC(
#### Prompting methods ####

def _query_prompting(self, inference_data: dict):
inference_data["inference_input_log"] = {
"message": repr(inference_data["message"]),
"system_prompt": inference_data["system_prompt"],
}

api_response = self.client.messages.create(
model=self.model_name,
max_tokens=(
8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096
), # 3.5 Sonnet has a higher max token limit
temperature=self.temperature,
system=inference_data["system_prompt"],
messages=inference_data["message"],
)

return api_response
# Initialize batch request list
batch_requests = []

# Add all the messages to the batch
for message in inference_data["message"]:
batch_requests.append(
Request(
custom_id=f"prompt-{message['content'][:20]}",
params=MessageCreateParamsNonStreaming(
model=self.model_name,
max_tokens=8192 if "claude-3-5-sonnet-20240620" in self.model_name else 4096,
temperature=self.temperature,
messages=[message],
system=inference_data["system_prompt"],
)
)
)

# Send the batch request
batch_response = self.client.beta.messages.batches.create(requests=batch_requests)

# Process the batch response
return self._process_batch_response(batch_response)

def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
functions: list = test_entry["function"]
Expand Down Expand Up @@ -261,4 +290,4 @@ def _add_execution_results_prompting(
{"role": "user", "content": formatted_results_message}
)

return inference_data
return inference_data