Skip to content

Commit

Permalink
优化组件标准化单测框架:更新系统变量,增加tool_eval参数和manifests匹配性检查
Browse files Browse the repository at this point in the history
  • Loading branch information
yepeiwen01 committed Dec 19, 2024
1 parent 9a95c5e commit e39043b
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 178 deletions.
246 changes: 150 additions & 96 deletions python/tests/component_check.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
import json
import os
import inspect
import time
from jsonschema import validate, ValidationError, SchemaError
from jsonschema import validate
from pydantic import BaseModel
from typing import Generator
from appbuilder.utils.func_utils import Singleton
from appbuilder.tests.component_schemas import type_to_json_schemas
from appbuilder.utils.json_schema_to_model import json_schema_to_pydantic_model
Expand Down Expand Up @@ -40,12 +38,15 @@ def register_rule(self, rule_name: str, rule_obj: RuleBase):
def remove_rule(self, rule_name: str):
del self.rules[rule_name]

def notify(self, component_cls) -> tuple[bool, list]:
def notify(self, component_cls, component_case) -> tuple[bool, list]:
check_pass = True
check_details = {}
reasons = []
for rule_name, rule_obj in self.rules.items():
res = rule_obj.check(component_cls)
if rule_name == "ToolEvalOutputJsonRule":
res = rule_obj.check(component_cls, component_case)
else:
res = rule_obj.check(component_cls)
check_details[rule_name] = res
if res.check_result == False:
check_pass = False
Expand All @@ -63,53 +64,40 @@ class ManifestValidRule(RuleBase):
def __init__(self, **kwargs):
super().__init__()
self.rule_name = "ManifestValidRule"
self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases", {})

def check(self, component_cls) -> CheckInfo:
def check(self, component_obj) -> CheckInfo:
check_pass_flag = True
invalid_details = []
component_cls_name = component_cls.__name__
if component_cls_name not in self.component_tool_eval_cases:
invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name))
else:
component_case = self.component_tool_eval_cases[component_cls_name]()
envs = component_case.envs()
os.environ.update(envs)
init_args = component_case.init_args()

try:
component_obj = component_cls(**init_args)
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
manifest = manifests[0]
tool_name = manifest['name']
tool_desc = manifest['description']
schema = manifest["parameters"]
schema["title"] = tool_name
# 第一步,将json schema转换为pydantic模型
pydantic_model = json_schema_to_pydantic_model(schema, tool_name)
check_to_json = pydantic_model.schema_json()
json_to_dict = json.loads(check_to_json)

if "properties" in schema:
properties = schema["properties"]
for key, value in properties.items():
if "type" not in value:
invalid_details.append("\'type' must be in properties item: {}".format(key))
if "description" not in value:
invalid_details.append("\'description' must be in properties item: {}".format(key))

except Exception as e:
print(e)
check_pass_flag = False
invalid_details.append(str(e))

for env in envs:
os.environ.pop(env)
try:
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
manifest = manifests[0]
tool_name = manifest['name']
tool_desc = manifest['description']
schema = manifest["parameters"]
schema["title"] = tool_name
# 第一步,将json schema转换为pydantic模型
pydantic_model = json_schema_to_pydantic_model(schema, tool_name)
check_to_json = pydantic_model.schema_json()
json_to_dict = json.loads(check_to_json)

if "properties" in schema:
properties = schema["properties"]
for key, value in properties.items():
if "type" not in value:
invalid_details.append("\'type' must be in properties item: {}".format(key))
if "description" not in value:
invalid_details.append("\'description' must be in properties item: {}".format(key))

except Exception as e:
print(e)
check_pass_flag = False
invalid_details.append(str(e))

if len(invalid_details) > 0:
check_pass_flag = False
Expand Down Expand Up @@ -137,14 +125,14 @@ def __init__(self):
self.rule_name = "MainfestMatchToolEvalRule"


def check(self, component_cls) -> CheckInfo:
def check(self, component_obj) -> CheckInfo:
check_pass_flag = True
invalid_details = []

