Skip to content

Commit

Permalink
Merge pull request #340 from Carreau/pattern-match
Browse files Browse the repository at this point in the history
Add ability to understand pattern matching.
  • Loading branch information
Carreau authored May 14, 2024
2 parents 00a253c + f2f1111 commit 8d84db2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 26 deletions.
34 changes: 26 additions & 8 deletions lib/python/pyflyby/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,40 @@
# License: MIT http://opensource.org/licenses/MIT
from __future__ import annotations, print_function


import ast
from ast import AsyncFunctionDef, TypeIgnore

from collections import namedtuple
from doctest import DocTestParser
from functools import total_ordering
from itertools import groupby

from pyflyby._file import FilePos, FileText, Filename
from pyflyby._flags import CompilerFlags
from pyflyby._log import logger
from pyflyby._util import cached_attribute, cmp

import re
import sys
from textwrap import dedent
import types
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union, cast
import warnings

from pyflyby._file import FilePos, FileText, Filename
from pyflyby._flags import CompilerFlags
from pyflyby._log import logger
from pyflyby._util import cached_attribute, cmp
_sentinel = object()

if sys.version_info < (3, 10):

from ast import AsyncFunctionDef, TypeIgnore
class MatchAs:
name: str
pattern: ast.AST

_sentinel = object()
class MatchMapping:
keys: List[ast.AST]
patterns: List[MatchAs]

else:
from ast import MatchAs, MatchMapping


def _is_comment_or_blank(line, /):
Expand Down Expand Up @@ -162,6 +174,12 @@ def _iter_child_nodes_in_order_internal_1(node):
elif isinstance(node, ast.FormattedValue):
assert node._fields == ('value', 'conversion', 'format_spec')
yield node.value,
elif isinstance(node, MatchAs):
yield node.pattern
yield node.name,
elif isinstance(node, MatchMapping):
for k, p in zip(node.keys, node.patterns):
yield k, p
else:
# Default behavior.
yield ast.iter_child_nodes(node)
Expand Down
90 changes: 72 additions & 18 deletions tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,24 +555,79 @@ def f(x):
assert s.output() == block


examples_transform = ["""
a = None # type: ignore
"""]
examples_transform = [
dedent(x)
for x in [
"""
a = None # type: ignore
""",
"""
class A:
async def func(self, location: str) -> bytes:
async with aiofiles.open(location, "rb") as file:
return await file.read()
""",
# positional only
"""
def f(x, y=None, / , z=None):
pass
""",
]
]

examples_transform.append(
"""
class A:
async def func(self, location: str) -> bytes:
async with aiofiles.open(location, "rb") as file:
return await file.read()
"""
)
if sys.version_info >= (3, 10):
examples_transform.extend(
[
dedent(x)
for x in [
"""
match { "foo": 1, "bar": 2 }:
case {
"foo": foo,
"bar": bar,
**rest,
}:
pass
case _:
pass
""",
"""
match event.get():
case Click(position=(x, y)):
handle_click_at(x, y)
case KeyPress(key_name="Q") | Quit():
game.quit()
case KeyPress(key_name="up arrow"):
game.go_north()
case KeyPress():
pass # Ignore other keystrokes
case other_event:
raise ValueError(f"Unrecognized event: {other_event}")
""",
"""
match event.get():
case Click((x, y), button=Button.LEFT): # This is a left click
handle_click_at(x, y)
case Click():
pass # ignore other clicks
""",
"""
def http_error(status):
match status:
case 400:
return "Bad request"
case 404:
return "Not found"
case 418:
return "I'm a teapot"
case 500 | 501 | 502:
return "I'm a teapot"
case _:
return "Something's wrong with the Internet"
examples_transform.append(
# positional only
"""
def f(x, y=None, / , z=None):
pass"""
"""
]
]
)


Expand All @@ -583,8 +638,7 @@ def test_PythonBlock_flags_type_comment_ignore_fails_transform(source):
Type: ignore are custom ast.AST who have no col_offset.
"""
block = PythonBlock(
dedent(source))
block = PythonBlock(dedent(source))
s = SourceToSourceFileImportsTransformation(block)
assert s.output() == block

Expand Down

0 comments on commit 8d84db2

Please sign in to comment.