diff --git a/.gitignore b/.gitignore index aee02fc..6a56064 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ build dist luaparser.egg-info -venv \ No newline at end of file +venv +.venv diff --git a/luaparser/astnodes.py b/luaparser/astnodes.py index ea2fa69..6f1f3f6 100644 --- a/luaparser/astnodes.py +++ b/luaparser/astnodes.py @@ -7,7 +7,7 @@ from enum import Enum from typing import List, Optional -from antlr4 import Token +from antlr4.Token import CommonToken Comments = Optional[List["Comment"]] @@ -30,8 +30,8 @@ def __init__( self, name: str, comments: Comments = None, - first_token: Optional[Token] = None, - last_token: Optional[Token] = None, + first_token: Optional[CommonToken] = None, + last_token: Optional[CommonToken] = None, ): """ @@ -45,8 +45,19 @@ def __init__( comments = [] self._name: str = name self.comments: Comments = comments - self._first_token: Optional[Token] = first_token - self._last_token: Optional[Token] = last_token + self._first_token: Optional[CommonToken] = first_token + self._last_token: Optional[CommonToken] = last_token + + # We want to have nodes be serializable with pickle. + # To allow that we must not have mutable fields such as streams. + # Tokens have streams, create a stream-less copy of tokens. + if self._first_token is not None: + self._first_token = self._first_token.clone() + self._first_token.source = CommonToken.EMPTY_SOURCE + + if self._last_token is not None: + self._last_token = self._last_token.clone() + self._last_token.source = CommonToken.EMPTY_SOURCE @property def display_name(self) -> str: @@ -60,20 +71,34 @@ def __eq__(self, other) -> bool: return False @property - def first_token(self) -> Optional[Token]: + def first_token(self) -> Optional[CommonToken]: + """ + First token of a node. + + Note: Token is disconnected from underline source streams. + """ return self._first_token @first_token.setter - def first_token(self, val): - self._first_token = val + def first_token(self, val: Optional[CommonToken]): + if val is not None: + self._first_token = val.clone() + self._first_token.source = CommonToken.EMPTY_SOURCE @property - def last_token(self) -> Optional[Token]: + def last_token(self) -> Optional[CommonToken]: + """ + Last token of a node. + + Note: Token is disconnected from underline source streams. + """ return self._last_token @last_token.setter - def last_token(self, val): - self._last_token = val + def last_token(self, val: Optional[CommonToken]): + if val is not None: + self._last_token = val.clone() + self._last_token.source = CommonToken.EMPTY_SOURCE @property def start_char(self) -> Optional[int]: diff --git a/luaparser/builder.py b/luaparser/builder.py index ed53505..490b369 100644 --- a/luaparser/builder.py +++ b/luaparser/builder.py @@ -220,7 +220,7 @@ def __init__(self, source): self._hidden_handled_stack: List[bool] = [] @property - def _LT(self) -> Token: + def _LT(self) -> CommonToken: """Last token that was consumed in next_i*_* method.""" return self._stream.LT(-1) @@ -665,9 +665,7 @@ def parse_tail(self) -> Node or bool: last_token=self._LT, ) if self.next_is_rc(Tokens.STRING, False): - string = self.parse_lua_str(self.text) - string.first_token = self._LT.start - string.last_token = self._LT + string = self.parse_lua_str(self.text, self._LT) self.success() # noinspection PyTypeChecker return Invoke(None, name, [string]) @@ -705,9 +703,7 @@ def parse_tail(self) -> Node or bool: self.failure_save() if self.next_is_rc(Tokens.STRING, False): - string = self.parse_lua_str(self.text) - string.first_token = self._LT - string.last_token = self._LT + string = self.parse_lua_str(self.text, self._LT) self.success() return string @@ -1410,9 +1406,7 @@ def parse_atom(self) -> Expression or bool: ) if self.next_is(Tokens.STRING) and self.next_is_rc(Tokens.STRING): - string = self.parse_lua_str(self.text) - string.first_token = self._LT - string.last_token = self._LT + string = self.parse_lua_str(self.text, self._LT) return string if self.next_is(Tokens.NIL) and self.next_is_rc(Tokens.NIL): @@ -1426,7 +1420,7 @@ def parse_atom(self) -> Expression or bool: return None @staticmethod - def parse_lua_str(lua_str) -> String: + def parse_lua_str(lua_str, token: Optional[CommonToken] = None) -> String: delimiter: StringDelimiter = StringDelimiter.SINGLE_QUOTE p = re.compile(r"^\[=+\[(.*)]=+]") # nested quote pattern # try remove double quote: @@ -1444,7 +1438,7 @@ def parse_lua_str(lua_str) -> String: # nested quote elif p.match(lua_str): lua_str = p.search(lua_str).group(1) - return String(lua_str, delimiter) + return String(lua_str, delimiter, first_token=token, last_token=token) def parse_function_literal(self) -> AnonymousFunction or bool: self.save()