diff --git a/lib/python/pyflyby/_parse.py b/lib/python/pyflyby/_parse.py index 13744384..c991be56 100644 --- a/lib/python/pyflyby/_parse.py +++ b/lib/python/pyflyby/_parse.py @@ -3,28 +3,40 @@ # License: MIT http://opensource.org/licenses/MIT from __future__ import annotations, print_function - import ast +from ast import AsyncFunctionDef, TypeIgnore + from collections import namedtuple from doctest import DocTestParser from functools import total_ordering from itertools import groupby + +from pyflyby._file import FilePos, FileText, Filename +from pyflyby._flags import CompilerFlags +from pyflyby._log import logger +from pyflyby._util import cached_attribute, cmp + import re import sys from textwrap import dedent import types -from typing import Any, List, Optional, Tuple, Union, cast +from typing import Any, List, Optional, Tuple, Union, cast import warnings -from pyflyby._file import FilePos, FileText, Filename -from pyflyby._flags import CompilerFlags -from pyflyby._log import logger -from pyflyby._util import cached_attribute, cmp +_sentinel = object() +if sys.version_info < (3, 10): -from ast import AsyncFunctionDef, TypeIgnore + class MatchAs: + name: str + pattern: ast.AST -_sentinel = object() + class MatchMapping: + keys: List[ast.AST] + patterns: List[MatchAs] + +else: + from ast import MatchAs, MatchMapping def _is_comment_or_blank(line, /): @@ -162,6 +174,12 @@ def _iter_child_nodes_in_order_internal_1(node): elif isinstance(node, ast.FormattedValue): assert node._fields == ('value', 'conversion', 'format_spec') yield node.value, + elif isinstance(node, MatchAs): + yield node.pattern + yield node.name, + elif isinstance(node, MatchMapping): + for k, p in zip(node.keys, node.patterns): + yield k, p else: # Default behavior. yield ast.iter_child_nodes(node) diff --git a/tests/test_parse.py b/tests/test_parse.py index 0d538218..a66cc318 100644 --- a/tests/test_parse.py +++ b/tests/test_parse.py @@ -572,55 +572,64 @@ async def func(self, location: str) -> bytes: def f(x, y=None, / , z=None): pass """, - """ - match { "foo": 1, "bar": 2 }: - case { - "foo": foo, - "bar": bar, - **rest, - }: - pass - case _: - pass - """, - """ - match event.get(): - case Click(position=(x, y)): - handle_click_at(x, y) - case KeyPress(key_name="Q") | Quit(): - game.quit() - case KeyPress(key_name="up arrow"): - game.go_north() - case KeyPress(): - pass # Ignore other keystrokes - case other_event: - raise ValueError(f"Unrecognized event: {other_event}") - """, - """ - match event.get(): - case Click((x, y), button=Button.LEFT): # This is a left click - handle_click_at(x, y) - case Click(): - pass # ignore other clicks - """, - """ - def http_error(status): - match status: - case 400: - return "Bad request" - case 404: - return "Not found" - case 418: - return "I'm a teapot" - case 500 | 501 | 502: - return "I'm a teapot" - case _: - return "Something's wrong with the Internet" - - """ ] ] +if sys.version_info >= (3, 10): + examples_transform.extend( + [ + dedent(x) + for x in [ + """ + match { "foo": 1, "bar": 2 }: + case { + "foo": foo, + "bar": bar, + **rest, + }: + pass + case _: + pass + """, + """ + match event.get(): + case Click(position=(x, y)): + handle_click_at(x, y) + case KeyPress(key_name="Q") | Quit(): + game.quit() + case KeyPress(key_name="up arrow"): + game.go_north() + case KeyPress(): + pass # Ignore other keystrokes + case other_event: + raise ValueError(f"Unrecognized event: {other_event}") + """, + """ + match event.get(): + case Click((x, y), button=Button.LEFT): # This is a left click + handle_click_at(x, y) + case Click(): + pass # ignore other clicks + """, + """ + def http_error(status): + match status: + case 400: + return "Bad request" + case 404: + return "Not found" + case 418: + return "I'm a teapot" + case 500 | 501 | 502: + return "I'm a teapot" + case _: + return "Something's wrong with the Internet" + + """ + ] + ] + ) + @pytest.mark.parametrize("source", examples_transform) def test_PythonBlock_flags_type_comment_ignore_fails_transform(source): @@ -629,8 +638,7 @@ def test_PythonBlock_flags_type_comment_ignore_fails_transform(source): Type: ignore are custom ast.AST who have no col_offset. """ - block = PythonBlock( - dedent(source)) + block = PythonBlock(dedent(source)) s = SourceToSourceFileImportsTransformation(block) assert s.output() == block