Skip to content

Commit

Permalink
Add ability to understand pattern matching.
Browse files Browse the repository at this point in the history
This was added in Python 3.10, we need o explicitly walk the
MatchMapping and MatchAs node that were introduced in the same Python
version .
  • Loading branch information
Carreau committed May 13, 2024
1 parent 031e655 commit f2f1111
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 56 deletions.
34 changes: 26 additions & 8 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,40 @@
# License: MIT http://opensource.org/licenses/MIT
from __future__ import annotations, print_function


import ast
from ast import AsyncFunctionDef, TypeIgnore

from collections import namedtuple
from doctest import DocTestParser
from functools import total_ordering
from itertools import groupby

from pyflyby._file import FilePos, FileText, Filename
from pyflyby._flags import CompilerFlags
from pyflyby._log import logger
from pyflyby._util import cached_attribute, cmp

import re
import sys
from textwrap import dedent
import types
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union, cast
import warnings

from pyflyby._file import FilePos, FileText, Filename
from pyflyby._flags import CompilerFlags
from pyflyby._log import logger
from pyflyby._util import cached_attribute, cmp
_sentinel = object()

if sys.version_info < (3, 10):

from ast import AsyncFunctionDef, TypeIgnore
class MatchAs:
name: str
pattern: ast.AST

_sentinel = object()
class MatchMapping:
keys: List[ast.AST]
patterns: List[MatchAs]

else:
from ast import MatchAs, MatchMapping


def _is_comment_or_blank(line, /):
Expand Down Expand Up @@ -162,6 +174,12 @@ def _iter_child_nodes_in_order_internal_1(node):
elif isinstance(node, ast.FormattedValue):
assert node._fields == ('value', 'conversion', 'format_spec')
yield node.value,
elif isinstance(node, MatchAs):
yield node.pattern
yield node.name,
elif isinstance(node, MatchMapping):
for k, p in zip(node.keys, node.patterns):
yield k, p
else:
# Default behavior.
yield ast.iter_child_nodes(node)
Expand Down
104 changes: 56 additions & 48 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,55 +572,64 @@ async def func(self, location: str) -> bytes:
def f(x, y=None, / , z=None):
pass
""",
"""
match { "foo": 1, "bar": 2 }:
case {
"foo": foo,
"bar": bar,
**rest,
}:
pass
case _:
pass
""",
"""
match event.get():
case Click(position=(x, y)):
handle_click_at(x, y)
case KeyPress(key_name="Q") | Quit():
game.quit()
case KeyPress(key_name="up arrow"):
game.go_north()
case KeyPress():
pass # Ignore other keystrokes
case other_event:
raise ValueError(f"Unrecognized event: {other_event}")
""",
"""
match event.get():
case Click((x, y), button=Button.LEFT): # This is a left click
handle_click_at(x, y)
case Click():
pass # ignore other clicks
""",
"""
def http_error(status):
match status:
case 400:
return "Bad request"
case 404:
return "Not found"
case 418:
return "I'm a teapot"
case 500 | 501 | 502:
return "I'm a teapot"
case _:
return "Something's wrong with the Internet"
"""
]
]

if sys.version_info >= (3, 10):
examples_transform.extend(
[
dedent(x)
for x in [
"""
match { "foo": 1, "bar": 2 }:
case {
"foo": foo,
"bar": bar,
**rest,
}:
pass
case _:
pass
""",
"""
match event.get():
case Click(position=(x, y)):
handle_click_at(x, y)
case KeyPress(key_name="Q") | Quit():
game.quit()
case KeyPress(key_name="up arrow"):
game.go_north()
case KeyPress():
pass # Ignore other keystrokes
case other_event:
raise ValueError(f"Unrecognized event: {other_event}")
""",
"""
match event.get():
case Click((x, y), button=Button.LEFT): # This is a left click
handle_click_at(x, y)
case Click():
pass # ignore other clicks
""",
"""
def http_error(status):
match status:
case 400:
return "Bad request"
case 404:
return "Not found"
case 418:
return "I'm a teapot"
case 500 | 501 | 502:
return "I'm a teapot"
case _:
return "Something's wrong with the Internet"
"""
]
]
)


@pytest.mark.parametrize("source", examples_transform)
def test_PythonBlock_flags_type_comment_ignore_fails_transform(source):
Expand All @@ -629,8 +638,7 @@ def test_PythonBlock_flags_type_comment_ignore_fails_transform(source):
Type: ignore are custom ast.AST who have no col_offset.
"""
block = PythonBlock(
dedent(source))
block = PythonBlock(dedent(source))
s = SourceToSourceFileImportsTransformation(block)
assert s.output() == block

Expand Down

0 comments on commit f2f1111

Please sign in to comment.