From bd30ba5ccebbf44ee1ef9642c0cd5359d0392854 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 15 Jan 2025 14:53:22 +0000 Subject: [PATCH] refactor(risingwave): decouple the risingwave backend from the postgres backend (#10669) Decouple the risingwave backend from the postgres backend to allow use of `psycopg` (not `psycopg2`) with the postgres backend in #10659. --- ibis/backends/risingwave/__init__.py | 444 ++++++++++++++++++++++++++- 1 file changed, 441 insertions(+), 3 deletions(-) diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index fd3400fb79a1..f2c3d66802ca 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -2,25 +2,33 @@ from __future__ import annotations +import contextlib from operator import itemgetter -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from urllib.parse import unquote_plus import psycopg2 import sqlglot as sg import sqlglot.expressions as sge +from pandas.api.types import is_float_dtype from psycopg2 import extras import ibis import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com +import ibis.common.exceptions as exc import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util -from ibis.backends.postgres import Backend as PostgresBackend +from ibis.backends import CanCreateDatabase, CanListCatalog +from ibis.backends.sql import SQLBackend +from ibis.backends.sql.compilers.base import TRUE, C, ColGen from ibis.util import experimental if TYPE_CHECKING: + from urllib.parse import ParseResult + import pandas as pd import polars as pl import pyarrow as pa @@ -44,11 +52,424 @@ def format_properties(props): return "( {} ) ".format(", ".join(tokens)) -class Backend(PostgresBackend): +class Backend(SQLBackend, CanListCatalog, CanCreateDatabase): name = "risingwave" compiler = sc.risingwave.compiler supports_python_udfs = False + def _from_url(self, url: ParseResult, **kwargs): + """Connect to a backend using a URL `url`. + + Parameters + ---------- + url + URL with which to connect to a backend. + kwargs + Additional keyword arguments + + Returns + ------- + BaseBackend + A backend instance + + """ + database, *schema = url.path[1:].split("/", 1) + connect_args = { + "user": url.username, + "password": unquote_plus(url.password or ""), + "host": url.hostname, + "database": database or "", + "schema": schema[0] if schema else "", + "port": url.port, + } + + kwargs.update(connect_args) + self._convert_kwargs(kwargs) + + if "user" in kwargs and not kwargs["user"]: + del kwargs["user"] + + if "host" in kwargs and not kwargs["host"]: + del kwargs["host"] + + if "database" in kwargs and not kwargs["database"]: + del kwargs["database"] + + if "schema" in kwargs and not kwargs["schema"]: + del kwargs["schema"] + + if "password" in kwargs and kwargs["password"] is None: + del kwargs["password"] + + if "port" in kwargs and kwargs["port"] is None: + del kwargs["port"] + + return self.connect(**kwargs) + + def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + from psycopg2.extras import execute_batch + + schema = op.schema + if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: + raise exc.IbisTypeError( + f"{self.name} cannot yet reliably handle `null` typed columns; " + f"got null typed columns: {null_columns}" + ) + + name = op.name + quoted = self.compiler.quoted + create_stmt = sg.exp.Create( + kind="TABLE", + this=sg.exp.Schema( + this=sg.to_identifier(name, quoted=quoted), + expressions=schema.to_sqlglot(self.dialect), + ), + properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]), + ) + create_stmt_sql = create_stmt.sql(self.dialect) + + df = op.data.to_frame() + # nan gets compiled into 'NaN'::float which throws errors in non-float columns + # In order to hold NaN values, pandas automatically converts integer columns + # to float columns if there are NaN values in them. Therefore, we need to convert + # them to their original dtypes (that support pd.NA) to figure out which columns + # are actually non-float, then fill the NaN values in those columns with None. + convert_df = df.convert_dtypes() + for col in convert_df.columns: + if not is_float_dtype(convert_df[col]): + df[col] = df[col].replace(float("nan"), None) + + data = df.itertuples(index=False) + sql = self._build_insert_template( + name, schema=schema, columns=True, placeholder="%s" + ) + + with self.begin() as cur: + cur.execute(create_stmt_sql) + execute_batch(cur, sql, data, 128) + + @contextlib.contextmanager + def begin(self): + con = self.con + cursor = con.cursor() + try: + yield cursor + except Exception: + con.rollback() + raise + else: + con.commit() + finally: + cursor.close() + + def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: + import pandas as pd + + from ibis.backends.postgres.converter import PostgresPandasData + + try: + df = pd.DataFrame.from_records( + cursor, columns=schema.names, coerce_float=True + ) + except Exception: + # clean up the cursor if we fail to create the DataFrame + # + # in the sqlite case failing to close the cursor results in + # artificially locked tables + cursor.close() + raise + df = PostgresPandasData.convert_table(df, schema) + return df + + @property + def version(self): + version = f"{self.con.server_version:0>6}" + major = int(version[:2]) + minor = int(version[2:4]) + patch = int(version[4:]) + pieces = [major] + if minor: + pieces.append(minor) + pieces.append(patch) + return ".".join(map(str, pieces)) + + @util.experimental + @classmethod + def from_connection(cls, con: psycopg2.extensions.connection) -> Backend: + """Create an Ibis client from an existing connection to a PostgreSQL database. + + Parameters + ---------- + con + An existing connection to a PostgreSQL database. + """ + new_backend = cls() + new_backend._can_reconnect = False + new_backend.con = con + new_backend._post_connect() + return new_backend + + def _post_connect(self) -> None: + with self.begin() as cur: + cur.execute("SET TIMEZONE = UTC") + + @property + def _session_temp_db(self) -> str | None: + # Postgres doesn't assign the temporary table database until the first + # temp table is created in a given session. + # Before that temp table is created, this will return `None` + # After a temp table is created, it will return `pg_temp_N` where N is + # some integer + res = self.raw_sql( + "select nspname from pg_namespace where oid = pg_my_temp_schema()" + ).fetchone() + if res is not None: + return res[0] + return res + + def list_tables( + self, + like: str | None = None, + database: tuple[str, str] | str | None = None, + ) -> list[str]: + """List the tables in the database. + + ::: {.callout-note} + ## Ibis does not use the word `schema` to refer to database hierarchy. + + A collection of tables is referred to as a `database`. + A collection of `database` is referred to as a `catalog`. + + These terms are mapped onto the corresponding features in each + backend (where available), regardless of whether the backend itself + uses the same terminology. + ::: + + Parameters + ---------- + like + A pattern to use for listing tables. + database + Database to list tables from. Default behavior is to show tables in + the current database. + """ + + if database is not None: + table_loc = database + else: + table_loc = (self.current_catalog, self.current_database) + + table_loc = self._to_sqlglot_table(table_loc) + + conditions = [TRUE] + + if (db := table_loc.args["db"]) is not None: + db.args["quoted"] = False + db = db.sql(dialect=self.name) + conditions.append(C.table_schema.eq(sge.convert(db))) + if (catalog := table_loc.args["catalog"]) is not None: + catalog.args["quoted"] = False + catalog = catalog.sql(dialect=self.name) + conditions.append(C.table_catalog.eq(sge.convert(catalog))) + + sql = ( + sg.select("table_name") + .from_(sg.table("tables", db="information_schema")) + .distinct() + .where(*conditions) + .sql(self.dialect) + ) + + with self._safe_raw_sql(sql) as cur: + out = cur.fetchall() + + # Include temporary tables only if no database has been explicitly specified + # to avoid temp tables showing up in all calls to `list_tables` + if db == "public": + out += self._fetch_temp_tables() + + return self._filter_with_like(map(itemgetter(0), out), like) + + def _fetch_temp_tables(self): + # postgres temporary tables are stored in a separate schema + # so we need to independently grab them and return them along with + # the existing results + + sql = ( + sg.select("table_name") + .from_(sg.table("tables", db="information_schema")) + .distinct() + .where(C.table_type.eq(sge.convert("LOCAL TEMPORARY"))) + .sql(self.dialect) + ) + + with self._safe_raw_sql(sql) as cur: + out = cur.fetchall() + + return out + + def list_catalogs(self, like=None) -> list[str]: + # http://dba.stackexchange.com/a/1304/58517 + cats = ( + sg.select(C.datname) + .from_(sg.table("pg_database", db="pg_catalog")) + .where(sg.not_(C.datistemplate)) + ) + with self._safe_raw_sql(cats) as cur: + catalogs = list(map(itemgetter(0), cur)) + + return self._filter_with_like(catalogs, like) + + @property + def current_catalog(self) -> str: + with self._safe_raw_sql(sg.select(sg.func("current_database"))) as cur: + (db,) = cur.fetchone() + return db + + @property + def current_database(self) -> str: + with self._safe_raw_sql(sg.select(sg.func("current_schema"))) as cur: + (schema,) = cur.fetchone() + return schema + + def get_schema( + self, + name: str, + *, + catalog: str | None = None, + database: str | None = None, + ): + a = ColGen(table="a") + c = ColGen(table="c") + n = ColGen(table="n") + + format_type = self.compiler.f["pg_catalog.format_type"] + + # If no database is specified, assume the current database + db = database or self.current_database + + dbs = [sge.convert(db)] + + # If a database isn't specified, then include temp tables in the + # returned values + if database is None and (temp_table_db := self._session_temp_db) is not None: + dbs.append(sge.convert(temp_table_db)) + + type_info = ( + sg.select( + a.attname.as_("column_name"), + format_type(a.atttypid, a.atttypmod).as_("data_type"), + sg.not_(a.attnotnull).as_("nullable"), + ) + .from_(sg.table("pg_attribute", db="pg_catalog").as_("a")) + .join( + sg.table("pg_class", db="pg_catalog").as_("c"), + on=c.oid.eq(a.attrelid), + join_type="INNER", + ) + .join( + sg.table("pg_namespace", db="pg_catalog").as_("n"), + on=n.oid.eq(c.relnamespace), + join_type="INNER", + ) + .where( + a.attnum > 0, + sg.not_(a.attisdropped), + n.nspname.isin(*dbs), + c.relname.eq(sge.convert(name)), + ) + .order_by(a.attnum) + ) + + type_mapper = self.compiler.type_mapper + + with self._safe_raw_sql(type_info) as cur: + rows = cur.fetchall() + + if not rows: + raise com.TableNotFound(name) + + return sch.Schema( + { + col: type_mapper.from_string(typestr, nullable=nullable) + for col, typestr, nullable in rows + } + ) + + def _get_schema_using_query(self, query: str) -> sch.Schema: + name = util.gen_name(f"{self.name}_metadata") + + create_stmt = sge.Create( + kind="VIEW", + this=sg.table(name), + expression=sg.parse_one(query, read=self.dialect), + properties=sge.Properties(expressions=[sge.TemporaryProperty()]), + ) + drop_stmt = sge.Drop(kind="VIEW", this=sg.table(name), exists=True).sql( + self.dialect + ) + + with self._safe_raw_sql(create_stmt): + pass + try: + return self.get_schema(name) + finally: + with self._safe_raw_sql(drop_stmt): + pass + + def create_database( + self, name: str, catalog: str | None = None, force: bool = False + ) -> None: + if catalog is not None and catalog != self.current_catalog: + raise exc.UnsupportedOperationError( + f"{self.name} does not support creating a database in a different catalog" + ) + sql = sge.Create( + kind="SCHEMA", this=sg.table(name, catalog=catalog), exists=force + ) + with self._safe_raw_sql(sql): + pass + + def drop_database( + self, + name: str, + catalog: str | None = None, + force: bool = False, + cascade: bool = False, + ) -> None: + if catalog is not None and catalog != self.current_catalog: + raise exc.UnsupportedOperationError( + f"{self.name} does not support dropping a database in a different catalog" + ) + + sql = sge.Drop( + kind="SCHEMA", + this=sg.table(name, catalog=catalog), + exists=force, + cascade=cascade, + ) + with self._safe_raw_sql(sql): + pass + + def drop_table( + self, + name: str, + database: str | None = None, + force: bool = False, + ) -> None: + drop_stmt = sg.exp.Drop( + kind="TABLE", + this=sg.table(name, db=database, quoted=self.compiler.quoted), + exists=force, + ) + with self._safe_raw_sql(drop_stmt): + pass + + @contextlib.contextmanager + def _safe_raw_sql(self, *args, **kwargs): + with contextlib.closing(self.raw_sql(*args, **kwargs)) as result: + yield result + def do_connect( self, host: str | None = None, @@ -556,3 +977,20 @@ def _session_temp_db(self) -> str | None: # Return `None`, because RisingWave does not implement temp tables like # Postgres return None + + def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: + with contextlib.suppress(AttributeError): + query = query.sql(dialect=self.dialect) + + con = self.con + cursor = con.cursor() + + try: + cursor.execute(query, **kwargs) + except Exception: + con.rollback() + cursor.close() + raise + else: + con.commit() + return cursor