diff --git a/luaparser/astnodes.py b/luaparser/astnodes.py index 082e81a..4e69512 100644 --- a/luaparser/astnodes.py +++ b/luaparser/astnodes.py @@ -57,7 +57,7 @@ def display_name(self) -> str: def __eq__(self, other) -> bool: if isinstance(self, other.__class__): return _equal_dicts( - self.__dict__, other.__dict__, ["_first_token", "_last_token"] + self.__dict__, other.__dict__, ["_tokens"] ) return False diff --git a/luaparser/builder.py b/luaparser/builder.py index a00672c..190fd20 100644 --- a/luaparser/builder.py +++ b/luaparser/builder.py @@ -1658,6 +1658,7 @@ def parse_field_sep(self) -> bool: class BuilderVisitor(LuaParserVisitor): + VISITED_CHANNEL = -1 # a magic number used to mark a token as visited COMMENT_CHANNEL = 2 def __init__(self, comment_token_stream: CommonTokenStream): @@ -1710,7 +1711,12 @@ def add_comment_context(self, ctx: ParserRuleContext, node: Node): for token in hidden_tokens_left: if token.channel == self.COMMENT_CHANNEL: - node.comments.append(Comment(token.text, is_multi_line=token.type == LuaLexer.COMMENT)) + node.comments.append(Comment( + token.text, + is_multi_line=token.type == LuaLexer.COMMENT, + tokens=[token], + )) + token.channel = self.VISITED_CHANNEL # prevent from being visited again return @@ -1838,6 +1844,12 @@ def visitStat_for(self, ctx: LuaParser.Stat_forContext): def visitStat_function(self, ctx: LuaParser.Stat_functionContext): func_name = self.visitFuncname(ctx.funcname()) param_list, block = self.visitFuncbody(ctx.funcbody()) + + if isinstance(func_name, Method): + func_name.args = param_list + func_name.body = block + return func_name + return self.add_context(ctx, Function(func_name, param_list, block)) # Visit a parse tree produced by LuaParser#stat_localfunction. @@ -1906,7 +1918,12 @@ def visitFuncname(self, ctx: LuaParser.FuncnameContext): ) if has_invoke: - return self.add_context(ctx, Invoke(root, self.visit(names[-1]), [])) + return self.add_context(ctx, Method( + source=root, + name=self.visit(names[-1]), + args=[], + body=Block([]) + )) return self.add_context(ctx, root) @@ -2006,7 +2023,7 @@ def visitExp(self, ctx: LuaParser.ExpContext) -> Expression: # Visit a parse tree produced by LuaParser#var. def visitVar(self, ctx: LuaParser.VarContext): if ctx.NAME(): - return self.add_context(ctx, Name(ctx.NAME().getText())) + return Name(ctx.NAME().getText()) else: # prefixexp tail root = self.visit(ctx.prefixexp()) return self.visitAllTails(root, [ctx.tail()]) @@ -2029,14 +2046,16 @@ def visitFunctioncall_name(self, ctx: LuaParser.Functioncall_nameContext): name = self.visit(ctx.NAME()) tail = self.visitAllTails(name, ctx.tail()) par, args = self.visitArgs(ctx.args()) - return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Call(tail, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) # Visit a parse tree produced by LuaParser#functioncall_nested. def visitFunctioncall_nested(self, ctx: LuaParser.Functioncall_nestedContext): call = self.visit(ctx.functioncall()) tail = self.visitAllTails(call, ctx.tail()) par, args = self.visitArgs(ctx.args()) - return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Call(tail, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) # Visit a parse tree produced by LuaParser#functioncall_exp. def visitFunctioncall_exp(self, ctx: LuaParser.Functioncall_expContext): @@ -2044,7 +2063,8 @@ def visitFunctioncall_exp(self, ctx: LuaParser.Functioncall_expContext): exp.wrapped = True tail = self.visitAllTails(exp, ctx.tail()) par, args = self.visitArgs(ctx.args()) - return self.add_context(ctx, Call(tail, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Call(tail, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) # Visit a parse tree produced by LuaParser#functioncall_expinvoke. def visitFunctioncall_expinvoke(self, ctx: LuaParser.Functioncall_expinvokeContext): @@ -2053,7 +2073,8 @@ def visitFunctioncall_expinvoke(self, ctx: LuaParser.Functioncall_expinvokeConte tail = self.visitAllTails(exp, ctx.tail()) par, args = self.visitArgs(ctx.args()) func = self.visit(ctx.NAME()) - return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Invoke(tail, func, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) # Visit a parse tree produced by LuaParser#functioncall_invoke. def visitFunctioncall_invoke(self, ctx: LuaParser.Functioncall_invokeContext): @@ -2061,7 +2082,8 @@ def visitFunctioncall_invoke(self, ctx: LuaParser.Functioncall_invokeContext): func = self.visit(ctx.NAME(1)) tail = self.visitAllTails(source, ctx.tail()) par, args = self.visitArgs(ctx.args()) - return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Invoke(tail, func, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) # Visit a parse tree produced by LuaParser#functioncall_nestedinvoke. def visitFunctioncall_nestedinvoke(self, ctx: LuaParser.Functioncall_nestedinvokeContext): @@ -2069,7 +2091,8 @@ def visitFunctioncall_nestedinvoke(self, ctx: LuaParser.Functioncall_nestedinvok func = self.visit(ctx.NAME()) tail = self.visitAllTails(call, ctx.tail()) par, args = self.visitArgs(ctx.args()) - return self.add_context(ctx, Invoke(tail, func, _listify(args), style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) + return self.add_context(ctx, Invoke(tail, func, _listify(args), + style=CallStyle.DEFAULT if par else CallStyle.NO_PARENTHESIS)) def visitAllTails(self, root_exp: Expression, tails: List[LuaParser.TailContext]): if not tails: diff --git a/luaparser/tests/test_statements.py b/luaparser/tests/test_statements.py index e121503..87c5dbb 100644 --- a/luaparser/tests/test_statements.py +++ b/luaparser/tests/test_statements.py @@ -583,4 +583,30 @@ def test_attr(self): ) ]) ) - self.assertEqual(exp, tree) \ No newline at end of file + self.assertEqual(exp, tree) + + def test_method(self): + tree = ast.parse( + textwrap.dedent( + """ + function foo.bar:print(arg) + end + """ + ) + ) + exp = Chunk( + Block( + [ + Method( + source=Index( + idx=Name("bar"), + value=Name("foo"), + ), + name=Name("print"), + args=[Name("arg")], + body=Block([]), + ) + ] + ) + ) + self.assertEqual(exp, tree)