Skip to content

Commit

Permalink
Allow trailing separator for ListNonterm (#6963)
Browse files Browse the repository at this point in the history
* Add allow_trailing_separator argument to ListNonterm.

* Change list nonterms to use new argument.

* Combine ListNonterm with helper class.
  • Loading branch information
dnwpark authored Mar 2, 2024
1 parent ffe63aa commit df6a4e5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 76 deletions.
108 changes: 85 additions & 23 deletions edb/common/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import json
import logging
import os
import sys
import types

import parsing
Expand Down Expand Up @@ -149,34 +150,95 @@ def __init_subclass__(cls, *, is_internal=False, **kwargs):

class ListNonterm(Nonterm, is_internal=True):
def __init_subclass__(cls, *, element, separator=None, is_internal=False,
**kwargs):
allow_trailing_separator=False, **kwargs):
"""Create reductions for list classes.
if not is_internal:
element_name = ListNonterm._component_name(element)
separator_name = ListNonterm._component_name(separator)
If trailing separator is not allowed, the class can handle all
reductions directly.
L := E
L := L S E
If trailing separator is allowed, create an inner class to handle
all non-trailing reductions. Then the class handles the trailing
separator.
I := E
I := I S E
L := I
L := I S
if separator_name:
tail_prod = (
lambda self, lst, sep, el: self._reduce_list(lst, el)
The inner class is added to the same module as the class.
"""
if not is_internal:
if not allow_trailing_separator:
# directly handle the list
ListNonterm.add_list_reductions(
cls, element=element, separator=separator
)
tail_prod_name = 'reduce_{}_{}_{}'.format(
cls.__name__, separator_name, element_name)

else:
tail_prod = (
lambda self, lst, el: self._reduce_list(lst, el)
)
tail_prod_name = 'reduce_{}_{}'.format(
cls.__name__, element_name)
setattr(cls, tail_prod_name, tail_prod)
# create inner list class and add to same module
mod = sys.modules[cls.__module__]

def inner_cls_exec(ns):
ns['__module__'] = mod.__name__
return ns

inner_cls_name = cls.__name__ + 'Inner'
inner_cls_kwds = dict(element=element, separator=separator)
inner_cls = types.new_class(inner_cls_name, (ListNonterm,),
inner_cls_kwds, inner_cls_exec)
setattr(mod, inner_cls_name, inner_cls)

# create reduce_inner function
separator_name = ListNonterm.component_name(separator)

setattr(cls,
'reduce_{}'.format(inner_cls_name),
lambda self, inner: (
ListNonterm._reduce_inner(self, inner)
))
setattr(cls,
'reduce_{}_{}'.format(inner_cls_name, separator_name),
lambda self, inner, sep: (
ListNonterm._reduce_inner(self, inner)
))

setattr(cls, 'reduce_' + element_name,
lambda self, el: self._reduce_el(el))

# reduce functions must be present before calling superclass
super().__init_subclass__(is_internal=is_internal, **kwargs)

def __iter__(self):
return iter(self.val)

def __len__(self):
return len(self.val)

@staticmethod
def _component_name(component: type) -> Optional[str]:
def add_list_reductions(cls, *, element, separator=None):
element_name = ListNonterm.component_name(element)
separator_name = ListNonterm.component_name(separator)

if separator_name:
tail_prod = lambda self, lst, sep, el: (
ListNonterm._reduce_list(self, lst, el)
)
tail_prod_name = 'reduce_{}_{}_{}'.format(
cls.__name__, separator_name, element_name)
else:
tail_prod = lambda self, lst, el: (
ListNonterm._reduce_list(self, lst, el)
)
tail_prod_name = 'reduce_{}_{}'.format(
cls.__name__, element_name)
setattr(cls, tail_prod_name, tail_prod)

setattr(cls, 'reduce_' + element_name,
lambda self, el: ListNonterm._reduce_el(self, el))

