From d6bc112a969b92e86d1ae2d1e92f5e077b6f8a7c Mon Sep 17 00:00:00 2001 From: aicookies Date: Thu, 30 Nov 2023 18:49:36 +0800 Subject: [PATCH] [Feature]Add StarRocks 2.5+ DB adapter --- pilot/common/schema.py | 1 + .../connections/manages/connection_manager.py | 1 + pilot/connections/rdbms/conn_starrocks.py | 147 +++++++++++++++ pilot/connections/rdbms/dialect/__init__.py | 0 .../rdbms/dialect/starrocks/__init__.py | 14 ++ .../dialect/starrocks/sqlalchemy/__init__.py | 22 +++ .../dialect/starrocks/sqlalchemy/datatype.py | 104 +++++++++++ .../dialect/starrocks/sqlalchemy/dialect.py | 173 ++++++++++++++++++ 8 files changed, 462 insertions(+) create mode 100644 pilot/connections/rdbms/conn_starrocks.py create mode 100644 pilot/connections/rdbms/dialect/__init__.py create mode 100644 pilot/connections/rdbms/dialect/starrocks/__init__.py create mode 100644 pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py create mode 100644 pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py create mode 100644 pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py diff --git a/pilot/common/schema.py b/pilot/common/schema.py index f06159da9..561b019dc 100644 --- a/pilot/common/schema.py +++ b/pilot/common/schema.py @@ -29,6 +29,7 @@ class DBType(Enum): MSSQL = DbInfo("mssql") Postgresql = DbInfo("postgresql") Clickhouse = DbInfo("clickhouse") + StarRocks = DbInfo("starrocks") Spark = DbInfo("spark", True) def value(self): diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 340c58f0b..473449f94 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -17,6 +17,7 @@ from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase +from pilot.connections.rdbms.conn_starrocks import StarRocksConnect from pilot.singleton import Singleton from pilot.common.sql_database import Database from pilot.connections.db_conn_info import DBConfig diff --git a/pilot/connections/rdbms/conn_starrocks.py b/pilot/connections/rdbms/conn_starrocks.py new file mode 100644 index 000000000..530279340 --- /dev/null +++ b/pilot/connections/rdbms/conn_starrocks.py @@ -0,0 +1,147 @@ +from typing import Iterable, Optional, Any +from sqlalchemy import text +from urllib.parse import quote +import re +from pilot.connections.rdbms.base import RDBMSDatabase +from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import * + + +class StarRocksConnect(RDBMSDatabase): + driver = "starrocks" + db_type = "starrocks" + db_dialect = "starrocks" + + @classmethod + def from_uri_db( + cls, + host: str, + port: int, + user: str, + pwd: str, + db_name: str, + engine_args: Optional[dict] = None, + **kwargs: Any, + ) -> RDBMSDatabase: + db_url: str = ( + f"{cls.driver}://{quote(user)}:{quote(pwd)}@{host}:{str(port)}/{db_name}" + ) + return cls.from_uri(db_url, engine_args, **kwargs) + + def _sync_tables_from_db(self) -> Iterable[str]: + db_name = self.get_current_db_name() + table_results = self.session.execute(text(f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' )) + #view_results = self.session.execute(text(f'SELECT TABLE_NAME from information_schema.materialized_views where TABLE_SCHEMA="{db_name}"')) + table_results = set(row[0] for row in table_results) + #view_results = set(row[0] for row in view_results) + self._all_tables = table_results + self._metadata.reflect(bind=self._engine) + return self._all_tables + + def get_grants(self): + session = self._db_sessions() + cursor = session.execute(text("SHOW GRANTS")) + grants = cursor.fetchall() + if len(grants) == 0: + return [] + if len(grants[0]) == 2: + grants_list = [x[1] for x in grants] + else: + grants_list = [x[2] for x in grants] + return grants_list + + def _get_current_version(self): + """Get database current version""" + return int(self.session.execute(text("select current_version()")).scalar()) + + def get_collation(self): + """Get collation.""" + # StarRocks 排序是表级别的 + return None + + def get_users(self): + """Get user info.""" + return [] + + def get_fields(self, table_name, db_name="database()"): + """Get column fields about specified table.""" + session = self._db_sessions() + if db_name != "database()": + db_name = f'"{db_name}"' + cursor = session.execute( + text( + f'select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.columns where TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}' + ) + ) + fields = cursor.fetchall() + return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] + + def get_charset(self): + """Get character_set.""" + + return "utf-8" + + def get_show_create_table(self, table_name): + # cur = self.session.execute( + # text( + # f"""show create table {table_name}""" + # ) + # ) + # rows = cur.fetchone() + # create_sql = rows[0] + + # return create_sql + # 这里是要表描述, 返回建表语句会导致token过长而失败 + cur = self.session.execute( + text( + f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()' + ) + ) + table = cur.fetchone() + if table: + return str(table[0]) + else: + return "" + + def get_table_comments(self, db_name=None): + if not db_name: + db_name = self.get_current_db_name() + cur = self.session.execute( + text( + f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"' + ) + ) + tables = cur.fetchall() + return [(table[0], table[1]) for table in tables] + + def get_database_list(self): + return self.get_database_names() + + def get_database_names(self): + session = self._db_sessions() + cursor = session.execute(text("SHOW DATABASES;")) + results = cursor.fetchall() + return [ + d[0] + for d in results + if d[0] not in ["information_schema", "sys", "_statistics_", "dataease"] + ] + + def get_current_db_name(self) -> str: + return self.session.execute(text("select database()")).scalar() + + def table_simple_info(self): + _sql = f""" + SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") FROM information_schema.columns where TABLE_SCHEMA=database() + GROUP BY TABLE_NAME + """ + cursor = self.session.execute(text(_sql)) + results = cursor.fetchall() + return [x[0] for x in results] + + def get_indexes(self, table_name): + """Get table indexes about specified table.""" + session = self._db_sessions() + cursor = session.execute(text(f"SHOW INDEX FROM {table_name}")) + indexes = cursor.fetchall() + return [(index[2], index[4]) for index in indexes] + diff --git a/pilot/connections/rdbms/dialect/__init__.py b/pilot/connections/rdbms/dialect/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/connections/rdbms/dialect/starrocks/__init__.py b/pilot/connections/rdbms/dialect/starrocks/__init__.py new file mode 100644 index 000000000..20fa42f53 --- /dev/null +++ b/pilot/connections/rdbms/dialect/starrocks/__init__.py @@ -0,0 +1,14 @@ +#! /usr/bin/python3 +# Copyright 2021-present StarRocks, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py new file mode 100644 index 000000000..601cc7346 --- /dev/null +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/__init__.py @@ -0,0 +1,22 @@ +#! /usr/bin/python3 +# Copyright 2021-present StarRocks, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.dialects import registry + +registry.register( + "starrocks", + "pilot.connections.rdbms.dialect.starrocks.sqlalchemy.dialect", + "StarRocksDialect", +) diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py new file mode 100644 index 000000000..e8542f940 --- /dev/null +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/datatype.py @@ -0,0 +1,104 @@ +import logging +import re +from typing import Optional, List, Any, Type, Dict + +from sqlalchemy import Numeric, Integer, Float +from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.type_api import TypeEngine + +logger = logging.getLogger(__name__) + + +class TINYINT(Integer): # pylint: disable=no-init + __visit_name__ = "TINYINT" + + +class LARGEINT(Integer): # pylint: disable=no-init + __visit_name__ = "LARGEINT" + + +class DOUBLE(Float): # pylint: disable=no-init + __visit_name__ = "DOUBLE" + + +class HLL(Numeric): # pylint: disable=no-init + __visit_name__ = "HLL" + + +class BITMAP(Numeric): # pylint: disable=no-init + __visit_name__ = "BITMAP" + + +class PERCENTILE(Numeric): # pylint: disable=no-init + __visit_name__ = "PERCENTILE" + + +class ARRAY(TypeEngine): # pylint: disable=no-init + __visit_name__ = "ARRAY" + + @property + def python_type(self) -> Optional[Type[List[Any]]]: + return list + + +class MAP(TypeEngine): # pylint: disable=no-init + __visit_name__ = "MAP" + + @property + def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + return dict + + +class STRUCT(TypeEngine): # pylint: disable=no-init + __visit_name__ = "STRUCT" + + @property + def python_type(self) -> Optional[Type[Any]]: + return None + + +_type_map = { + # === Boolean === + "boolean": sqltypes.BOOLEAN, + # === Integer === + "tinyint": sqltypes.SMALLINT, + "smallint": sqltypes.SMALLINT, + "int": sqltypes.INTEGER, + "bigint": sqltypes.BIGINT, + "largeint": LARGEINT, + # === Floating-point === + "float": sqltypes.FLOAT, + "double": DOUBLE, + # === Fixed-precision === + "decimal": sqltypes.DECIMAL, + # === String === + "varchar": sqltypes.VARCHAR, + "char": sqltypes.CHAR, + "json": sqltypes.JSON, + # === Date and time === + "date": sqltypes.DATE, + "datetime": sqltypes.DATETIME, + "timestamp": sqltypes.DATETIME, + # === Structural === + "array": ARRAY, + "map": MAP, + "struct": STRUCT, + "hll": HLL, + "percentile": PERCENTILE, + "bitmap": BITMAP, +} + + +def parse_sqltype(type_str: str) -> TypeEngine: + type_str = type_str.strip().lower() + match = re.match(r"^(?P\w+)\s*(?:\((?P.*)\))?", type_str) + if not match: + logger.warning(f"Could not parse type name '{type_str}'") + return sqltypes.NULLTYPE + type_name = match.group("type") + + if type_name not in _type_map: + logger.warning(f"Did not recognize type '{type_name}'") + return sqltypes.NULLTYPE + type_class = _type_map[type_name] + return type_class() diff --git a/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py new file mode 100644 index 000000000..f92054d90 --- /dev/null +++ b/pilot/connections/rdbms/dialect/starrocks/sqlalchemy/dialect.py @@ -0,0 +1,173 @@ +#! /usr/bin/python3 +# Copyright 2021-present StarRocks, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https:#www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Any, Dict, List + +from sqlalchemy import log, exc, text +from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql +from sqlalchemy.engine import Connection + +from pilot.connections.rdbms.dialect.starrocks.sqlalchemy import datatype + +logger = logging.getLogger(__name__) + + +@log.class_logger +class StarRocksDialect(MySQLDialect_pymysql): + # Caching + # Warnings are generated by SQLAlchmey if this flag is not explicitly set + # and tests are needed before being enabled + supports_statement_cache = False + + name = "starrocks" + + def __init__(self, *args, **kw): + super(StarRocksDialect, self).__init__(*args, **kw) + + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + + if schema is None: + schema = self.default_schema_name + + assert schema is not None + + quote = self.identifier_preparer.quote_identifier + full_name = quote(table_name) + if schema: + full_name = "{}.{}".format(quote(schema), full_name) + + res = connection.execute(text(f"DESCRIBE {full_name}")) + return res.first() is not None + + def get_schema_names(self, connection, **kw): + rp = connection.exec_driver_sql("SHOW schemas") + return [r[0] for r in rp] + + def get_table_names(self, connection, schema=None, **kw): + """Return a Unicode SHOW TABLES from a given schema.""" + if schema is not None: + current_schema = schema + else: + current_schema = self.default_schema_name + + charset = self._connection_charset + + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] + + def get_view_names(self, connection, schema=None, **kw): + if schema is None: + schema = self.default_schema_name + charset = self._connection_charset + rp = connection.exec_driver_sql( + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] + + def get_columns( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: + if not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError(f"schema={schema}, table={table_name}") + schema = schema or self._get_default_schema_name(connection) + + quote = self.identifier_preparer.quote_identifier + full_name = quote(table_name) + if schema: + full_name = "{}.{}".format(quote(schema), full_name) + + res = connection.execute(text(f"SHOW COLUMNS FROM {full_name}")) + columns = [] + for record in res: + column = dict( + name=record.Field, + type=datatype.parse_sqltype(record.Type), + nullable=record.Null == "YES", + default=record.Default, + ) + columns.append(column) + return columns + + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + return { # type: ignore # pep-655 not supported + "name": None, + "constrained_columns": [], + } + + def get_unique_constraints( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: + return [] + + def get_check_constraints( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: + return [] + + def get_foreign_keys( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[Dict[str, Any]]: + return [] + + def get_primary_keys( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> List[str]: + pk = self.get_pk_constraint(connection, table_name, schema) + return pk.get("constrained_columns") # type: ignore + + def get_indexes(self, connection, table_name, schema=None, **kw): + return [] + + def has_sequence( + self, connection: Connection, sequence_name: str, schema: str = None, **kw + ) -> bool: + return False + + def get_sequence_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: + return [] + + def get_temp_view_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: + return [] + + def get_temp_table_names( + self, connection: Connection, schema: str = None, **kw + ) -> List[str]: + return [] + + def get_table_options(self, connection, table_name, schema=None, **kw): + return {} + + def get_table_comment( + self, connection: Connection, table_name: str, schema: str = None, **kw + ) -> Dict[str, Any]: + return dict(text=None)