Skip to content

Commit

Permalink
pylint: Fix pyline error for frontend-rust
Browse files Browse the repository at this point in the history
Signed-off-by: Arthur Chan <[email protected]>
  • Loading branch information
arthurscchan committed Jan 21, 2025
1 parent 93943dd commit bb51eb2
Showing 1 changed file with 79 additions and 66 deletions.
145 changes: 79 additions & 66 deletions src/fuzz_introspector/frontends/frontend_rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ def language_specific_process(self):
# Load functions/methods delcaration
self._set_function_method_declaration(self.root)

def _set_function_method_declaration(self,
start_object: Node,
start_prefix: list[str] = []):
def _set_function_method_declaration(
self,
start_object: Node,
start_prefix: Optional[list[str]] = None):
"""Internal helper for retrieving all classes."""
start_prefix = [] if not start_prefix else start_prefix

for node in start_object.children:
# Reset prefix
prefix = start_prefix[:]
Expand Down Expand Up @@ -76,30 +79,31 @@ def _set_function_method_declaration(self,
# Handle mod functions
elif node.type == 'mod_item':
mod_body = node.child_by_field_name('body')
if mod_body:
# Basic info of this mod
mod_name = node.child_by_field_name('name')
if mod_name:
prefix.append(mod_name.text.decode())

# Loop through the body of this mod
for mod in mod_body.children:
# Handle general function in this mod
if mod.type == 'function_item':
self.functions.append(
RustFunction(mod, self.tree_sitter_lang, self,
prefix))

# Handles inner impl
elif mod.type == 'impl_item':
self._set_function_method_declaration(mod, prefix)

# Handles inner mod
elif mod.type == 'mod_item':
inner_body = mod.child_by_field_name('body')
if inner_body:
self._set_function_method_declaration(
inner_body, prefix)
if not mod_body:
return

# Basic info of this mod
mod_name = node.child_by_field_name('name')
if mod_name:
prefix.append(mod_name.text.decode())

# Loop through the body of this mod
for mod in mod_body.children:
# Handle general function in this mod
if mod.type == 'function_item':
self.functions.append(
RustFunction(mod, self.tree_sitter_lang, self,
prefix))
# Handles inner impl
elif mod.type == 'impl_item':
self._set_function_method_declaration(mod, prefix)

# Handles inner mod
elif mod.type == 'mod_item':
inner_body = mod.child_by_field_name('body')
if inner_body:
self._set_function_method_declaration(
inner_body, prefix)

# Handling trait item
elif node.type == 'trait_item':
Expand Down Expand Up @@ -317,17 +321,17 @@ def _process_macro_declaration(self):
if token_tree.type == 'token_tree':
content = token_tree.text.decode()
if content.startswith('{'):
bytes = content.encode('utf-8')
root = self.parent_source.parser.parse(bytes)
cbytes = content.encode('utf-8')
root = self.parent_source.parser.parse(cbytes)
self.fuzzing_token_tree = root.root_node

elif child.type == 'macro_rule':
token_tree = child.child_by_field_name('right')
if token_tree:
content = token_tree.text.decode()
if content.startswith('{'):
bytes = content.encode('utf-8')
root = self.parent_source.parser.parse(bytes)
cbytes = content.encode('utf-8')
root = self.parent_source.parser.parse(cbytes)
self.fuzzing_token_tree = root.root_node

def _process_variables(self):
Expand Down Expand Up @@ -392,7 +396,8 @@ def extract_callsites(self, functions: dict[str, 'RustFunction']):
"""Extract callsites."""

def _process_invoke(expr: Node) -> list[tuple[str, int, int]]:
"""Internal helper for processing the function invocation statement."""
"""Internal helper for processing the function invocation
statement."""
callsites = []
target_name: str = ''

Expand Down Expand Up @@ -437,26 +442,26 @@ def _process_field_expr_return_type(
field_expr: Node) -> tuple[Optional[str], str]:
"""Helper for determining the return type of a field expression
in a chained call and its full qualified name."""
type = None
return_type = None

name = field_expr.child_by_field_name('field').text.decode()
object = field_expr.child_by_field_name('value')
obj = field_expr.child_by_field_name('value')
full_name = name

object_type = None
if object.type == 'call_expression':
object_type = _retrieve_return_type(object)
elif object.type in ['identifier', 'scoped_identifier']:
object_text = object.text.decode()
if obj.type == 'call_expression':
object_type = _retrieve_return_type(obj)
elif obj.type in ['identifier', 'scoped_identifier']:
object_text = obj.text.decode()
node = get_function_node(object_text, functions)
if node:
object_type = node.return_type
else:
object_type = self.var_map.get(object_text)
elif object.type == 'self':
elif obj.type == 'self':
object_type = self.name.rsplit('::', 1)[0]

elif object.type == 'string_literal':
elif obj.type == 'string_literal':
object_type = '&str'

if object_type:
Expand All @@ -467,25 +472,25 @@ def _process_field_expr_return_type(

node = get_function_node(full_name, functions)
if node:
type = node.return_type
return_type = node.return_type

return ((type, full_name))
return ((return_type, full_name))

def _retrieve_return_type(call_expr: Node) -> Optional[str]:
"""Helper for determining the return type of a call expression."""
type = None
return_type = None

func = call_expr.child_by_field_name('function')
if func:
if func.type in ['identifier', 'scoped_identifier']:
func_name = func.text.decode()
node = get_function_node(func_name, functions)
if node:
type = node.return_type
return_type = node.return_type
elif func.type == 'field_expression':
type, _ = _process_field_expr_return_type(func)
return_type, _ = _process_field_expr_return_type(func)

return type
return return_type

def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]:
"""Process and store the callsites of the function."""
Expand All @@ -498,31 +503,32 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]:
param_type = stmt.child_by_field_name('value')
if param_name and param_type:
name = param_name.text.decode()
type = None
return_type = None
if param_type.type == 'identifier':
target = param_type.text.decode()
type = self.var_map.get(target)
if not type:
type = self.parent_source.uses.get(target)
if not type:
type = target
return_type = self.var_map.get(target)
if not return_type:
return_type = self.parent_source.uses.get(target)
if not return_type:
return_type = target
elif param_type.type == 'type_cast_expression':
# In general, type casted object are not callable
# This exists for type safety in case variable tracing for
# pointers and primitive types are needed.
type = param_type.child_by_field_name(
# This exists for type safety in case variable tracing
# for pointers and primitive types are needed.
return_type = param_type.child_by_field_name(
'type').text.decode()
elif param_type.type == 'call_expression':
type = _retrieve_return_type(param_type)
return_type = _retrieve_return_type(param_type)
elif param_type.type == 'reference_expression':
for ref_type in param_type.children:
if ref_type.type == 'identifier':
type = self.var_map.get(ref_type.text.decode())
return_type = self.var_map.get(
ref_type.text.decode())
elif ref_type.type == 'call_expression':
type = _retrieve_return_type(ref_type)
return_type = _retrieve_return_type(ref_type)

if type:
self.var_map[name] = type
if return_type:
self.var_map[name] = return_type

elif stmt.type == 'macro_invocation':
for child in stmt.children:
Expand Down Expand Up @@ -559,7 +565,7 @@ def _process_callsites(stmt: Node) -> list[tuple[str, int, int]]:

if not self.detailed_callsites:
for dst, src_line in self.base_callsites:
src_loc = self.parent_source.source_file + ':%d,1' % (src_line)
src_loc = f'self.parent_source.source_file :{src_line},1'
self.detailed_callsites.append({'Src': src_loc, 'Dst': dst})


Expand All @@ -576,6 +582,8 @@ def dump_module_logic(self,
harness_source: str = '',
dump_output: bool = True):
"""Dumps the data for the module in full."""
_ = entry_function

logger.info('Dumping project-wide logic.')
report: dict[str, Any] = {'report': 'name'}
report['sources'] = []
Expand Down Expand Up @@ -676,7 +684,7 @@ def calculate_function_uses(self, target_name: str,
if callsite[0] == target_name:
found = True
break
elif callsite[0].endswith(target_name):
if callsite[0].endswith(target_name):
found = True
break
if found:
Expand Down Expand Up @@ -761,13 +769,14 @@ def extract_calltree(self,
line_to_print += str(line_number)
line_to_print += '\n'

if function in visited_functions or not func_node or not source_code or not function:
if (function in visited_functions or not func_node or not source_code
or not function):
return line_to_print

callsites = func_node.base_callsites
visited_functions.add(function)

for cs, line_number in callsites:
for cs, line in callsites:
is_macro = bool(func_node and func_node.is_macro
and func_node.name != 'fuzz_target')
other_props = {}
Expand All @@ -777,7 +786,7 @@ def extract_calltree(self,
function=cs,
visited_functions=visited_functions,
depth=depth + 1,
line_number=line_number,
line_number=line,
other_props=other_props)

return line_to_print
Expand All @@ -792,6 +801,8 @@ def get_reachable_functions(
function: Optional[str] = None,
visited_functions: Optional[set[str]] = None) -> set[str]:
"""Get a list of reachable functions for a provided function name."""
_ = source_file

func_node = None

if not visited_functions:
Expand Down Expand Up @@ -832,7 +843,8 @@ def get_reachable_functions(

def load_treesitter_trees(source_files: list[str],
is_log: bool = True) -> RustProject:
"""Creates treesitter trees for all files in a given list of source files."""
"""Creates treesitter trees for all files in a given list of
source files."""
results = []

for code_file in source_files:
Expand All @@ -848,6 +860,7 @@ def load_treesitter_trees(source_files: list[str],
def analyse_source_code(source_content: str,
entrypoint: str) -> RustSourceCodeFile:
"""Returns a source abstraction based on a single source string."""
_ = entrypoint
source_code = RustSourceCodeFile('rust',
source_file='in-memory string',
source_content=source_content.encode())
Expand Down

0 comments on commit bb51eb2

Please sign in to comment.