@staticmethod
def component_name(component: type) -> Optional[str]:
if component is None:
return None
elif issubclass(component, Token):
Expand All @@ -187,6 +249,7 @@ def _component_name(component: type) -> Optional[str]:
raise Exception(
'List component must be a Token or Nonterm')

@staticmethod
def _reduce_list(self, lst, el):
if el.val is None:
tail = []
Expand All @@ -195,6 +258,7 @@ def _reduce_list(self, lst, el):

self.val = lst.val + tail

@staticmethod
def _reduce_el(self, el):
if el.val is None:
tail = []
Expand All @@ -203,11 +267,9 @@ def _reduce_el(self, el):

self.val = tail

def __iter__(self):
return iter(self.val)

def __len__(self):
return len(self.val)
@staticmethod
def _reduce_inner(self, inner):
self.val = inner.val


def precedence(precedence):
Expand Down
12 changes: 2 additions & 10 deletions edb/edgeql/parser/grammar/commondl.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,11 @@ def reduce_OptParameterKind_FuncDeclArgName_OptDefault(


class FuncDeclArgList(parsing.ListNonterm, element=FuncDeclArg,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA, allow_trailing_separator=True):
pass


class FuncDeclArgs(Nonterm):
@parsing.inline(0)
def reduce_FuncDeclArgList_COMMA(self, list, _):
pass

@parsing.inline(0)
def reduce_FuncDeclArgList(self, list):
pass
Expand Down Expand Up @@ -701,15 +697,11 @@ def reduce_FuncDeclArgName_OptDefault(self, name, default):


class IndexArgList(parsing.ListNonterm, element=IndexArg,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA, allow_trailing_separator=True):
pass


class OptIndexArgList(Nonterm):
@parsing.inline(0)
def reduce_IndexArgList_COMMA(self, list, _):
pass

@parsing.inline(0)
def reduce_IndexArgList(self, list):
pass
Expand Down
56 changes: 13 additions & 43 deletions edb/edgeql/parser/grammar/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def reduce_Expr(self, *kids):


class AliasedExprList(ListNonterm, element=AliasedExpr,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA, allow_trailing_separator=True):
pass


Expand Down Expand Up @@ -231,10 +231,6 @@ class UsingClause(Nonterm):
def reduce_USING_AliasedExprList(self, *kids):
pass

@parsing.inline(1)
def reduce_USING_AliasedExprList_COMMA(self, *kids):
pass


class OptUsingClause(Nonterm):
@parsing.inline(0)
Expand Down Expand Up @@ -379,12 +375,6 @@ def reduce_WITH_WithDeclList(self, *kids):
aliases.append(w)
self.val = WithBlockData(aliases=aliases)

def reduce_WITH_WithDeclList_COMMA(self, *kids):
aliases = []
for w in kids[1].val:
aliases.append(w)
self.val = WithBlockData(aliases=aliases)


class AliasDecl(Nonterm):
def reduce_MODULE_ModuleName(self, *kids):
Expand All @@ -408,7 +398,7 @@ def reduce_AliasDecl(self, *kids):


class WithDeclList(ListNonterm, element=WithDecl,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA, allow_trailing_separator=True):
pass


Expand All @@ -420,10 +410,6 @@ def reduce_LBRACE_RBRACE(self, *kids):
def reduce_LBRACE_ShapeElementList_RBRACE(self, *kids):
pass

@parsing.inline(1)
def reduce_LBRACE_ShapeElementList_COMMA_RBRACE(self, *kids):
pass


class OptShape(Nonterm):
@parsing.inline(0)
Expand Down Expand Up @@ -452,9 +438,6 @@ class FreeShape(Nonterm):
def reduce_LBRACE_FreeComputableShapePointerList_RBRACE(self, *kids):
self.val = qlast.Shape(elements=kids[1].val)

def reduce_LBRACE_FreeComputableShapePointerList_COMMA_RBRACE(self, *kids):
self.val = qlast.Shape(elements=kids[1].val)


class OptAnySubShape(Nonterm):
@parsing.inline(1)
Expand Down Expand Up @@ -483,7 +466,7 @@ def reduce_ComputableShapePointer(self, *kids):


class ShapeElementList(ListNonterm, element=ShapeElement,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA, allow_trailing_separator=True):
pass


Expand Down Expand Up @@ -1004,7 +987,8 @@ def reduce_FreeSimpleShapePointer_ASSIGN_Expr(self, *kids):

class FreeComputableShapePointerList(ListNonterm,
element=FreeComputableShapePointer,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA,
allow_trailing_separator=True):
pass


Expand Down Expand Up @@ -1507,9 +1491,6 @@ class NamedTuple(Nonterm):
def reduce_LPAREN_NamedTupleElementList_RPAREN(self, *kids):
self.val = qlast.NamedTuple(elements=kids[1].val)

def reduce_LPAREN_NamedTupleElementList_COMMA_RPAREN(self, *kids):
self.val = qlast.NamedTuple(elements=kids[1].val)


class NamedTupleElement(Nonterm):
def reduce_ShortNodeName_ASSIGN_Expr(self, *kids):
Expand All @@ -1520,7 +1501,8 @@ def reduce_ShortNodeName_ASSIGN_Expr(self, *kids):


class NamedTupleElementList(ListNonterm, element=NamedTupleElement,
separator=tokens.T_COMMA):
separator=tokens.T_COMMA,
allow_trailing_separator=True):
pass


