diff --git a/tools/typeshed_patcher/patching.py b/tools/typeshed_patcher/patching.py index fb9fa685b5e..03b6cbd537e 100644 --- a/tools/typeshed_patcher/patching.py +++ b/tools/typeshed_patcher/patching.py @@ -50,10 +50,14 @@ def patch_one_file( if original_code is None: raise ValueError(f"Could not find content for {file_patch.path}") else: - patched_code = transforms.apply_patches_in_sequence( - code=original_code, - patches=file_patch.patches, - ) + try: + patched_code = transforms.apply_patches_in_sequence( + code=original_code, + patches=file_patch.patches, + ) + except NotImplementedError as e: + e.args = (f"{file_patch.path}: {e.args[0]}",) + e.args[1:] + raise e diff_view = compute_diff_view( original_code=original_code, patched_code=patched_code, diff --git a/tools/typeshed_patcher/tests/transforms_test.py b/tools/typeshed_patcher/tests/transforms_test.py index b62ef23dd9a..7950953aca5 100644 --- a/tools/typeshed_patcher/tests/transforms_test.py +++ b/tools/typeshed_patcher/tests/transforms_test.py @@ -326,6 +326,32 @@ def f(x: float) -> float: ... ), ) + def test_delete__multistatement_if_block(self) -> None: + self.assert_transform( + original_code=( + """ + if condition: + class A: + pass + class B: + pass + """ + ), + patch=patch_specs.Patch( + parent=patch_specs.QualifiedName.from_string(""), + action=patch_specs.DeleteAction( + name="A", + ), + ), + expected_code=( + """ + if condition: + class B: + pass + """ + ), + ) + def test_delete__typealias(self) -> None: self.assert_transform( original_code=( @@ -649,3 +675,31 @@ def test_replace__import(self) -> None: """ ), ) + + def test_replace__multistatement_if_block(self) -> None: + self.assert_transform( + original_code=( + """ + if condition: + class A: + pass + class B: + pass + """ + ), + patch=patch_specs.Patch( + parent=patch_specs.QualifiedName.from_string(""), + action=patch_specs.ReplaceAction( + name="A", + content="A = int", + ), + ), + expected_code=( + """ + if condition: + A = int + class B: + pass + """ + ), + ) diff --git a/tools/typeshed_patcher/transforms.py b/tools/typeshed_patcher/transforms.py index d8ad64345e8..b54637acbc3 100644 --- a/tools/typeshed_patcher/transforms.py +++ b/tools/typeshed_patcher/transforms.py @@ -12,13 +12,22 @@ from __future__ import annotations -from typing import Callable, Iterable, Sequence +import enum +from typing import Callable, Iterable, Sequence, TypeVar import libcst import libcst.codemod from . import patch_specs +_PASS_STATEMENT = libcst.SimpleStatementLine([libcst.Pass()]) + + +# Options for handling multistatement blocks in is_matching_if_block +class _MatchType(enum.Enum): + ALL = all + ANY = any + def statements_from_content(content: str) -> Sequence[libcst.BaseStatement]: """ @@ -62,33 +71,20 @@ def statements_in_if_block( def is_matching_if_block( block: libcst.If, predicate: Callable[[libcst.BaseStatement | libcst.BaseSmallStatement], bool], + match_type: _MatchType, ) -> bool: """ For the most part stubs do not use control flow. The one common exception is the use of if(/else) conditions to gate logic, typically on the Python version. This helper allows us to extend statement matchers to support - if/else blocks in which all the statements in the block match our condition. + if/else blocks. - We don't currently handle the case where many different names - (e.g. methods on a class) are defined in a single if block. This does - happen occasionally in stub files, but so far we've never needed to patch - such a case. - - If we detect that this is likely (the predicate matches on some but not - all statements), we raise a NotImplementedError rather than failing silently. + The match type (ALL or ANY) determines how multiple statements + (e.g. methods on a class) in a single if block are handled. """ statements = statements_in_if_block(block) predicate_evaluations = [predicate(statement) for statement in statements] - if all(predicate_evaluations): - return True - elif any(predicate_evaluations): - raise NotImplementedError( - "Typeshed patcher does not yet support complex if statements where " - "some inner statements match a condition and others don't.\n" - f"Got this if-statement: {block}" - ) - else: - return False + return match_type.value(predicate_evaluations) def import_names_match_name( @@ -149,10 +145,6 @@ def statement_matches_name( """ Given a statement in the parent scope, determine whether it matches the name (used for delete and replace actions). - - We handle if blocks correctly as long as all the conditions - are indented blocks consisting statements that themselves - match the name. """ if isinstance(statement, libcst.SimpleStatementLine): if len(statement.body) != 1: @@ -190,6 +182,7 @@ def statement_matches_name( return is_matching_if_block( statement, predicate=lambda s: statement_matches_name(name, s), + match_type=_MatchType.ANY, ) if isinstance(statement, libcst.Import): return import_names_match_name(statement.names, name) @@ -204,9 +197,6 @@ def is_import_statement( """ Given a statement, determine whether it is only performing imports. - - We handle if blocks correctly as long as all the conditions - are indented blocks consisting only of import statements. """ def is_import_indented_block( @@ -224,6 +214,7 @@ def is_import_indented_block( return is_matching_if_block( statement, predicate=is_import_statement, + match_type=_MatchType.ALL, ) return False @@ -268,25 +259,85 @@ def run_add_action( raise RuntimeError(f"Unknown position {action.position}") +_ActionT = TypeVar("_ActionT", bound=patch_specs.Action) + + +def run_action_on_indented_block( + action: _ActionT, + block: libcst.IndentedBlock, + parent: str, + runner: Callable[ + [_ActionT, Sequence[libcst.BaseStatement], str], Sequence[libcst.BaseStatement] + ], +) -> libcst.IndentedBlock: + statements = statements_in_indented_block(block) + try: + new_statements = runner(action, statements, parent) + except ValueError: + return block + return block.with_changes(body=new_statements) + + +def run_action_on_if_block( + action: _ActionT, + block: libcst.If, + parent: str, + runner: Callable[ + [_ActionT, Sequence[libcst.BaseStatement], str], Sequence[libcst.BaseStatement] + ], +) -> libcst.If: + # First run the action on the body of the `if`. + assert isinstance(block.body, libcst.IndentedBlock) # type hint for pyre + new_body = run_action_on_indented_block(action, block.body, parent, runner) + # Then run the action on any `elif` or `else` block. + if isinstance(block.orelse, libcst.If): + new_orelse = run_action_on_if_block(action, block.orelse, parent, runner) + elif isinstance(block.orelse, libcst.Else): + assert isinstance(block.orelse.body, libcst.IndentedBlock) # type hint for pyre + new_orelse_body = run_action_on_indented_block( + action, block.orelse.body, parent, runner + ) + new_orelse = block.orelse.with_changes(body=new_orelse_body) + else: + new_orelse = block.orelse + return block.with_changes(body=new_body, orelse=new_orelse) + + def run_delete_action( action: patch_specs.DeleteAction, existing_body: Sequence[libcst.BaseStatement], parent: str, ) -> Sequence[libcst.BaseStatement]: - new_body = [ - statement - for statement in existing_body - if not statement_matches_name(action.name, statement) - ] + new_body = [] + success = False + for statement in existing_body: + if not statement_matches_name(action.name, statement): + new_body.append(statement) + continue + if isinstance(statement, libcst.If): + new_statement = run_action_on_if_block( + action, statement, parent, run_delete_action + ) + if new_statement.body.body == [_PASS_STATEMENT] and ( + new_statement.orelse is None + or new_statement.orelse.body.body == [_PASS_STATEMENT] + ): + # If the `if` now contains nothing but `pass`es, delete the whole thing. + new_statement = None + else: + new_statement = None + success |= new_statement is None or not statement.deep_equals(new_statement) + if new_statement: + new_body.append(new_statement) # Always make sure we successfully deleted the target. This # might fail if the target has disappeared, or if our # `matches_name` logic needs to be extended. - if len(new_body) == len(existing_body): + if not success: raise ValueError(f"Could not find deletion target {action.name} in {parent}") # There's an edge case where we delete the entire scope body; # we can deal with this by inserting a pass. if len(new_body) == 0: - new_body = [libcst.SimpleStatementLine([libcst.Pass()])] + new_body = [_PASS_STATEMENT] return new_body @@ -295,16 +346,25 @@ def run_replace_action( existing_body: Sequence[libcst.BaseStatement], parent: str, ) -> Sequence[libcst.BaseStatement]: - statements_to_add = statements_from_content(action.content) new_body: list[libcst.BaseStatement] = [] added_replacements = False for statement in existing_body: - if statement_matches_name(action.name, statement): - if not added_replacements: - added_replacements = True - new_body.extend(statements_to_add) - else: + if not statement_matches_name(action.name, statement): new_body.append(statement) + continue + if added_replacements: + # We've already replaced the first occurrence of the name, drop any later + # occurrences of it. + continue + if isinstance(statement, libcst.If): + new_statement = run_action_on_if_block( + action, statement, parent, run_replace_action + ) + added_replacements = not statement.deep_equals(new_statement) + new_body.append(new_statement) + else: + added_replacements = True + new_body.extend(statements_from_content(action.content)) if not added_replacements: raise ValueError(f"Could not find replacement target {action.name} in {parent}") return new_body