-
Notifications
You must be signed in to change notification settings - Fork 96
/
utils.py
159 lines (136 loc) · 5.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import ast
import os
import re
import json
import logging
import datetime
import xml.etree.ElementTree as ET
from art import text2art
from logging.handlers import RotatingFileHandler
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
script_dir = os.path.dirname(os.path.abspath(__file__))
now = datetime.datetime.now()
log_folder = os.path.join(script_dir, "inference_logs")
os.makedirs(log_folder, exist_ok=True)
log_file_path = os.path.join(
log_folder, f"function-calling-inference_{now.strftime('%Y-%m-%d_%H-%M-%S')}.log"
)
# Use RotatingFileHandler from the logging.handlers module
file_handler = RotatingFileHandler(log_file_path, maxBytes=0, backupCount=0)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d:%H:%M:%S")
file_handler.setFormatter(formatter)
inference_logger = logging.getLogger("function-calling-inference")
inference_logger.addHandler(file_handler)
def print_nous_text_art(suffix=None):
font = "nancyj"
ascii_text = " nousresearch"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
print(ascii_art)
def get_fewshot_examples(num_fewshot):
"""return a list of few shot examples"""
example_path = os.path.join(script_dir, 'prompt_assets', 'few_shot.json')
with open(example_path, 'r') as file:
examples = json.load(file) # Use json.load with the file object, not the file path
if num_fewshot > len(examples):
raise ValueError(f"Not enough examples (got {num_fewshot}, but there are only {len(examples)} examples).")
return examples[:num_fewshot]
def get_chat_template(chat_template):
"""read chat template from jinja file"""
template_path = os.path.join(script_dir, 'chat_templates', f"{chat_template}.j2")
if not os.path.exists(template_path):
print
inference_logger.error(f"Template file not found: {chat_template}")
return None
try:
with open(template_path, 'r') as file:
template = file.read()
return template
except Exception as e:
print(f"Error loading template: {e}")
return None
def get_assistant_message(completion, chat_template, eos_token):
"""define and match pattern to find the assistant message"""
completion = completion.strip()
if chat_template == "zephyr":
assistant_pattern = re.compile(r'<\|assistant\|>((?:(?!<\|assistant\|>).)*)$', re.DOTALL)
elif chat_template == "chatml":
assistant_pattern = re.compile(r'<\|im_start\|>\s*assistant((?:(?!<\|im_start\|>\s*assistant).)*)$', re.DOTALL)
elif chat_template == "vicuna":
assistant_pattern = re.compile(r'ASSISTANT:\s*((?:(?!ASSISTANT:).)*)$', re.DOTALL)
else:
raise NotImplementedError(f"Handling for chat_template '{chat_template}' is not implemented.")
assistant_match = assistant_pattern.search(completion)
if assistant_match:
assistant_content = assistant_match.group(1).strip()
if chat_template == "vicuna":
eos_token = f"</s>{eos_token}"
return assistant_content.replace(eos_token, "")
else:
assistant_content = None
inference_logger.info("No match found for the assistant pattern")
return assistant_content
def validate_and_extract_tool_calls(assistant_content):
validation_result = False
tool_calls = []
error_message = None
try:
# wrap content in root element
xml_root_element = f"<root>{assistant_content}</root>"
root = ET.fromstring(xml_root_element)
# extract JSON data
for element in root.findall(".//tool_call"):
json_data = None
try:
json_text = element.text.strip()
try:
# Prioritize json.loads for better error handling
json_data = json.loads(json_text)
except json.JSONDecodeError as json_err:
try:
# Fallback to ast.literal_eval if json.loads fails
json_data = ast.literal_eval(json_text)
except (SyntaxError, ValueError) as eval_err:
error_message = f"JSON parsing failed with both json.loads and ast.literal_eval:\n"\
f"- JSON Decode Error: {json_err}\n"\
f"- Fallback Syntax/Value Error: {eval_err}\n"\
f"- Problematic JSON text: {json_text}"
inference_logger.error(error_message)
continue
except Exception as e:
error_message = f"Cannot strip text: {e}"
inference_logger.error(error_message)
if json_data is not None:
tool_calls.append(json_data)
validation_result = True
except ET.ParseError as err:
error_message = f"XML Parse Error: {err}"
inference_logger.error(f"XML Parse Error: {err}")
# Return default values if no valid data is extracted
return validation_result, tool_calls, error_message
def extract_json_from_markdown(text):
"""
Extracts the JSON string from the given text using a regular expression pattern.
Args:
text (str): The input text containing the JSON string.
Returns:
dict: The JSON data loaded from the extracted string, or None if the JSON string is not found.
"""
json_pattern = r'```json\r?\n(.*?)\r?\n```'
match = re.search(json_pattern, text, re.DOTALL)
if match:
json_string = match.group(1)
try:
data = json.loads(json_string)
return data
except json.JSONDecodeError as e:
print(f"Error decoding JSON string: {e}")
else:
print("JSON string not found in the text.")
return None