From da85415148966d4a834d7a9c4873581d8f2ce4ea Mon Sep 17 00:00:00 2001 From: Gil Forsyth Date: Tue, 13 Feb 2024 17:05:14 -0500 Subject: [PATCH] feat(api): add disconnect method This adds a `disconnect` method to all backends. Previously we didn't do this since the actual connection was often wrapped in SQLAlchemy and our ability to terminate the connection was unclear. The DB-API spec states that there should be a `close` method, and that any subsequent operations on a given `Connection` should raise an error after `close` is called. This _mostly_ works. Trino, Clickhouse, Impala, and BigQuery do not conform to the DB-API in this way. They have the `close` method but don't raise when you make a subsequent call. For the in-process backends I've chosen to raise, since I don't think there's a clear meaning on what closing that connection would mean, but happy to take any suggestions there. --- ibis/backends/base/__init__.py | 4 +++ ibis/backends/base/sqlglot/__init__.py | 5 +++ ibis/backends/bigquery/__init__.py | 3 ++ ibis/backends/clickhouse/__init__.py | 2 +- ibis/backends/dask/__init__.py | 3 ++ ibis/backends/datafusion/__init__.py | 3 ++ ibis/backends/pandas/__init__.py | 3 ++ ibis/backends/polars/__init__.py | 3 ++ ibis/backends/pyspark/__init__.py | 3 ++ ibis/backends/tests/errors.py | 34 +++++++++++++++----- ibis/backends/tests/test_client.py | 43 ++++++++++++++++++++++++++ 11 files changed, 97 insertions(+), 9 deletions(-) diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 7e2dbd46b1e3e..f5419eaea4694 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -864,6 +864,10 @@ def connect(self, *args, **kwargs) -> BaseBackend: new_backend.reconnect() return new_backend + @abc.abstractmethod + def disconnect(self) -> None: + """Close the connection to the backend.""" + @staticmethod def _convert_kwargs(kwargs: MutableMapping) -> None: """Manipulate keyword arguments to `.connect` method.""" diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 47c715129cf2f..a280d5277cdd2 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -383,3 +383,8 @@ def truncate_table( ).sql(self.dialect) with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass + + def disconnect(self): + # This is part of the Python DB-API specification so should work for + # _most_ sqlglot backends + self.con.close() diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 43fe70af5c8a8..a7d8e20d520e0 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -447,6 +447,9 @@ def do_connect( self.partition_column = partition_column + def disconnect(self) -> None: + self.client.close() + def _parse_project_and_dataset(self, dataset) -> tuple[str, str]: if not dataset and not self.dataset: raise ValueError("Unable to determine BigQuery dataset.") diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 9e6cde739f1b4..43f490a07f4d2 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -446,7 +446,7 @@ def raw_sql( self._log(query) return self.con.query(query, external_data=external_data, **kwargs) - def close(self) -> None: + def disconnect(self) -> None: """Close ClickHouse connection.""" self.con.close() diff --git a/ibis/backends/dask/__init__.py b/ibis/backends/dask/__init__.py index df1de12c2359b..d5a259119e3da 100644 --- a/ibis/backends/dask/__init__.py +++ b/ibis/backends/dask/__init__.py @@ -61,6 +61,9 @@ def do_connect( ) super().do_connect(dictionary) + def disconnect(self) -> None: + raise NotImplementedError + @property def version(self): return dask.__version__ diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 44946ac3abf51..79bedbc59ee47 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -94,6 +94,9 @@ def do_connect( for name, path in config.items(): self.register(path, table_name=name) + def disconnect(self) -> None: + raise NotImplementedError + @contextlib.contextmanager def _safe_raw_sql(self, sql: sge.Statement) -> Any: yield self.raw_sql(sql).collect() diff --git a/ibis/backends/pandas/__init__.py b/ibis/backends/pandas/__init__.py index 4d3cddb665c0d..f3f88e492e9ef 100644 --- a/ibis/backends/pandas/__init__.py +++ b/ibis/backends/pandas/__init__.py @@ -53,6 +53,9 @@ def do_connect( self.dictionary = dictionary or {} self.schemas: MutableMapping[str, sch.Schema] = {} + def disconnect(self) -> None: + raise NotImplementedError + def from_dataframe( self, df: pd.DataFrame, diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 819fb995dc7fa..927ffd058371f 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -59,6 +59,9 @@ def do_connect( for name, table in (tables or {}).items(): self._add_table(name, table) + def disconnect(self) -> None: + raise NotImplementedError() + @property def version(self) -> str: return pl.__version__ diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index e77839fa939fc..3507636e1cc29 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -158,6 +158,9 @@ def do_connect(self, session: SparkSession) -> None: self._session.conf.set("spark.sql.session.timeZone", "UTC") self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN") + def disconnect(self) -> None: + self._session.stop() + def _metadata(self, query: str): cursor = self.raw_sql(query) struct_dtype = PySparkType.to_ibis(cursor.query.schema) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index 8aec1a1912ed2..0a603ecdcf678 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -1,6 +1,7 @@ from __future__ import annotations try: + from duckdb import ConnectionException as DuckDBConnectionException from duckdb import ConversionException as DuckDBConversionException from duckdb import InvalidInputException as DuckDBInvalidInputException from duckdb import NotImplementedException as DuckDBNotImplementedException @@ -8,7 +9,9 @@ except ImportError: DuckDBConversionException = ( DuckDBInvalidInputException - ) = DuckDBParserException = DuckDBNotImplementedException = None + ) = ( + DuckDBParserException + ) = DuckDBNotImplementedException = DuckDBConnectionException = None try: from clickhouse_connect.driver.exceptions import ( @@ -27,9 +30,9 @@ try: - from pyexasol.exceptions import ExaQueryError + from pyexasol.exceptions import ExaQueryError, ExaRuntimeError except ImportError: - ExaQueryError = None + ExaQueryError = ExaRuntimeError = None try: from pyspark.sql.utils import AnalysisException as PySparkAnalysisException @@ -85,9 +88,10 @@ PyDeltaTableError = None try: + from snowflake.connector.errors import DatabaseError as SnowflakeDatabaseError from snowflake.connector.errors import ProgrammingError as SnowflakeProgrammingError except ImportError: - SnowflakeProgrammingError = None + SnowflakeProgrammingError = SnowflakeDatabaseError = None try: from trino.exceptions import TrinoUserError @@ -97,6 +101,7 @@ try: from psycopg2.errors import DivisionByZero as PsycoPg2DivisionByZero from psycopg2.errors import IndeterminateDatatype as PsycoPg2IndeterminateDatatype + from psycopg2.errors import InterfaceError as PsycoPg2InterfaceError from psycopg2.errors import InternalError_ as PsycoPg2InternalError from psycopg2.errors import ( InvalidTextRepresentation as PsycoPg2InvalidTextRepresentation, @@ -116,27 +121,40 @@ PsycoPg2InternalError ) = ( PsycoPg2ProgrammingError - ) = PsycoPg2OperationalError = PsycoPg2UndefinedObject = None + ) = ( + PsycoPg2OperationalError + ) = PsycoPg2UndefinedObject = PsycoPg2InterfaceError = None try: + from pymysql.err import InterfaceError as MySQLInterfaceError from pymysql.err import NotSupportedError as MySQLNotSupportedError from pymysql.err import OperationalError as MySQLOperationalError from pymysql.err import ProgrammingError as MySQLProgrammingError except ImportError: - MySQLNotSupportedError = MySQLProgrammingError = MySQLOperationalError = None + MySQLNotSupportedError = ( + MySQLProgrammingError + ) = MySQLOperationalError = MySQLInterfaceError = None try: + from pydruid.db.exceptions import Error as PyDruidError from pydruid.db.exceptions import ProgrammingError as PyDruidProgrammingError except ImportError: - PyDruidProgrammingError = None + PyDruidProgrammingError = PyDruidError = None try: from oracledb.exceptions import DatabaseError as OracleDatabaseError + from oracledb.exceptions import InterfaceError as OracleInterfaceError except ImportError: - OracleDatabaseError = None + OracleDatabaseError = OracleInterfaceError = None try: from pyodbc import DataError as PyODBCDataError from pyodbc import ProgrammingError as PyODBCProgrammingError except ImportError: PyODBCProgrammingError = PyODBCDataError = None + + +try: + from sqlite3 import ProgrammingError as sqlite3ProgrammingError +except ImportError: + sqlite3ProgrammingError = None diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 62ffb72c0ccfa..82253874c09f8 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -26,14 +26,23 @@ import ibis.expr.operations as ops from ibis.backends.conftest import ALL_BACKENDS from ibis.backends.tests.errors import ( + DuckDBConnectionException, ExaQueryError, + ExaRuntimeError, ImpalaHiveServer2Error, + MySQLInterfaceError, OracleDatabaseError, + OracleInterfaceError, + PsycoPg2InterfaceError, PsycoPg2InternalError, PsycoPg2UndefinedObject, + Py4JJavaError, + PyDruidError, PyODBCProgrammingError, + SnowflakeDatabaseError, SnowflakeProgrammingError, TrinoUserError, + sqlite3ProgrammingError, ) from ibis.util import gen_name @@ -1460,3 +1469,37 @@ def test_list_databases_schemas(con_create_database_schema): con_create_database_schema.drop_schema(schema, database=database) finally: con_create_database_schema.drop_database(database) + + +@pytest.mark.notyet( + ["pandas", "dask", "polars", "datafusion"], + raises=NotImplementedError, + reason="In process backends have nothing to close", +) +@pytest.mark.notyet( + ["trino", "clickhouse", "impala", "bigquery"], + reason="Backend client does not conform to DB-API, subsequent op does not raise", +) +def test_close_connection(con): + new_con = getattr(ibis, con.name).connect(*con._con_args, **con._con_kwargs) + + # Run any command that hits the backend + _ = new_con.list_tables() + new_con.disconnect() + + # DB-API states that subsequent execution attempt should raise + with pytest.raises( + ( + DuckDBConnectionException, + ExaRuntimeError, + MySQLInterfaceError, + OracleInterfaceError, + PsycoPg2InterfaceError, + Py4JJavaError, + PyDruidError, + PyODBCProgrammingError, + SnowflakeDatabaseError, + sqlite3ProgrammingError, + ) + ): + new_con.list_tables()