Skip to content

Commit

Permalink
add support for Method
Browse files Browse the repository at this point in the history
  • Loading branch information
boolangery committed Oct 25, 2024
1 parent 9410d94 commit ec266c3
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
2 changes: 1 addition & 1 deletion luaparser/astnodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def display_name(self) -> str:
def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return _equal_dicts(
self.__dict__, other.__dict__, ["_first_token", "_last_token"]
self.__dict__, other.__dict__, ["_tokens"]
)
return False

Expand Down
41 changes: 32 additions & 9 deletions luaparser/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,7 @@ def parse_field_sep(self) -> bool:


class BuilderVisitor(LuaParserVisitor):
VISITED_CHANNEL = -1 # a magic number used to mark a token as visited
COMMENT_CHANNEL = 2

def __init__(self, comment_token_stream: CommonTokenStream):
Expand Down Expand Up @@ -1710,7 +1711,12 @@ def add_comment_context(self, ctx: ParserRuleContext, node: Node):

for token in hidden_tokens_left:
if token.channel == self.COMMENT_CHANNEL:
node.comments.append(Comment(token.text, is_multi_line=token.type == LuaLexer.COMMENT))
node.comments.append(Comment(
token.text,
is_multi_line=token.type == LuaLexer.COMMENT,
tokens=[token],
))
token.channel = self.VISITED_CHANNEL # prevent from being visited again

return

Expand Down Expand Up @@ -1838,6 +1844,12 @@ def visitStat_for(self, ctx: LuaParser.Stat_forContext):
def visitStat_function(self, ctx: LuaParser.Stat_functionContext):
func_name = self.visitFuncname(ctx.funcname())
param_list, block = self.visitFuncbody(ctx.funcbody())

if isinstance(func_name, Method):
func_name.args = param_list
func_name.body = block
return func_name

return self.add_context(ctx, Function(func_name, param_list, block))

# Visit a parse tree produced by LuaParser#stat_localfunction.
Expand Down Expand Up @@ -1906,7 +1918,12 @@ def visitFuncname(self, ctx: LuaParser.FuncnameContext):
)

if has_invoke:
return self.add_context(ctx, Invoke(root, self.visit(names[-1]), []))
return self.add_context(ctx, Method(
source=root,
name=self.visit(names[-1]),
args=[],
body=Block([])
))

return self.add_context(ctx, root)

Expand Down Expand Up @@ -2006,7 +2023,7 @@ def visitExp(self, ctx: LuaParser.ExpContext) -> Expression:
# Visit a parse tree produced by LuaParser#var.
def visitVar(self, ctx: LuaParser.VarContext):
if ctx.NAME():
return self.add_context(ctx, Name(ctx.NAME().getText()))
return Name(ctx.NAME().getText())
else: # prefixexp tail
root = self.visit(ctx.prefixexp())
return self.visitAllTails(root, [ctx.tail()])
Expand All @@ -2029,22 +2046,25 @@ def visitFunctioncall_name(self, ctx: LuaParser.Functioncall_nameContext):
name = self.visit(ctx.NAME())
tail = self.visitAllTails(name, ctx.tail())
par, args = self.visitArgs(ctx.args())
return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Call(tail, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

# Visit a parse tree produced by LuaParser#functioncall_nested.
def visitFunctioncall_nested(self, ctx: LuaParser.Functioncall_nestedContext):
call = self.visit(ctx.functioncall())
tail = self.visitAllTails(call, ctx.tail())
par, args = self.visitArgs(ctx.args())
return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Call(tail, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

# Visit a parse tree produced by LuaParser#functioncall_exp.
def visitFunctioncall_exp(self, ctx: LuaParser.Functioncall_expContext):
exp = self.visitExp(ctx.exp())
exp.wrapped = True
tail = self.visitAllTails(exp, ctx.tail())
par, args = self.visitArgs(ctx.args())
return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Call(tail, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

# Visit a parse tree produced by LuaParser#functioncall_expinvoke.
def visitFunctioncall_expinvoke(self, ctx: LuaParser.Functioncall_expinvokeContext):
Expand All @@ -2053,23 +2073,26 @@ def visitFunctioncall_expinvoke(self, ctx: LuaParser.Functioncall_expinvokeConte
tail = self.visitAllTails(exp, ctx.tail())
par, args = self.visitArgs(ctx.args())
func = self.visit(ctx.NAME())
return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Invoke(tail, func, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

# Visit a parse tree produced by LuaParser#functioncall_invoke.
def visitFunctioncall_invoke(self, ctx: LuaParser.Functioncall_invokeContext):
source = self.visit(ctx.NAME(0))
func = self.visit(ctx.NAME(1))
tail = self.visitAllTails(source, ctx.tail())
par, args = self.visitArgs(ctx.args())
return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Invoke(tail, func, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

# Visit a parse tree produced by LuaParser#functioncall_nestedinvoke.
def visitFunctioncall_nestedinvoke(self, ctx: LuaParser.Functioncall_nestedinvokeContext):
call = self.visit(ctx.functioncall())
func = self.visit(ctx.NAME())
tail = self.visitAllTails(call, ctx.tail())
par, args = self.visitArgs(ctx.args())
return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))
return self.add_context(ctx, Invoke(tail, func, _listify(args),
style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS))

def visitAllTails(self, root_exp: Expression, tails: List[LuaParser.TailContext]):
if not tails:
Expand Down
28 changes: 27 additions & 1 deletion luaparser/tests/test_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,30 @@ def test_attr(self):
)
])
)
self.assertEqual(exp, tree)
self.assertEqual(exp, tree)

def test_method(self):
tree = ast.parse(
textwrap.dedent(
"""
function foo.bar:print(arg)
end
"""
)
)
exp = Chunk(
Block(
[
Method(
source=Index(
idx=Name("bar"),
value=Name("foo"),
),
name=Name("print"),
args=[Name("arg")],
body=Block([]),
)
]
)
)
self.assertEqual(exp, tree)

0 comments on commit ec266c3

Please sign in to comment.