-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement AstBuilder to extract metadata from the query
- Loading branch information
Showing
1 changed file
with
50 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import typing as t | ||
|
||
from cratedb_sqlparse.generated_parser.SqlBaseParser import SqlBaseParser | ||
from cratedb_sqlparse.generated_parser.SqlBaseParserVisitor import SqlBaseParserVisitor | ||
|
||
|
||
class AstBuilder(SqlBaseParserVisitor): | ||
""" | ||
The class implements the antlr4 visitor pattern similar to how we do it in CrateDB | ||
https://github.com/crate/crate/blob/master/libs/sql-parser/src/main/java/io/crate/sql/parser/AstBuilder.java | ||
The biggest difference is that in CrateDB, `AstBuilder`, visitor methods | ||
return a specialized Statement visitor. | ||
Sqlparse just extracts whatever data it needs from the context and injects it to the current | ||
visited statement, enriching its metadata. | ||
""" | ||
|
||
@property | ||
def stmt(self): | ||
if not hasattr(self, "_stmt"): | ||
raise Exception("You should call `enrich` first, that is the entrypoint.") | ||
return self._stmt | ||
|
||
@stmt.setter | ||
def stmt(self, value): | ||
self._stmt = value | ||
|
||
def enrich(self, stmt) -> None: | ||
self.stmt = stmt | ||
self.visit(self.stmt.ctx) | ||
|
||
def visitTableName(self, ctx: SqlBaseParser.TableNameContext): | ||
fqn = ctx.qname() | ||
parts = self.get_text(fqn).replace('"', "").split(".") | ||
|
||
if len(parts) == 1: | ||
name = parts[0] | ||
schema = None | ||
else: | ||
schema, name = parts | ||
|
||
self.stmt.metadata.table_name = name | ||
self.stmt.metadata.schema = schema | ||
|
||
def get_text(self, node) -> t.Optional[str]: | ||
"""Gets the text representation of the node or None if it doesn't have one""" | ||
if node: | ||
return node.getText() | ||
return node |