From cbe6f5b16b3083357085a926b813c4d3aafafbc1 Mon Sep 17 00:00:00 2001 From: aicookies Date: Thu, 30 Nov 2023 14:40:14 +0800 Subject: [PATCH] [Feature]Add StarRocks connection and dialect --- pilot/common/schema.py | 2 +- pilot/connections/rdbms/conn_starrocks.py | 48 ++++++++++--------- .../rdbms/dialect/starrocks/__init__.py | 1 - .../dialect/starrocks/sqlalchemy/__init__.py | 7 ++- .../dialect/starrocks/sqlalchemy/datatype.py | 12 ++--- .../dialect/starrocks/sqlalchemy/dialect.py | 37 +++++++++----- 6 files changed, 62 insertions(+), 45 deletions(-) diff --git a/pilot/common/schema.py b/pilot/common/schema.py index 7b065da67..5c8fda027 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -30,7 +30,7 @@ class DBType(Enum): Postgresql = DbInfo("postgresql") Clickhouse = DbInfo("clickhouse") StarRocks = DbInfo("starrocks") - + Spark = DbInfo("spark", True) def value(self): diff --git a/pilot/connections/rdbms/conn_starrocks.py b/pilot/connections/rdbms/conn_starrocks.py index 836a665c4..9b5853a94 100644 --- a/pilot/connections/rdbms/conn_starrocks.py +++ b/pilot/connections/rdbms/conn_starrocks.py @@ -22,7 +22,9 @@ def from_uri_db( engine_args: Optional[dict] = None, **kwargs: Any, ) -> RDBMSDatabase: - db_url: str = f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}" + db_url: str = ( + f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}" + ) return cls.from_uri(db_url, engine_args, **kwargs) def _sync_tables_from_db(self) -> Iterable[str]: @@ -32,11 +34,7 @@ def _sync_tables_from_db(self) -> Iterable[str]: f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' ) ) - view_results = self.session.execute( - text( - f'SHOW MATERIALIZED VIEWS' - ) - ) + view_results = self.session.execute(text(f"SHOW MATERIALIZED VIEWS")) table_results = set(row[0] for row in table_results) view_results = set(row[2] for row in view_results) self._all_tables = table_results.union(view_results) @@ -45,11 +43,7 @@ def _sync_tables_from_db(self) -> Iterable[str]: def get_grants(self): session = self._db_sessions() - cursor = session.execute( - text( - "SHOW GRANTS" - ) - ) + cursor = session.execute(text("SHOW GRANTS")) grants = cursor.fetchall() grants_list = [x[2] for x in grants] return grants_list @@ -57,7 +51,7 @@ def get_grants(self): def _get_current_version(self): """Get database current version""" return int(self.session.execute(text("select current_version()")).scalar()) - + def get_collation(self): """Get collation.""" # StarRocks 排序是表级别的 @@ -73,9 +67,9 @@ def get_users(self): # return [user[0] for user in users] # except Exception as e: # print("starrocks get users error: ", e) - # return [] + # return [] return [] - + def get_fields(self, table_name, db_name="database()"): """Get column fields about specified table.""" session = self._db_sessions() @@ -100,7 +94,7 @@ def get_fields(self, table_name, db_name="database()"): def get_charset(self): """Get character_set.""" - + return "utf-8" def get_show_create_table(self, table_name): @@ -114,7 +108,11 @@ def get_show_create_table(self, table_name): # return create_sql # 这里是要表描述, 返回建表语句会导致token过长而失败 - cur = self.session.execute(text(f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()')) + cur = self.session.execute( + text( + f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()' + ) + ) table = cur.fetchone() if table: return str(table[0]) @@ -132,7 +130,11 @@ def get_table_comments(self, db_name=None): # comments.append((table_name, table_comment)) if not db_name: db_name = self.get_current_db_name() - cur = self.session.execute(text(f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"')) + cur = self.session.execute( + text( + f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' + ) + ) tables = cur.fetchall() return [(table[0], table[1]) for table in tables] @@ -143,7 +145,11 @@ def get_database_names(self): session = self._db_sessions() cursor = session.execute(text("SHOW DATABASES;")) results = cursor.fetchall() - return [d[0] for d in results if d[0] not in ["information_schema", "sys", "_statistics_","dataease"]] + return [ + d[0] + for d in results + if d[0] not in ["information_schema", "sys", "_statistics_", "dataease"] + ] def get_current_db_name(self) -> str: return self.session.execute(text("select database()")).scalar() @@ -160,10 +166,6 @@ def table_simple_info(self): def get_indexes(self, table_name): """Get table indexes about specified table.""" session = self._db_sessions() - cursor = session.execute( - text( - f"SHOW INDEX FROM {table_name}" - ) - ) + cursor = session.execute(text(f"SHOW INDEX FROM {table_name}")) indexes = cursor.fetchall() return [(index[2], index[4]) for index in indexes] diff --git a/pilot/connections/rdbms/dialect/starrocks/__init__.py b/pilot/connections/rdbms/dialect/starrocks/__init__.py index c2e11e974..20fa42f53 100644 --- a/pilot/connections/rdbms/dialect/starrocks/__init__.py +++ b/pilot/connections/rdbms/dialect/starrocks/__init__.py @@ -1,4 +1,3 @@ - #! /usr/bin/python3 # Copyright 2021-present StarRocks, Inc. All rights reserved. # diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py index c21f12d33..601cc7346 100644 --- a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py @@ -1,4 +1,3 @@ - #! /usr/bin/python3 # Copyright 2021-present StarRocks, Inc. All rights reserved. # @@ -16,4 +15,8 @@ from sqlalchemy.dialects import registry -registry.register("starrocks", "pilot.connections.rdbms.dialect.starrocks.sqlalchemy.dialect", "StarRocksDialect") +registry.register( + "starrocks", + "pilot.connections.rdbms.dialect.starrocks.sqlalchemy.dialect", + "StarRocksDialect", +) diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py index c1dce586e..e8542f940 100644 --- a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py @@ -80,12 +80,12 @@ def python_type(self) -> Optional[Type[Any]]: "datetime": sqltypes.DATETIME, "timestamp": sqltypes.DATETIME, # === Structural === - 'array': ARRAY, - 'map': MAP, - 'struct': STRUCT, - 'hll': HLL, - 'percentile': PERCENTILE, - 'bitmap': BITMAP, + "array": ARRAY, + "map": MAP, + "struct": STRUCT, + "hll": HLL, + "percentile": PERCENTILE, + "bitmap": BITMAP, } diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py index a73021ff2..f92054d90 100644 --- a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py @@ -31,7 +31,7 @@ class StarRocksDialect(MySQLDialect_pymysql): # and tests are needed before being enabled supports_statement_cache = False - name = 'starrocks' + name = "starrocks" def __init__(self, *args, **kw): super(StarRocksDialect, self).__init__(*args, **kw) @@ -90,7 +90,9 @@ def get_view_names(self, connection, schema=None, **kw): if row[1] in ("VIEW", "SYSTEM VIEW") ] - def get_columns(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]: + def get_columns( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: if not self.has_table(connection, table_name, schema): raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") schema = schema or self._get_default_schema_name(connection) @@ -112,7 +114,6 @@ def get_columns(self, connection: Connection, table_name: str, schema: str = Non columns.append(column) return columns - def get_pk_constraint(self, connection, table_name, schema=None, **kw): return { # type: ignore # pep-655 not supported "name": None, @@ -120,41 +121,53 @@ def get_pk_constraint(self, connection, table_name, schema=None, **kw): } def get_unique_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, connection: Connection, table_name: str, schema: str = None, **kw ) -> List[Dict[str, Any]]: return [] def get_check_constraints( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, connection: Connection, table_name: str, schema: str = None, **kw ) -> List[Dict[str, Any]]: return [] def get_foreign_keys( - self, connection: Connection, table_name: str, schema: str = None, **kw + self, connection: Connection, table_name: str, schema: str = None, **kw ) -> List[Dict[str, Any]]: return [] - def get_primary_keys(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[str]: + def get_primary_keys( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[str]: pk = self.get_pk_constraint(connection, table_name, schema) return pk.get("constrained_columns") # type: ignore def get_indexes(self, connection, table_name, schema=None, **kw): return [] - def has_sequence(self, connection: Connection, sequence_name: str, schema: str = None, **kw) -> bool: + def has_sequence( + self, connection: Connection, sequence_name: str, schema: str = None, **kw + ) -> bool: return False - def get_sequence_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_sequence_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: return [] - def get_temp_view_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_temp_view_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: return [] - def get_temp_table_names(self, connection: Connection, schema: str = None, **kw) -> List[str]: + def get_temp_table_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: return [] def get_table_options(self, connection, table_name, schema=None, **kw): return {} - def get_table_comment(self, connection: Connection, table_name: str, schema: str = None, **kw) -> Dict[str, Any]: + def get_table_comment( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> Dict[str, Any]: return dict(text=None)