Expand All @@ -1536,10 +1518,6 @@ def reduce_LBRACKET_OptExprList_RBRACKET(self, *kids):


class OptExprList(Nonterm):
@parsing.inline(0)
def reduce_ExprList_COMMA(self, *kids):
pass

@parsing.inline(0)
def reduce_ExprList(self, *kids):
pass
Expand All @@ -1548,7 +1526,8 @@ def reduce_empty(self, *kids):
self.val = []


class ExprList(ListNonterm, element=Expr, separator=tokens.T_COMMA):
class ExprList(ListNonterm, element=Expr, separator=tokens.T_COMMA,
allow_trailing_separator=True):
pass


Expand Down Expand Up @@ -1821,15 +1800,12 @@ def reduce_FuncCallArgExpr_OptFilterClause_OptSortClause(self, *kids):
self.val = (self.val[0], self.val[1], qry)


class FuncArgList(ListNonterm, element=FuncCallArg, separator=tokens.T_COMMA):
class FuncArgList(ListNonterm, element=FuncCallArg, separator=tokens.T_COMMA,
allow_trailing_separator=True):
pass


class OptFuncArgList(Nonterm):
@parsing.inline(0)
def reduce_FuncArgList_COMMA(self, *kids):
pass

@parsing.inline(0)
def reduce_FuncArgList(self, *kids):
pass
Expand Down Expand Up @@ -2022,13 +1998,6 @@ def reduce_NodeName_LANGBRACKET_SubtypeList_RANGBRACKET(self, *kids):
subtypes=kids[2].val,
)

def reduce_NodeName_LANGBRACKET_SubtypeList_COMMA_RANGBRACKET(self, *kids):
self.validate_subtype_list(kids[2])
self.val = qlast.TypeName(
maintype=kids[0].val,
subtypes=kids[2].val,
)


class TypeName(Nonterm):
@parsing.inline(0)
Expand Down Expand Up @@ -2128,7 +2097,8 @@ def reduce_BaseNumberConstant(self, *kids):
)


class SubtypeList(ListNonterm, element=Subtype, separator=tokens.T_COMMA):
class SubtypeList(ListNonterm, element=Subtype, separator=tokens.T_COMMA,
allow_trailing_separator=True):
pass


Expand Down

0 comments on commit df6a4e5

Please sign in to comment.