Skip to content

Commit

Permalink
[Feature]Add StarRocks connection and dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
quqibing committed Nov 30, 2023
1 parent 48ab392 commit cbe6f5b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 45 deletions.
2 changes: 1 addition & 1 deletion pilot/common/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DBType(Enum):
Postgresql = DbInfo("postgresql")
Clickhouse = DbInfo("clickhouse")
StarRocks = DbInfo("starrocks")

Spark = DbInfo("spark", True)

def value(self):
Expand Down
48 changes: 25 additions & 23 deletions pilot/connections/rdbms/conn_starrocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def from_uri_db(
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}"
db_url: str = (
f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)

def _sync_tables_from_db(self) -> Iterable[str]:
Expand All @@ -32,11 +34,7 @@ def _sync_tables_from_db(self) -> Iterable[str]:
f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
)
)
view_results = self.session.execute(
text(
f'SHOW MATERIALIZED VIEWS'
)
)
view_results = self.session.execute(text(f"SHOW MATERIALIZED VIEWS"))
table_results = set(row[0] for row in table_results)
view_results = set(row[2] for row in view_results)
self._all_tables = table_results.union(view_results)
Expand All @@ -45,19 +43,15 @@ def _sync_tables_from_db(self) -> Iterable[str]:

def get_grants(self):
session = self._db_sessions()
cursor = session.execute(
text(
"SHOW GRANTS"
)
)
cursor = session.execute(text("SHOW GRANTS"))
grants = cursor.fetchall()
grants_list = [x[2] for x in grants]
return grants_list

def _get_current_version(self):
"""Get database current version"""
return int(self.session.execute(text("select current_version()")).scalar())

def get_collation(self):
"""Get collation."""
# StarRocks 排序是表级别的
Expand All @@ -73,9 +67,9 @@ def get_users(self):
# return [user[0] for user in users]
# except Exception as e:
# print("starrocks get users error: ", e)
# return []
# return []
return []

def get_fields(self, table_name, db_name="database()"):
"""Get column fields about specified table."""
session = self._db_sessions()
Expand All @@ -100,7 +94,7 @@ def get_fields(self, table_name, db_name="database()"):

def get_charset(self):
"""Get character_set."""

return "utf-8"

def get_show_create_table(self, table_name):
Expand All @@ -114,7 +108,11 @@ def get_show_create_table(self, table_name):

# return create_sql
# 这里是要表描述, 返回建表语句会导致token过长而失败
cur = self.session.execute(text(f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'))
cur = self.session.execute(
text(
f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
)
)
table = cur.fetchone()
if table:
return str(table[0])
Expand All @@ -132,7 +130,11 @@ def get_table_comments(self, db_name=None):
# comments.append((table_name, table_comment))
if not db_name:
db_name = self.get_current_db_name()
cur = self.session.execute(text(f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'))
cur = self.session.execute(
text(
f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
)
)
tables = cur.fetchall()
return [(table[0], table[1]) for table in tables]

Expand All @@ -143,7 +145,11 @@ def get_database_names(self):
session = self._db_sessions()
cursor = session.execute(text("SHOW DATABASES;"))
results = cursor.fetchall()
return [d[0] for d in results if d[0] not in ["information_schema", "sys", "_statistics_","dataease"]]
return [
d[0]
for d in results
if d[0] not in ["information_schema", "sys", "_statistics_", "dataease"]
]

def get_current_db_name(self) -> str:
return self.session.execute(text("select database()")).scalar()
Expand All @@ -160,10 +166,6 @@ def table_simple_info(self):
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SHOW INDEX FROM {table_name}"
)
)
cursor = session.execute(text(f"SHOW INDEX FROM {table_name}"))
indexes = cursor.fetchall()
return [(index[2], index[4]) for index in indexes]
1 change: 0 additions & 1 deletion pilot/connections/rdbms/dialect/starrocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#! /usr/bin/python3
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#! /usr/bin/python3
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
Expand All @@ -16,4 +15,8 @@

from sqlalchemy.dialects import registry

registry.register("starrocks", "pilot.connections.rdbms.dialect.starrocks.sqlalchemy.dialect", "StarRocksDialect")
registry.register(
"starrocks",
"pilot.connections.rdbms.dialect.starrocks.sqlalchemy.dialect",
"StarRocksDialect",
)
12 changes: 6 additions & 6 deletions pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def python_type(self) -> Optional[Type[Any]]:
"datetime": sqltypes.DATETIME,
"timestamp": sqltypes.DATETIME,
# === Structural ===
'array': ARRAY,
'map': MAP,
'struct': STRUCT,
'hll': HLL,
'percentile': PERCENTILE,
'bitmap': BITMAP,
"array": ARRAY,
"map": MAP,
"struct": STRUCT,
"hll": HLL,
"percentile": PERCENTILE,
"bitmap": BITMAP,
}


Expand Down
37 changes: 25 additions & 12 deletions pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StarRocksDialect(MySQLDialect_pymysql):
# and tests are needed before being enabled
supports_statement_cache = False

name = 'starrocks'
name = "starrocks"

def __init__(self, *args, **kw):
super(StarRocksDialect, self).__init__(*args, **kw)
Expand Down Expand Up @@ -90,7 +90,9 @@ def get_view_names(self, connection, schema=None, **kw):
if row[1] in ("VIEW", "SYSTEM VIEW")
]

def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
def get_columns(
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[Dict[str, Any]]:
if not self.has_table(connection, table_name, schema):
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
schema = schema or self._get_default_schema_name(connection)
Expand All @@ -112,49 +114,60 @@ def get_columns(self, connection: Connection, table_name: str, schema: str = Non
columns.append(column)
return columns


def get_pk_constraint(self, connection, table_name, schema=None, **kw):
return { # type: ignore # pep-655 not supported
"name": None,
"constrained_columns": [],
}

def get_unique_constraints(
self, connection: Connection, table_name: str, schema: str = None, **kw
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[Dict[str, Any]]:
return []

def get_check_constraints(
self, connection: Connection, table_name: str, schema: str = None, **kw
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[Dict[str, Any]]:
return []

def get_foreign_keys(
self, connection: Connection, table_name: str, schema: str = None, **kw
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[Dict[str, Any]]:
return []

def get_primary_keys(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[str]:
def get_primary_keys(
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[str]:
pk = self.get_pk_constraint(connection, table_name, schema)
return pk.get("constrained_columns") # type: ignore

def get_indexes(self, connection, table_name, schema=None, **kw):
return []

def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool:
def has_sequence(
self, connection: Connection, sequence_name: str, schema: str = None, **kw
) -> bool:
return False

def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
def get_sequence_names(
self, connection: Connection, schema: str = None, **kw
) -> List[str]:
return []

def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
def get_temp_view_names(
self, connection: Connection, schema: str = None, **kw
) -> List[str]:
return []

def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]:
def get_temp_table_names(
self, connection: Connection, schema: str = None, **kw
) -> List[str]:
return []

def get_table_options(self, connection, table_name, schema=None, **kw):
return {}

def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]:
def get_table_comment(
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> Dict[str, Any]:
return dict(text=None)

0 comments on commit cbe6f5b

Please sign in to comment.