Skip to content

Commit

Permalink
[Frontend-go] Add unit test for object type detect and fix logic (#1972)
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Chan <[email protected]>
  • Loading branch information
arthurscchan authored Jan 14, 2025
1 parent 370fe00 commit 4577575
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 79 deletions.
213 changes: 135 additions & 78 deletions src/fuzz_introspector/frontends/frontend_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,14 @@ class Project():

def __init__(self, source_code_files: list[SourceCodeFile]):
self.source_code_files = source_code_files
self.full_functions_methods = [
full_functions_methods = [
item for src in source_code_files
for item in src.functions + src.methods
]
self.functions_methods_map = {
item.function_name: item
for item in full_functions_methods
}

def dump_module_logic(self,
report_name: str,
Expand Down Expand Up @@ -217,8 +221,8 @@ def dump_module_logic(self,
functions_methods = source_code.functions + source_code.methods
for func_def in functions_methods:
func_def.extract_local_variable_type(
self.full_functions_methods)
func_def.extract_callsites()
self.functions_methods_map)
func_def.extract_callsites(self.functions_methods_map)
func_dict: dict[str, Any] = {}
func_dict['functionName'] = func_def.function_name
func_dict['functionSourceFile'] = source_code.source_file
Expand All @@ -239,9 +243,9 @@ def dump_module_logic(self,
func_dict['BranchProfiles'] = []
func_dict['Callsites'] = func_def.detailed_callsites
func_dict['functionUses'] = func_def.get_function_uses(
self.full_functions_methods)
list(self.functions_methods_map.values()))
func_dict['functionDepth'] = func_def.get_function_depth(
self.full_functions_methods)
list(self.functions_methods_map.values()))
func_dict['constantsTouched'] = []
func_dict['BBCount'] = 0
func_dict['signature'] = func_def.sig
Expand Down Expand Up @@ -383,6 +387,7 @@ def __init__(self, root: Node, tree_sitter_lang: Language,
# Other properties
self.function_name = ''
self.receiver = ''
self.receiver_name = ''
self.complexity = 0
self.icount = 0
self.arg_names: list[str] = []
Expand All @@ -392,7 +397,7 @@ def __init__(self, root: Node, tree_sitter_lang: Language,
self.function_uses = 0
self.function_depth = 0
self.base_callsites: list[tuple[str, int]] = []
self.detailed_callsites: list[tuple[str, int]] = []
self.detailed_callsites: list[dict[str, str]] = []
self.var_map: dict[str, str] = {}

# Process properties
Expand Down Expand Up @@ -457,12 +462,14 @@ def _process_properties(self):
receiver = self.root.child_by_field_name('receiver')
if receiver:
for child in receiver.children:
receiver_name = child.child_by_field_name('name')
receiver_type = child.child_by_field_name('type')
if child.type == 'parameter_declaration':
receiver_name = child.child_by_field_name('name')
receiver_type = child.child_by_field_name('type')

if receiver_name and receiver_type:
self.receiver = receiver_type.text.decode()
self.var_map[receiver_name.text.decode()] = self.receiver
if receiver_name and receiver_type:
self.receiver = receiver_type.text.decode()
self.receiver_name = receiver_name.text.decode()
self.var_map[self.receiver_name] = self.receiver

# Process name
name_node = self.root
Expand All @@ -473,25 +480,26 @@ def _process_properties(self):
self.function_name = f'{self.receiver}.{self.function_name}'

# Process arguments
param_names: list[str] = []
param_types: list[str] = []
param_names = []
param_types = []
query = self.tree_sitter_lang.query('( parameter_list ) @pl')
for _, exprs in query.captures(self.root).items():
for param_node in exprs:
for param in param_node.children:
if not param.is_named:
continue

param_name = ''
param_type = ''

# Param name
param_tmp = param
while param_tmp.child_by_field_name('name') is not None:
param_tmp = param_tmp.child_by_field_name('name')
param_names.append(param_tmp.text.decode())
param_name = param_tmp.text.decode()

# Param type
if not param.child_by_field_name('type'):
param_types.append('')
else:
if param.child_by_field_name('type'):
type_str = param.child_by_field_name(
'type').text.decode()
param_tmp = param
Expand All @@ -501,9 +509,13 @@ def _process_properties(self):
type_str += '*'
param_tmp = param_tmp.child_by_field_name(
'declarator')
param_types.append(type_str)
param_type = type_str

self.var_map[param_names[-1]] = param_types[-1]
if param_name:
if param_name != self.receiver_name:
param_names.append(param_name)
param_types.append(param_type)
self.var_map[param_names[-1]] = param_types[-1]

self.arg_names = param_names
self.arg_types = param_types
Expand Down Expand Up @@ -578,8 +590,95 @@ def _traverse_node_instr_count(node: Node) -> int:

self.icount = _traverse_node_instr_count(self.root)

def _process_call_expr_child(
self, call_child: Node,
all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]:
"""Internal helper to process call expr."""
target_name = None

# Simple call
if call_child.type == 'identifier':
target_name = call_child.text.decode()

# Package/method call
if call_child.type == 'selector_expression':
target_name = call_child.text.decode()

# Variable call
split_call = target_name.split('.')
if len(split_call) > 1:
var_name = self.var_map.get(split_call[-2])
if var_name:
target_name = f'{var_name}.{split_call[-1]}'

elif split_call[0] not in self.parent_source.imports:
target_name = target_name.split('.')[-1]

# Chain call
split_call = target_name.rsplit(').', 1)
if len(split_call) > 1:
target_name = split_call[1]

return target_name

def _detect_variable_type(
self, node: Node,
all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]:
"""Internal recursive helper to determine the return type of the expression."""

for child in node.children:
# Literals
if child.type in LITERAL_TYPE_MAP:
return LITERAL_TYPE_MAP[child.type]

# Identifier
elif child.type == 'identifier':
if child.text.decode() in self.var_map:
return self.var_map[child.text.decode()]

# Composite Literal
elif child.type == 'composite_literal':
composite_type = child.child_by_field_name('type')
if composite_type:
return composite_type.text.decode()

# Call expression
elif child.type == 'call_expression':
call = child.child_by_field_name('function')
args = child.child_by_field_name('arguments')
target_name = self._process_call_expr_child(
call, all_funcs_meths)
if target_name in all_funcs_meths:
return all_funcs_meths[target_name].return_type

elif target_name == 'new':
for arg in args.children:
if arg.type.endswith('identifier'):
return arg.text.decode()

# Selector expression
elif child.type == 'selector_expression':
target_name = self._process_call_expr_child(
child, all_funcs_meths)
if target_name:
return target_name

# TODO Handles the following type
# index_expression slice_expression
# type_assertion_expression type_conversion_expression
# type_instantiation_expression

# Other expression that need to recursive deeper
# unary_expression binary_expression
# parenthesized_expression
else:
return self._detect_variable_type(child, all_funcs_meths)

return None

def extract_local_variable_type(self,
all_funcs_meths: list['FunctionMethod']):
all_funcs_meths: dict[str,
'FunctionMethod']):
"""Gets the local variable types of the function."""
# TODO The handling of all kind of variable declaration approach is not done.
# There are some requires extensive search to determine a type.
Expand All @@ -593,81 +692,39 @@ def extract_local_variable_type(self,
query = self.tree_sitter_lang.query('( short_var_declaration ) @vd')
for _, exprs in query.captures(self.root).items():
for decl_node in exprs:
decl_name = ''
decl_type = ''
left = decl_node.child_by_field_name('left')
right = decl_node.child_by_field_name('right')

for child in left.children:
if child.type == 'identifier':
decl_name = child.text.decode()

for child in right.children:
# Literals
if child.type in LITERAL_TYPE_MAP:
decl_type = LITERAL_TYPE_MAP[child.type]

# Identifier
elif child.type == 'identifier':
if child.text.decode() in self.var_map:
decl_type = self.var_map[child.text.decode()]

# Composite Literal
elif child.type == 'composite_literal':
composite_type = child.child_by_field_name('type')
if composite_type:
decl_type = composite_type.text.decode()

# TODO Handles the following type
# unary_expression binary_expression selector_expression
# index_expression slice_expression call_expression
# type_assertion_expression type_conversion_expression
# type_instantiation_expression new make
# parenthesized_expression
decl_type = self._detect_variable_type(right, all_funcs_meths)

if decl_name and decl_type:
self.var_map[decl_name] = decl_type

def extract_callsites(self):
def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']):
"""Gets the callsites of the function."""

callsites = []
call_query = self.tree_sitter_lang.query('( call_expression ) @ce')
call_res = call_query.captures(self.root)
for _, call_exprs in call_res.items():
for call_expr in call_exprs:
for call_child in call_expr.children:
target_name = None

# Simple call
if call_child.type == 'identifier':
target_name = call_child.text.decode()

# Package/method call
if call_child.type == 'selector_expression':
target_name = call_child.text.decode()

# Variable call
split_call = target_name.split('.')
if len(split_call) > 1:
var_name = self.var_map.get(split_call[-2])
if var_name:
target_name = f'{var_name}.{split_call[-1]}'

elif split_call[
0] not in self.parent_source.imports:
target_name = target_name.split('.')[-1]

# Chain call
split_call = target_name.rsplit(').', 1)
if len(split_call) > 1:
target_name = split_call[1]

if target_name:
callsites.append((
target_name,
call_child.byte_range,
call_child.start_point.row + 1,
))
call = call_expr.child_by_field_name('function')
target_name = self._process_call_expr_child(
call, all_funcs_meths)
if target_name in ['new', 'make']:
if target_name not in all_funcs_meths:
target_name = None

if target_name:
callsites.append((
target_name,
call_expr.byte_range,
call_expr.start_point.row + 1,
))

callsites = sorted(callsites, key=lambda x: x[1][1])
self.base_callsites = [(x[0], x[2]) for x in callsites]
Expand Down
Loading

0 comments on commit 4577575

Please sign in to comment.