From 61fc40ee5b0716d8370b29fb0929b34c18c74fb4 Mon Sep 17 00:00:00 2001 From: Arthur Chan Date: Wed, 22 Jan 2025 12:49:08 +0000 Subject: [PATCH] pylint: Fix pylint error for frontend-rust (#2010) * pylint: Fix pyline error for frontend-rust Signed-off-by: Arthur Chan * Fix bug Signed-off-by: Arthur Chan * Add disable to pyline in some unused parameters Signed-off-by: Arthur Chan * Fix formatting Signed-off-by: Arthur Chan * Fix unnecessary disable Signed-off-by: Arthur Chan --------- Signed-off-by: Arthur Chan --- .../frontends/frontend_rust.py | 208 ++++++++++-------- 1 file changed, 115 insertions(+), 93 deletions(-) diff --git a/src/fuzz_introspector/frontends/frontend_rust.py b/src/fuzz_introspector/frontends/frontend_rust.py index 3f5b8395d..388d659b4 100644 --- a/src/fuzz_introspector/frontends/frontend_rust.py +++ b/src/fuzz_introspector/frontends/frontend_rust.py @@ -22,15 +22,15 @@ import logging import yaml -from fuzz_introspector.frontends.datatypes import Project, SourceCodeFile +from fuzz_introspector.frontends import datatypes logger = logging.getLogger(name=__name__) -class RustSourceCodeFile(SourceCodeFile): +class RustSourceCodeFile(datatypes.SourceCodeFile): """Class for holding file-specific information.""" - def language_specific_process(self): + def language_specific_process(self) -> None: """Perform some language specific processes in subclasses.""" self.uses: dict[str, str] = {} @@ -38,12 +38,16 @@ def language_specific_process(self): self.functions: list['RustFunction'] = [] # Load functions/methods delcaration - self._set_function_method_declaration(self.root) + if self.root: + self._set_function_method_declaration(self.root) - def _set_function_method_declaration(self, - start_object: Node, - start_prefix: list[str] = []): + def _set_function_method_declaration( + self, + start_object: Node, + start_prefix: Optional[list[str]] = None): """Internal helper for retrieving all classes.""" + start_prefix = [] if not start_prefix else start_prefix + for node in start_object.children: # Reset prefix prefix = start_prefix[:] @@ -58,9 +62,13 @@ def _set_function_method_declaration(self, # Basic info of this impl impl_type = node.child_by_field_name('type') impl_body = node.child_by_field_name('body') - if impl_type: + if impl_type and impl_type.text: prefix.append(impl_type.text.decode().split('<')[0]) + # Check impl_bdoy + if not impl_body: + continue + # Loop through the items in this impl for impl in impl_body.children: # Handle general methods in this impl @@ -76,39 +84,44 @@ def _set_function_method_declaration(self, # Handle mod functions elif node.type == 'mod_item': mod_body = node.child_by_field_name('body') - if mod_body: - # Basic info of this mod - mod_name = node.child_by_field_name('name') - if mod_name: - prefix.append(mod_name.text.decode()) - - # Loop through the body of this mod - for mod in mod_body.children: - # Handle general function in this mod - if mod.type == 'function_item': - self.functions.append( - RustFunction(mod, self.tree_sitter_lang, self, - prefix)) - - # Handles inner impl - elif mod.type == 'impl_item': - self._set_function_method_declaration(mod, prefix) - - # Handles inner mod - elif mod.type == 'mod_item': - inner_body = mod.child_by_field_name('body') - if inner_body: - self._set_function_method_declaration( - inner_body, prefix) + if not mod_body: + continue + + # Basic info of this mod + mod_name = node.child_by_field_name('name') + if mod_name and mod_name.text: + prefix.append(mod_name.text.decode()) + + # Loop through the body of this mod + for mod in mod_body.children: + # Handle general function in this mod + if mod.type == 'function_item': + self.functions.append( + RustFunction(mod, self.tree_sitter_lang, self, + prefix)) + # Handles inner impl + elif mod.type == 'impl_item': + self._set_function_method_declaration(mod, prefix) + + # Handles inner mod + elif mod.type == 'mod_item': + inner_body = mod.child_by_field_name('body') + if inner_body: + self._set_function_method_declaration( + inner_body, prefix) # Handling trait item elif node.type == 'trait_item': # Basic info of this trait trait_name = node.child_by_field_name('name') trait_body = node.child_by_field_name('body') - if trait_name: + if trait_name and trait_name.text: prefix.append(trait_name.text.decode().split('<')[0]) + # Check trait_body + if not trait_body: + continue + # Loop through the items in this trait for trait in trait_body.children: # Handle general methods in this trait @@ -135,7 +148,7 @@ def _set_function_method_declaration(self, # Handling specific use declaration elif node.type == 'use_declaration': use_stmt = node.child_by_field_name('argument') - if use_stmt: + if use_stmt and use_stmt.text: use_map = self._process_recursive_use( use_stmt.text.decode()) self.uses.update(use_map) @@ -317,8 +330,8 @@ def _process_macro_declaration(self): if token_tree.type == 'token_tree': content = token_tree.text.decode() if content.startswith('{'): - bytes = content.encode('utf-8') - root = self.parent_source.parser.parse(bytes) + cbytes = content.encode('utf-8') + root = self.parent_source.parser.parse(cbytes) self.fuzzing_token_tree = root.root_node elif child.type == 'macro_rule': @@ -326,8 +339,8 @@ def _process_macro_declaration(self): if token_tree: content = token_tree.text.decode() if content.startswith('{'): - bytes = content.encode('utf-8') - root = self.parent_source.parser.parse(bytes) + cbytes = content.encode('utf-8') + root = self.parent_source.parser.parse(cbytes) self.fuzzing_token_tree = root.root_node def _process_variables(self): @@ -392,7 +405,8 @@ def extract_callsites(self, functions: dict[str, 'RustFunction']): """Extract callsites.""" def _process_invoke(expr: Node) -> list[tuple[str, int, int]]: - """Internal helper for processing the function invocation statement.""" + """Internal helper for processing the function invocation + statement.""" callsites = [] target_name: str = '' @@ -402,7 +416,8 @@ def _process_invoke(expr: Node) -> list[tuple[str, int, int]]: if func: # Simple function call if func.type in ['identifier', 'scoped_identifier']: - target_name = func.text.decode() + if func.text: + target_name = func.text.decode() # Ignore lambda function calls if target_name in self.var_map: @@ -414,12 +429,12 @@ def _process_invoke(expr: Node) -> list[tuple[str, int, int]]: elif func.type == 'field_expression': _, target_name = _process_field_expr_return_type(func) - elif func.type == 'generic_function': + elif func.type == 'generic_function' and func.text: target_name = func.text.decode().split('.', 1)[-1] - if target_name: - callsites.append((target_name, func.byte_range[1], - func.start_point.row + 1)) + if target_name and func.byte_range and func.start_point: + callsites.append((target_name, func.byte_range[1], + func.start_point.row + 1)) return callsites @@ -437,55 +452,56 @@ def _process_field_expr_return_type( field_expr: Node) -> tuple[Optional[str], str]: """Helper for determining the return type of a field expression in a chained call and its full qualified name.""" - type = None + return_type = None + + name = field_expr.child_by_field_name('field') + obj = field_expr.child_by_field_name('value') + full_name = name.text.decode() if name and name.text else '' - name = field_expr.child_by_field_name('field').text.decode() - object = field_expr.child_by_field_name('value') - full_name = name + if not obj: + return (return_type, full_name) object_type = None - if object.type == 'call_expression': - object_type = _retrieve_return_type(object) - elif object.type in ['identifier', 'scoped_identifier']: - object_text = object.text.decode() + if obj.type == 'call_expression': + object_type = _retrieve_return_type(obj) + elif obj.type in ['identifier', 'scoped_identifier']: + object_text = obj.text.decode() if obj.text else '' node = get_function_node(object_text, functions) if node: object_type = node.return_type else: object_type = self.var_map.get(object_text) - elif object.type == 'self': + elif obj.type == 'self': object_type = self.name.rsplit('::', 1)[0] - elif object.type == 'string_literal': + elif obj.type == 'string_literal': object_type = '&str' if object_type: - if object_type == 'void': - full_name = name - else: - full_name = f'{object_type}::{name}' + if object_type != 'void': + full_name = f'{object_type}::{full_name}' node = get_function_node(full_name, functions) if node: - type = node.return_type + return_type = node.return_type - return ((type, full_name)) + return (return_type, full_name) def _retrieve_return_type(call_expr: Node) -> Optional[str]: """Helper for determining the return type of a call expression.""" - type = None + return_type = None func = call_expr.child_by_field_name('function') if func: if func.type in ['identifier', 'scoped_identifier']: - func_name = func.text.decode() + func_name = func.text.decode() if func.text else '' node = get_function_node(func_name, functions) if node: - type = node.return_type + return_type = node.return_type elif func.type == 'field_expression': - type, _ = _process_field_expr_return_type(func) + return_type, _ = _process_field_expr_return_type(func) - return type + return return_type def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: """Process and store the callsites of the function.""" @@ -497,36 +513,40 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: param_name = stmt.child_by_field_name('pattern') param_type = stmt.child_by_field_name('value') if param_name and param_type: - name = param_name.text.decode() - type = None + name = param_name.text.decode() if param_name.text else '' + return_type = None if param_type.type == 'identifier': - target = param_type.text.decode() - type = self.var_map.get(target) - if not type: - type = self.parent_source.uses.get(target) - if not type: - type = target + target = (param_type.text.decode() + if param_type.text else '') + return_type = self.var_map.get(target) + if not return_type: + return_type = self.parent_source.uses.get(target) + if not return_type: + return_type = target elif param_type.type == 'type_cast_expression': # In general, type casted object are not callable - # This exists for type safety in case variable tracing for - # pointers and primitive types are needed. - type = param_type.child_by_field_name( - 'type').text.decode() + # This exists for type safety in case variable tracing + # for pointers and primitive types are needed. + return_node = param_type.child_by_field_name('type') + if return_node and return_node.text: + return_type = return_node.text.decode() elif param_type.type == 'call_expression': - type = _retrieve_return_type(param_type) + return_type = _retrieve_return_type(param_type) elif param_type.type == 'reference_expression': for ref_type in param_type.children: if ref_type.type == 'identifier': - type = self.var_map.get(ref_type.text.decode()) + key_bytes = ref_type.text + key = key_bytes.decode() if key_bytes else '' + return_type = self.var_map.get(key) elif ref_type.type == 'call_expression': - type = _retrieve_return_type(ref_type) + return_type = _retrieve_return_type(ref_type) - if type: - self.var_map[name] = type + if return_type: + self.var_map[name] = return_type elif stmt.type == 'macro_invocation': for child in stmt.children: - if child.type == 'identifier': + if child.type == 'identifier' and child.text: macro_name = child.text.decode() target_func = get_function_node(macro_name, functions) if target_func and target_func.is_macro: @@ -559,11 +579,11 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]: if not self.detailed_callsites: for dst, src_line in self.base_callsites: - src_loc = self.parent_source.source_file + ':%d,1' % (src_line) + src_loc = f'self.parent_source.source_file :{src_line},1' self.detailed_callsites.append({'Src': src_loc, 'Dst': dst}) -class RustProject(Project[RustSourceCodeFile]): +class RustProject(datatypes.Project[RustSourceCodeFile]): """Wrapper for doing analysis of a collection of source files.""" def __init__(self, source_code_files: list[RustSourceCodeFile]): @@ -576,6 +596,7 @@ def dump_module_logic(self, harness_source: str = '', dump_output: bool = True): """Dumps the data for the module in full.""" + # pylint: disable=unused-argument logger.info('Dumping project-wide logic.') report: dict[str, Any] = {'report': 'name'} report['sources'] = [] @@ -676,7 +697,7 @@ def calculate_function_uses(self, target_name: str, if callsite[0] == target_name: found = True break - elif callsite[0].endswith(target_name): + if callsite[0].endswith(target_name): found = True break if found: @@ -761,13 +782,14 @@ def extract_calltree(self, line_to_print += str(line_number) line_to_print += '\n' - if function in visited_functions or not func_node or not source_code or not function: + if (function in visited_functions or not func_node or not source_code + or not function): return line_to_print callsites = func_node.base_callsites visited_functions.add(function) - for cs, line_number in callsites: + for cs, line in callsites: is_macro = bool(func_node and func_node.is_macro and func_node.name != 'fuzz_target') other_props = {} @@ -777,14 +799,11 @@ def extract_calltree(self, function=cs, visited_functions=visited_functions, depth=depth + 1, - line_number=line_number, + line_number=line, other_props=other_props) return line_to_print - def get_source_codes_with_harnesses(self) -> list[RustSourceCodeFile]: - return super().get_source_codes_with_harnesses() - def get_reachable_functions( self, source_file: str = '', @@ -792,6 +811,7 @@ def get_reachable_functions( function: Optional[str] = None, visited_functions: Optional[set[str]] = None) -> set[str]: """Get a list of reachable functions for a provided function name.""" + # pylint: disable=unused-argument func_node = None if not visited_functions: @@ -832,7 +852,8 @@ def get_reachable_functions( def load_treesitter_trees(source_files: list[str], is_log: bool = True) -> RustProject: - """Creates treesitter trees for all files in a given list of source files.""" + """Creates treesitter trees for all files in a given list of + source files.""" results = [] for code_file in source_files: @@ -848,6 +869,7 @@ def load_treesitter_trees(source_files: list[str], def analyse_source_code(source_content: str, entrypoint: str) -> RustSourceCodeFile: """Returns a source abstraction based on a single source string.""" + # pylint: disable=unused-argument source_code = RustSourceCodeFile('rust', source_file='in-memory string', source_content=source_content.encode())