From bb51eb2f15513c0dbaf22c630e50a7129851e4d7 Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Tue, 21 Jan 2025 20:59:12 +0000 Subject: [PATCH] pylint: Fix pyline error for frontend-rust Signed-off-by: Arthur Chan --- .../frontends/frontend_rust.py | 145 ++++++++++-------- 1 file changed, 79 insertions(+), 66 deletions(-) diff --git a/src/fuzz_introspector/frontends/frontend_rust.py b/src/fuzz_introspector/frontends/frontend_rust.py index 3f5b8395d..46fd17581 100644 --- a/src/fuzz_introspector/frontends/frontend_rust.py +++ b/src/fuzz_introspector/frontends/frontend_rust.py @@ -40,10 +40,13 @@ def language_specific_process(self): # Load functions/methods delcaration self._set_function_method_declaration(self.root) - def _set_function_method_declaration(self, - start_object: Node, - start_prefix: list[str] = []): + def _set_function_method_declaration( + self, + start_object: Node, + start_prefix: Optional[list[str]] = None): """Internal helper for retrieving all classes.""" + start_prefix = [] if not start_prefix else start_prefix + for node in start_object.children: # Reset prefix prefix = start_prefix[:] @@ -76,30 +79,31 @@ def _set_function_method_declaration(self, # Handle mod functions elif node.type == 'mod_item': mod_body = node.child_by_field_name('body') - if mod_body: - # Basic info of this mod - mod_name = node.child_by_field_name('name') - if mod_name: - prefix.append(mod_name.text.decode()) - - # Loop through the body of this mod - for mod in mod_body.children: - # Handle general function in this mod - if mod.type == 'function_item': - self.functions.append( - RustFunction(mod, self.tree_sitter_lang, self, - prefix)) - - # Handles inner impl - elif mod.type == 'impl_item': - self._set_function_method_declaration(mod, prefix) - - # Handles inner mod - elif mod.type == 'mod_item': - inner_body = mod.child_by_field_name('body') - if inner_body: - self._set_function_method_declaration( - inner_body, prefix) + if not mod_body: + return + + # Basic info of this mod + mod_name = node.child_by_field_name('name') + if mod_name: + prefix.append(mod_name.text.decode()) + + # Loop through the body of this mod + for mod in mod_body.children: + # Handle general function in this mod + if mod.type == 'function_item': + self.functions.append( + RustFunction(mod, self.tree_sitter_lang, self, + prefix)) + # Handles inner impl + elif mod.type == 'impl_item': + self._set_function_method_declaration(mod, prefix) + + # Handles inner mod + elif mod.type == 'mod_item': + inner_body = mod.child_by_field_name('body') + if inner_body: + self._set_function_method_declaration( + inner_body, prefix) # Handling trait item elif node.type == 'trait_item': @@ -317,8 +321,8 @@ def _process_macro_declaration(self): if token_tree.type == 'token_tree': content = token_tree.text.decode() if content.startswith('{'): - bytes = content.encode('utf-8') - root = self.parent_source.parser.parse(bytes) + cbytes = content.encode('utf-8') + root = self.parent_source.parser.parse(cbytes) self.fuzzing_token_tree = root.root_node elif child.type == 'macro_rule': @@ -326,8 +330,8 @@ def _process_macro_declaration(self): if token_tree: content = token_tree.text.decode() if content.startswith('{'): - bytes = content.encode('utf-8') - root = self.parent_source.parser.parse(bytes) + cbytes = content.encode('utf-8') + root = self.parent_source.parser.parse(cbytes) self.fuzzing_token_tree = root.root_node def _process_variables(self): @@ -392,7 +396,8 @@ def extract_callsites(self, functions: dict[str, 'RustFunction']): """Extract callsites.""" def _process_invoke(expr: Node) -> list[tuple[str, int, int]]: - """Internal helper for processing the function invocation statement.""" + """Internal helper for processing the function invocation + statement.""" callsites = [] target_name: str = '' @@ -437,26 +442,26 @@ def _process_field_expr_return_type( field_expr: Node) -> tuple[Optional[str], str]: """Helper for determining the return type of a field expression in a chained call and its full qualified name.""" - type = None + return_type = None name = field_expr.child_by_field_name('field').text.decode() - object = field_expr.child_by_field_name('value') + obj = field_expr.child_by_field_name('value') full_name = name object_type = None - if object.type == 'call_expression': - object_type = _retrieve_return_type(object) - elif object.type in ['identifier', 'scoped_identifier']: - object_text = object.text.decode() + if obj.type == 'call_expression': + object_type = _retrieve_return_type(obj) + elif obj.type in ['identifier', 'scoped_identifier']: + object_text = obj.text.decode() node = get_function_node(object_text, functions) if node: object_type = node.return_type else: object_type = self.var_map.get(object_text) - elif object.type == 'self': + elif obj.type == 'self': object_type = self.name.rsplit('::', 1)[0] - elif object.type == 'string_literal': + elif obj.type == 'string_literal': object_type = '&str' if object_type: @@ -467,13 +472,13 @@ def _process_field_expr_return_type( node = get_function_node(full_name, functions) if node: - type = node.return_type + return_type = node.return_type - return ((type, full_name)) + return ((return_type, full_name)) def _retrieve_return_type(call_expr: Node) -> Optional[str]: """Helper for determining the return type of a call expression.""" - type = None + return_type = None func = call_expr.child_by_field_name('function') if func: @@ -481,11 +486,11 @@ def _retrieve_return_type(call_expr: Node) -> Optional[str]: func_name = func.text.decode() node = get_function_node(func_name, functions) if node: - type = node.return_type + return_type = node.return_type elif func.type == 'field_expression': - type, _ = _process_field_expr_return_type(func) + return_type, _ = _process_field_expr_return_type(func) - return type + return return_type def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: """Process and store the callsites of the function.""" @@ -498,31 +503,32 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: param_type = stmt.child_by_field_name('value') if param_name and param_type: name = param_name.text.decode() - type = None + return_type = None if param_type.type == 'identifier': target = param_type.text.decode() - type = self.var_map.get(target) - if not type: - type = self.parent_source.uses.get(target) - if not type: - type = target + return_type = self.var_map.get(target) + if not return_type: + return_type = self.parent_source.uses.get(target) + if not return_type: + return_type = target elif param_type.type == 'type_cast_expression': # In general, type casted object are not callable - # This exists for type safety in case variable tracing for - # pointers and primitive types are needed. - type = param_type.child_by_field_name( + # This exists for type safety in case variable tracing + # for pointers and primitive types are needed. + return_type = param_type.child_by_field_name( 'type').text.decode() elif param_type.type == 'call_expression': - type = _retrieve_return_type(param_type) + return_type = _retrieve_return_type(param_type) elif param_type.type == 'reference_expression': for ref_type in param_type.children: if ref_type.type == 'identifier': - type = self.var_map.get(ref_type.text.decode()) + return_type = self.var_map.get( + ref_type.text.decode()) elif ref_type.type == 'call_expression': - type = _retrieve_return_type(ref_type) + return_type = _retrieve_return_type(ref_type) - if type: - self.var_map[name] = type + if return_type: + self.var_map[name] = return_type elif stmt.type == 'macro_invocation': for child in stmt.children: @@ -559,7 +565,7 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: if not self.detailed_callsites: for dst, src_line in self.base_callsites: - src_loc = self.parent_source.source_file + ':%d,1' % (src_line) + src_loc = f'self.parent_source.source_file :{src_line},1' self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) @@ -576,6 +582,8 @@ def dump_module_logic(self, harness_source: str = '', dump_output: bool = True): """Dumps the data for the module in full.""" + _ = entry_function + logger.info('Dumping project-wide logic.') report: dict[str, Any] = {'report': 'name'} report['sources'] = [] @@ -676,7 +684,7 @@ def calculate_function_uses(self, target_name: str, if callsite[0] == target_name: found = True break - elif callsite[0].endswith(target_name): + if callsite[0].endswith(target_name): found = True break if found: @@ -761,13 +769,14 @@ def extract_calltree(self, line_to_print += str(line_number) line_to_print += '\n' - if function in visited_functions or not func_node or not source_code or not function: + if (function in visited_functions or not func_node or not source_code + or not function): return line_to_print callsites = func_node.base_callsites visited_functions.add(function) - for cs, line_number in callsites: + for cs, line in callsites: is_macro = bool(func_node and func_node.is_macro and func_node.name != 'fuzz_target') other_props = {} @@ -777,7 +786,7 @@ def extract_calltree(self, function=cs, visited_functions=visited_functions, depth=depth + 1, - line_number=line_number, + line_number=line, other_props=other_props) return line_to_print @@ -792,6 +801,8 @@ def get_reachable_functions( function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" + _ = source_file + func_node = None if not visited_functions: @@ -832,7 +843,8 @@ def get_reachable_functions( def load_treesitter_trees(source_files: list[str], is_log: bool = True) -> RustProject: - """Creates treesitter trees for all files in a given list of source files.""" + """Creates treesitter trees for all files in a given list of + source files.""" results = [] for code_file in source_files: @@ -848,6 +860,7 @@ def load_treesitter_trees(source_files: list[str], def analyse_source_code(source_content: str, entrypoint: str) -> RustSourceCodeFile: """Returns a source abstraction based on a single source string.""" + _ = entrypoint source_code = RustSourceCodeFile('rust', source_file='in-memory string', source_content=source_content.encode())