diff --git a/src/fuzz_introspector/frontends/frontend_go.py b/src/fuzz_introspector/frontends/frontend_go.py index 52921b08f..be435bc28 100644 --- a/src/fuzz_introspector/frontends/frontend_go.py +++ b/src/fuzz_introspector/frontends/frontend_go.py @@ -185,10 +185,14 @@ class Project(): def __init__(self, source_code_files: list[SourceCodeFile]): self.source_code_files = source_code_files - self.full_functions_methods = [ + full_functions_methods = [ item for src in source_code_files for item in src.functions + src.methods ] + self.functions_methods_map = { + item.function_name: item + for item in full_functions_methods + } def dump_module_logic(self, report_name: str, @@ -217,8 +221,8 @@ def dump_module_logic(self, functions_methods = source_code.functions + source_code.methods for func_def in functions_methods: func_def.extract_local_variable_type( - self.full_functions_methods) - func_def.extract_callsites() + self.functions_methods_map) + func_def.extract_callsites(self.functions_methods_map) func_dict: dict[str, Any] = {} func_dict['functionName'] = func_def.function_name func_dict['functionSourceFile'] = source_code.source_file @@ -239,9 +243,9 @@ def dump_module_logic(self, func_dict['BranchProfiles'] = [] func_dict['Callsites'] = func_def.detailed_callsites func_dict['functionUses'] = func_def.get_function_uses( - self.full_functions_methods) + list(self.functions_methods_map.values())) func_dict['functionDepth'] = func_def.get_function_depth( - self.full_functions_methods) + list(self.functions_methods_map.values())) func_dict['constantsTouched'] = [] func_dict['BBCount'] = 0 func_dict['signature'] = func_def.sig @@ -383,6 +387,7 @@ def __init__(self, root: Node, tree_sitter_lang: Language, # Other properties self.function_name = '' self.receiver = '' + self.receiver_name = '' self.complexity = 0 self.icount = 0 self.arg_names: list[str] = [] @@ -392,7 +397,7 @@ def __init__(self, root: Node, tree_sitter_lang: Language, self.function_uses = 0 self.function_depth = 0 self.base_callsites: list[tuple[str, int]] = [] - self.detailed_callsites: list[tuple[str, int]] = [] + self.detailed_callsites: list[dict[str, str]] = [] self.var_map: dict[str, str] = {} # Process properties @@ -457,12 +462,14 @@ def _process_properties(self): receiver = self.root.child_by_field_name('receiver') if receiver: for child in receiver.children: - receiver_name = child.child_by_field_name('name') - receiver_type = child.child_by_field_name('type') + if child.type == 'parameter_declaration': + receiver_name = child.child_by_field_name('name') + receiver_type = child.child_by_field_name('type') - if receiver_name and receiver_type: - self.receiver = receiver_type.text.decode() - self.var_map[receiver_name.text.decode()] = self.receiver + if receiver_name and receiver_type: + self.receiver = receiver_type.text.decode() + self.receiver_name = receiver_name.text.decode() + self.var_map[self.receiver_name] = self.receiver # Process name name_node = self.root @@ -473,8 +480,8 @@ def _process_properties(self): self.function_name = f'{self.receiver}.{self.function_name}' # Process arguments - param_names: list[str] = [] - param_types: list[str] = [] + param_names = [] + param_types = [] query = self.tree_sitter_lang.query('( parameter_list ) @pl') for _, exprs in query.captures(self.root).items(): for param_node in exprs: @@ -482,16 +489,17 @@ def _process_properties(self): if not param.is_named: continue + param_name = '' + param_type = '' + # Param name param_tmp = param while param_tmp.child_by_field_name('name') is not None: param_tmp = param_tmp.child_by_field_name('name') - param_names.append(param_tmp.text.decode()) + param_name = param_tmp.text.decode() # Param type - if not param.child_by_field_name('type'): - param_types.append('') - else: + if param.child_by_field_name('type'): type_str = param.child_by_field_name( 'type').text.decode() param_tmp = param @@ -501,9 +509,13 @@ def _process_properties(self): type_str += '*' param_tmp = param_tmp.child_by_field_name( 'declarator') - param_types.append(type_str) + param_type = type_str - self.var_map[param_names[-1]] = param_types[-1] + if param_name: + if param_name != self.receiver_name: + param_names.append(param_name) + param_types.append(param_type) + self.var_map[param_names[-1]] = param_types[-1] self.arg_names = param_names self.arg_types = param_types @@ -578,8 +590,95 @@ def _traverse_node_instr_count(node: Node) -> int: self.icount = _traverse_node_instr_count(self.root) + def _process_call_expr_child( + self, call_child: Node, + all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]: + """Internal helper to process call expr.""" + target_name = None + + # Simple call + if call_child.type == 'identifier': + target_name = call_child.text.decode() + + # Package/method call + if call_child.type == 'selector_expression': + target_name = call_child.text.decode() + + # Variable call + split_call = target_name.split('.') + if len(split_call) > 1: + var_name = self.var_map.get(split_call[-2]) + if var_name: + target_name = f'{var_name}.{split_call[-1]}' + + elif split_call[0] not in self.parent_source.imports: + target_name = target_name.split('.')[-1] + + # Chain call + split_call = target_name.rsplit(').', 1) + if len(split_call) > 1: + target_name = split_call[1] + + return target_name + + def _detect_variable_type( + self, node: Node, + all_funcs_meths: dict[str, 'FunctionMethod']) -> Optional[str]: + """Internal recursive helper to determine the return type of the expression.""" + + for child in node.children: + # Literals + if child.type in LITERAL_TYPE_MAP: + return LITERAL_TYPE_MAP[child.type] + + # Identifier + elif child.type == 'identifier': + if child.text.decode() in self.var_map: + return self.var_map[child.text.decode()] + + # Composite Literal + elif child.type == 'composite_literal': + composite_type = child.child_by_field_name('type') + if composite_type: + return composite_type.text.decode() + + # Call expression + elif child.type == 'call_expression': + call = child.child_by_field_name('function') + args = child.child_by_field_name('arguments') + target_name = self._process_call_expr_child( + call, all_funcs_meths) + if target_name in all_funcs_meths: + return all_funcs_meths[target_name].return_type + + elif target_name == 'new': + for arg in args.children: + if arg.type.endswith('identifier'): + return arg.text.decode() + + # Selector expression + elif child.type == 'selector_expression': + target_name = self._process_call_expr_child( + child, all_funcs_meths) + if target_name: + return target_name + + # TODO Handles the following type + # index_expression slice_expression + # type_assertion_expression type_conversion_expression + # type_instantiation_expression + + # Other expression that need to recursive deeper + # unary_expression binary_expression + # parenthesized_expression + else: + return self._detect_variable_type(child, all_funcs_meths) + + return None + def extract_local_variable_type(self, - all_funcs_meths: list['FunctionMethod']): + all_funcs_meths: dict[str, + 'FunctionMethod']): """Gets the local variable types of the function.""" # TODO The handling of all kind of variable declaration approach is not done. # There are some requires extensive search to determine a type. @@ -593,41 +692,19 @@ def extract_local_variable_type(self, query = self.tree_sitter_lang.query('( short_var_declaration ) @vd') for _, exprs in query.captures(self.root).items(): for decl_node in exprs: - decl_name = '' - decl_type = '' left = decl_node.child_by_field_name('left') right = decl_node.child_by_field_name('right') + for child in left.children: if child.type == 'identifier': decl_name = child.text.decode() - for child in right.children: - # Literals - if child.type in LITERAL_TYPE_MAP: - decl_type = LITERAL_TYPE_MAP[child.type] - - # Identifier - elif child.type == 'identifier': - if child.text.decode() in self.var_map: - decl_type = self.var_map[child.text.decode()] - - # Composite Literal - elif child.type == 'composite_literal': - composite_type = child.child_by_field_name('type') - if composite_type: - decl_type = composite_type.text.decode() - - # TODO Handles the following type - # unary_expression binary_expression selector_expression - # index_expression slice_expression call_expression - # type_assertion_expression type_conversion_expression - # type_instantiation_expression new make - # parenthesized_expression + decl_type = self._detect_variable_type(right, all_funcs_meths) if decl_name and decl_type: self.var_map[decl_name] = decl_type - def extract_callsites(self): + def extract_callsites(self, all_funcs_meths: dict[str, 'FunctionMethod']): """Gets the callsites of the function.""" callsites = [] @@ -635,39 +712,19 @@ def extract_callsites(self): call_res = call_query.captures(self.root) for _, call_exprs in call_res.items(): for call_expr in call_exprs: - for call_child in call_expr.children: - target_name = None - - # Simple call - if call_child.type == 'identifier': - target_name = call_child.text.decode() - - # Package/method call - if call_child.type == 'selector_expression': - target_name = call_child.text.decode() - - # Variable call - split_call = target_name.split('.') - if len(split_call) > 1: - var_name = self.var_map.get(split_call[-2]) - if var_name: - target_name = f'{var_name}.{split_call[-1]}' - - elif split_call[ - 0] not in self.parent_source.imports: - target_name = target_name.split('.')[-1] - - # Chain call - split_call = target_name.rsplit(').', 1) - if len(split_call) > 1: - target_name = split_call[1] - - if target_name: - callsites.append(( - target_name, - call_child.byte_range, - call_child.start_point.row + 1, - )) + call = call_expr.child_by_field_name('function') + target_name = self._process_call_expr_child( + call, all_funcs_meths) + if target_name in ['new', 'make']: + if target_name not in all_funcs_meths: + target_name = None + + if target_name: + callsites.append(( + target_name, + call_expr.byte_range, + call_expr.start_point.row + 1, + )) callsites = sorted(callsites, key=lambda x: x[1][1]) self.base_callsites = [(x[0], x[2]) for x in callsites] diff --git a/src/test/data/source-code/go/test-project-3/fuzzer.go b/src/test/data/source-code/go/test-project-3/fuzzer.go new file mode 100644 index 000000000..5645e6a49 --- /dev/null +++ b/src/test/data/source-code/go/test-project-3/fuzzer.go @@ -0,0 +1,93 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package structs + +import ( + "testing" + "fmt" + "strconv" +) + +type Person struct { + Name string + Age int +} + +func (p Person) Greet() string { + return fmt.Sprintf("Hello, my name is %s and I am %d years old.", p.Name, p.Age) +} + +func (p Person) Introduce() string { + return fmt.Sprintf("I am %s, a person of age %d.", p.Name, p.Age) +} + +func (p Person) Describe() string { + return fmt.Sprintf("Person: %s, Age: %d", p.Name, p.Age) +} + +type Dog struct { + Name string +} + +func (d Dog) Greet() string { + return fmt.Sprintf("Hello, my dog's name is %s.", d.Name) +} + +func (d Dog) Introduce() string { + return fmt.Sprintf("This is my dog, %s.", d.Name) +} + +func (d Dog) Describe() string { + return fmt.Sprintf("Dog: %s", d.Name) +} + +func NewDog(name string) Dog { + return Dog{Name: name} +} + +type Robot struct { + Model string +} + +func (r Robot) Greet() string { + return fmt.Sprintf("Hello, I am a robot of model %s.", r.Model) +} + +func (r Robot) Introduce() string { + return fmt.Sprintf("I am %s, a highly advanced robot.", r.Model) +} + +func (r Robot) Describe() string { + return fmt.Sprintf("Robot Model: %s", r.Model) +} + +func FuzzStructs(f *testing.F) { + f.Fuzz(func(t *testing.T, name string, ageString string, model string) { + age, err := strconv.Atoi(ageString) + if err != nil { + return + } + + p := Person{Name: name, Age: age} + d := NewDog(name) + r := new(Robot) + r.Model = model + + _ = p.Greet() + _ = d.Introduce() + _ = r.Describe() + }) +} diff --git a/src/test/test_frontends_go.py b/src/test/test_frontends_go.py index 740fdeaca..b606ae14f 100644 --- a/src/test/test_frontends_go.py +++ b/src/test/test_frontends_go.py @@ -64,7 +64,22 @@ def test_tree_sitter_go_sample3(): # Project check harness = project.get_source_codes_with_harnesses() - assert len(harness) == 0 + assert len(harness) == 1 + + functions_reached = project.get_reachable_functions(harness[0].source_file, harness[0]) + + # Callsite check + assert 'strconv.Atoi' in functions_reached + assert 'NewDog' in functions_reached + assert 'Person.Greet' in functions_reached + assert 'Dog.Introduce' in functions_reached + assert 'Robot.Describe' in functions_reached + assert 'Person.Introduce' not in functions_reached + assert 'Person.Describe' not in functions_reached + assert 'Dog.Greet' not in functions_reached + assert 'Dog.Describe' not in functions_reached + assert 'Robot.Greet' not in functions_reached + assert 'Robot.Introduce' not in functions_reached def test_tree_sitter_go_sample4():