From 39b22cf2f952fa23c422d370a6fbc3383d4dc3a2 Mon Sep 17 00:00:00 2001 From: Stuart Pernsteiner Date: Thu, 15 Aug 2024 17:01:54 -0700 Subject: [PATCH] analyze: auto_fix_errors.py: handle suggested lifetime bounds --- c2rust-analyze/scripts/auto_fix_errors.py | 109 +++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/c2rust-analyze/scripts/auto_fix_errors.py b/c2rust-analyze/scripts/auto_fix_errors.py index 8a045a8e3..d0370479c 100644 --- a/c2rust-analyze/scripts/auto_fix_errors.py +++ b/c2rust-analyze/scripts/auto_fix_errors.py @@ -1,6 +1,7 @@ import argparse from dataclasses import dataclass import json +import re import sys def parse_args() -> argparse.Namespace: @@ -21,6 +22,23 @@ class Fix: new_text: str message: str +@dataclass(frozen=True) +class LifetimeBound: + file_path: str + line_number: int + # Byte offset of the start of the lifetime parameter declaration. + start_byte: int + # Byte offset of the end of the lifetime parameter declaration. + end_byte: int + # The lifetime to use in the new bound. If `'a: 'b` is the suggested + # bound, then `start/end_byte` points to the declaration of `'a`, and + # `bound_lifetime` is the string `"'b"`. + bound_lifetime: str + +LIFETIME_DEFINED_RE = re.compile(r'^lifetime `([^`]*)` defined here$') +CONSIDER_ADDING_BOUND_RE = re.compile(r'^consider adding the following bound: `([^`:]*): ([^`]*)`$') +SPACE_COLON_RE = re.compile(rb'\s*:') + def main(): args = parse_args() @@ -42,6 +60,50 @@ def gather_fixes(j, message): for child in j['children']: gather_fixes(child, message) + lifetime_bounds = [] + def gather_lifetime_bounds(j): + # We look for a particular pattern seen in lifetime errors. First, + # there should be a span with label "lifetime 'a defined here" pointing + # at the declaration of `'a`. Second, there should be a child of type + # `help` with the text "consider adding the following bound: `'a: 'b`". + + decl_spans = {} + for span in j['spans']: + m = LIFETIME_DEFINED_RE.match(span['label']) + if m is not None: + lifetime = m.group(1) + if lifetime in decl_spans: + # Duplicate declaration for this lifetime. This shouldn't + # happen, but we can proceed as long as the lifetime isn't + # the target of the bound. We mark the duplicate lifetime + # so it can't be used as the target. + decl_spans[lifetime] = None + continue + decl_spans[lifetime] = span + + for child in j['children']: + if child['level'] != 'help': + continue + m = CONSIDER_ADDING_BOUND_RE.match(child['message']) + if m is None: + continue + lifetime_a = m.group(1) + lifetime_b = m.group(2) + span = decl_spans.get(lifetime_a) + if span is None: + # We don't have anywhere to insert the new bound. This can + # also happen if there were duplicate declaration spans for + # this lifetime (we explicitly insert `None` into the map in + # that case). + continue + lifetime_bounds.append(LifetimeBound( + file_path=span['file_name'], + line_number=span['line_start'], + start_byte=span['byte_start'], + end_byte=span['byte_end'], + bound_lifetime=lifetime_b, + )) + with open(args.path, 'r') as f: for line in f: j = json.loads(line) @@ -58,6 +120,51 @@ def gather_fixes(j, message): gather_fixes(j, j['message']) + if j['message'] == 'lifetime may not live long enough': + gather_lifetime_bounds(j) + + # Convert suggested lifetime bounds to fixes. We have to group the bounds + # first because there may be multiple suggested bounds for a single + # declaration, in which case we want to generate a single `Fix` that adds + # all of them at once. + + # Maps the `(file_path, line_number, start_byte, end_byte)` of the + # declaration site to the set of new lifetimes to apply at that site. + grouped_lifetime_bounds = {} + for lb in lifetime_bounds: + key = (lb.file_path, lb.line_number, lb.start_byte, lb.end_byte) + if key not in grouped_lifetime_bounds: + grouped_lifetime_bounds[key] = set() + grouped_lifetime_bounds[key].add(lb.bound_lifetime) + + file_content = {} + def read_file(file_path): + if file_path not in file_content: + file_content[file_path] = open(file_path, 'rb').read() + return file_content[file_path] + + for key, bound_lifetimes in sorted(grouped_lifetime_bounds.items()): + (file_path, line_number, start_byte, end_byte) = key + content = read_file(file_path) + decl_lifetime = content[start_byte : end_byte].decode('utf-8') + bound_lifetimes = ' + '.join(bound_lifetimes) + m = SPACE_COLON_RE.match(content, end_byte) + if m is None: + fix_end_byte = end_byte + fix_new_text = '%s: %s' % (decl_lifetime, bound_lifetimes) + else: + space_colon = m.group().decode('utf-8') + fix_end_byte = m.end() + fix_new_text = '%s%s %s +' % (decl_lifetime, space_colon, bound_lifetimes) + fixes.append(Fix( + file_path=file_path, + line_number=line_number, + start_byte=start_byte, + end_byte=fix_end_byte, + new_text=fix_new_text, + message='lifetime may not live long enough', + )) + fixes_by_file = {} for fix in fixes: file_fixes = fixes_by_file.get(fix.file_path) @@ -70,7 +177,7 @@ def gather_fixes(j, message): # Apply fixes for file_path, file_fixes in sorted(fixes_by_file.items()): - content = open(file_path, 'rb').read() + content = read_file(file_path) chunks = [] pos = 0 prev_fix = None