try:
if not hasattr(component_cls, "manifests"):
if not hasattr(component_obj, "manifests"):
raise ValueError("No manifests found")
manifests = component_cls.manifests
manifests = component_obj.manifests
# NOTE(暂时检查manifest中的第一个mainfest)
if not manifests or len(manifests) == 0:
raise ValueError("No manifests found")
Expand All @@ -158,7 +146,7 @@ def check(self, component_cls) -> CheckInfo:
# 交互检查
tool_eval_input_params = []
print("required_params: {}".format(manifest_var))
signature = inspect.signature(component_cls.tool_eval)
signature = inspect.signature(component_obj.tool_eval)
ileagal_params = []
for param_name, param in signature.parameters.items():
if param_name == 'kwargs' or param_name == 'args' or param_name == 'self':
Expand Down Expand Up @@ -193,10 +181,6 @@ def check(self, component_cls) -> CheckInfo:
check_detail=",".join(invalid_details))






class ToolEvalInputNameRule(RuleBase):
"""
检查tool_eval的输入参数中,是否包含系统保留的输入名称
Expand All @@ -222,10 +206,15 @@ def __init__(self):
"_sys_custom_variables",
"_sys_thought_model_config",
"_sys_rag_model_config",
"_sys_parent_span_id",
"_sys_span_id",
"_sys_memory",
"_sys_code_execution_endpoint",
"_sys_session_id"
]

def check(self, component_cls) -> CheckInfo:
tool_eval_signature = inspect.signature(component_cls.__init__)
def check(self, component_obj) -> CheckInfo:
tool_eval_signature = inspect.signature(component_obj.tool_eval)
params = tool_eval_signature.parameters
invalid_details = []
check_pass_flag = True
Expand All @@ -250,7 +239,6 @@ class ToolEvalOutputJsonRule(RuleBase):
def __init__(self, **kwargs):
super().__init__()
self.rule_name = 'ToolEvalOutputJsonRule'
self.component_tool_eval_cases = kwargs.get("component_tool_eval_cases")

def _check_pre_format(self, outputs):
invalid_details = []
Expand Down Expand Up @@ -351,42 +339,26 @@ def _check_text_and_code(self, component_case, output_dict):
else:
return []

def check(self, component_cls) -> CheckInfo:
def check(self, component_obj, component_case) -> CheckInfo:
invalid_details = []
component_cls_name = component_cls.__name__

if component_cls_name not in self.component_tool_eval_cases:
invalid_details.append("{} 没有添加测试case到 component_tool_eval_cases 中".format(component_cls_name))
else:
component_case = self.component_tool_eval_cases[component_cls_name]()

envs = {}
if hasattr(component_case, "envs"):
envs = component_case.envs()
os.environ.update(envs)

input_dict = component_case.inputs()
init_args = component_case.init_args()
component_obj = component_cls(**init_args)
output_json_schemas = component_case.schemas()

try:
stream_output_dict = {"text": "", "oral_text":"", "code": ""}
stream_outputs = component_obj.tool_eval(**input_dict)
for stream_output in stream_outputs:
iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas)
invalid_details.extend(iter_invalid_detail)
iter_output_dict = self._gather_iter_outputs(stream_output)
stream_output_dict["text"] += iter_output_dict["text"]
stream_output_dict["oral_text"] += iter_output_dict["oral_text"]
stream_output_dict["code"] += iter_output_dict["code"]
if len(invalid_details) == 0:
invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict))
except Exception as e:
invalid_details.append("ToolEval执行失败: {}".format(e))

for env in envs:
os.environ.pop(env)
input_dict = component_case.inputs()
output_json_schemas = component_case.schemas()

try:
stream_output_dict = {"text": "", "oral_text":"", "code": ""}
stream_outputs = component_obj.tool_eval(**input_dict)
for stream_output in stream_outputs:
iter_invalid_detail = self._check_jsonschema(stream_output.model_dump(), output_json_schemas)
invalid_details.extend(iter_invalid_detail)
iter_output_dict = self._gather_iter_outputs(stream_output)
stream_output_dict["text"] += iter_output_dict["text"]
stream_output_dict["oral_text"] += iter_output_dict["oral_text"]
stream_output_dict["code"] += iter_output_dict["code"]
if len(invalid_details) == 0:
invalid_details.extend(self._check_text_and_code(component_case, stream_output_dict))
except Exception as e:
invalid_details.append("ToolEval执行失败: {}".format(e))

if len(invalid_details) > 0:
return CheckInfo(
Expand All @@ -400,6 +372,88 @@ def check(self, component_cls) -> CheckInfo:
check_detail="")


def register_component_check_rule(rule_name: str, rule_cls: RuleBase, init_args={}):
def register_component_check_rule(rule_name: str, rule_cls: RuleBase):
component_checker = ComponentCheckBase()
component_checker.register_rule(rule_name, rule_cls(**init_args))
component_checker.register_rule(rule_name, rule_cls())


def check_component_with_retry(component_import_res_tuple):
"""
使用重试机制检查组件。测试用例失败后会重试两次。
Args:
component_import_res_tuple (tuple): 包含组件和导入结果的元组。
Returns:
list: 包含错误信息的数据列表。
"""
component, import_res, component_case_cls = component_import_res_tuple
component_check_base = ComponentCheckBase()

error_data = []
max_retries = 2 # 设置最大重试次数
attempts = 0

while attempts <= max_retries:
if import_res["import_error"] != "":
error_data.append({"Component Name": component, "Error Message": import_res["import_error"]})
print("组件名称:{} 错误信息:{}".format(component, import_res["import_error"]))
break

component_case = component_case_cls()
envs = component_case.envs()
os.environ.update(envs)
component_cls = import_res["obj"]
component_obj = component_cls(**component_case.init_args())

try:
# 此处的self.component_check_base.notify需要根据实际情况修改
pass_check, reasons = component_check_base.notify(component_obj, component_case) # 示例修改
reasons = list(set(reasons))
if not pass_check:
error_data.append({"Component Name": component, "Error Message": ", ".join(reasons)})
print("组件名称:{} 错误信息:{}".format(component, ", ".join(reasons)))
# 如果检查失败,增加尝试次数并重试
attempts += 1
if attempts <= max_retries:
print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts))
continue
# 如果检查通过,则退出循环
break
except Exception as e:
error_data.append({"Component Name": component, "Error Message": str(e)})
print("组件名称:{} 错误信息:{}".format(component, str(e)))
# 如果发生异常,增加尝试次数并重试
attempts += 1
if attempts <= max_retries:
print("组件名称:{} 将重试,当前尝试次数:{}".format(component, attempts))
continue

finally:
for env in envs:
os.environ.pop(env)

return error_data

def write_error_data(txt_file_path, error_df, error_stats):
"""将组件错误信息写入文件
Args:
error_df (Union[pd.DataFrame, None]): 错误信息表格
error_stats (dict): 错误统计信息
"""
with open(txt_file_path, 'w') as file:
file.write("Component Name\tError Message\n")
for _, row in error_df.iterrows():
file.write(f"{row['Component Name']}\t{row['Error Message']}\n")
file.write("\n错误统计信息:\n")
for error, count in error_stats.items():
file.write(f"错误信息: {error}, 出现次数: {count}\n")
print(f"\n错误信息已写入: {txt_file_path}")


register_component_check_rule("ManifestValidRule", ManifestValidRule)
register_component_check_rule("MainfestMatchToolEvalRule", MainfestMatchToolEvalRule)
register_component_check_rule("ToolEvalInputNameRule", ToolEvalInputNameRule)
register_component_check_rule("ToolEvalOutputJsonRule", ToolEvalOutputJsonRule)
Loading

0 comments on commit e39043b

Please sign in to comment.