Skip to content

Commit

Permalink
feat(api): add disconnect method
Browse files Browse the repository at this point in the history
This adds a `disconnect` method to all backends.  Previously we didn't do
this since the actual connection was often wrapped in SQLAlchemy and our
ability to terminate the connection was unclear.

The DB-API spec states that there should be a `close` method, and that
any subsequent operations on a given `Connection` should raise an error
after `close` is called.

This _mostly_ works.

Trino, Clickhouse, Impala, and BigQuery do not conform to the DB-API in
this way.  They have the `close` method but don't raise when you make a
subsequent call.

For the in-process backends I've chosen to raise, since I don't think
there's a clear meaning on what closing that connection would mean, but
happy to take any suggestions there.
  • Loading branch information
gforsyth committed Feb 13, 2024
1 parent d587166 commit da85415
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 9 deletions.
4 changes: 4 additions & 0 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,10 @@ def connect(self, *args, **kwargs) -> BaseBackend:
new_backend.reconnect()
return new_backend

@abc.abstractmethod
def disconnect(self) -> None:
"""Close the connection to the backend."""

@staticmethod
def _convert_kwargs(kwargs: MutableMapping) -> None:
"""Manipulate keyword arguments to `.connect` method."""
Expand Down
5 changes: 5 additions & 0 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,8 @@ def truncate_table(
).sql(self.dialect)
with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"):
pass

def disconnect(self):
# This is part of the Python DB-API specification so should work for
# _most_ sqlglot backends
self.con.close()
3 changes: 3 additions & 0 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,9 @@ def do_connect(

self.partition_column = partition_column

def disconnect(self) -> None:
self.client.close()

def _parse_project_and_dataset(self, dataset) -> tuple[str, str]:
if not dataset and not self.dataset:
raise ValueError("Unable to determine BigQuery dataset.")
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def raw_sql(
self._log(query)
return self.con.query(query, external_data=external_data, **kwargs)

def close(self) -> None:
def disconnect(self) -> None:
"""Close ClickHouse connection."""
self.con.close()

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def do_connect(
)
super().do_connect(dictionary)

def disconnect(self) -> None:
raise NotImplementedError

@property
def version(self):
return dask.__version__
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def do_connect(
for name, path in config.items():
self.register(path, table_name=name)

def disconnect(self) -> None:
raise NotImplementedError

@contextlib.contextmanager
def _safe_raw_sql(self, sql: sge.Statement) -> Any:
yield self.raw_sql(sql).collect()
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def do_connect(
self.dictionary = dictionary or {}
self.schemas: MutableMapping[str, sch.Schema] = {}

def disconnect(self) -> None:
raise NotImplementedError

def from_dataframe(
self,
df: pd.DataFrame,
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def do_connect(
for name, table in (tables or {}).items():
self._add_table(name, table)

def disconnect(self) -> None:
raise NotImplementedError()

@property
def version(self) -> str:
return pl.__version__
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def do_connect(self, session: SparkSession) -> None:
self._session.conf.set("spark.sql.session.timeZone", "UTC")
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")

def disconnect(self) -> None:
self._session.stop()

def _metadata(self, query: str):
cursor = self.raw_sql(query)
struct_dtype = PySparkType.to_ibis(cursor.query.schema)
Expand Down
34 changes: 26 additions & 8 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

try:
from duckdb import ConnectionException as DuckDBConnectionException
from duckdb import ConversionException as DuckDBConversionException
from duckdb import InvalidInputException as DuckDBInvalidInputException
from duckdb import NotImplementedException as DuckDBNotImplementedException
from duckdb import ParserException as DuckDBParserException
except ImportError:
DuckDBConversionException = (
DuckDBInvalidInputException
) = DuckDBParserException = DuckDBNotImplementedException = None
) = (
DuckDBParserException
) = DuckDBNotImplementedException = DuckDBConnectionException = None

try:
from clickhouse_connect.driver.exceptions import (
Expand All @@ -27,9 +30,9 @@


try:
from pyexasol.exceptions import ExaQueryError
from pyexasol.exceptions import ExaQueryError, ExaRuntimeError
except ImportError:
ExaQueryError = None
ExaQueryError = ExaRuntimeError = None

try:
from pyspark.sql.utils import AnalysisException as PySparkAnalysisException
Expand Down Expand Up @@ -85,9 +88,10 @@
PyDeltaTableError = None

try:
from snowflake.connector.errors import DatabaseError as SnowflakeDatabaseError
from snowflake.connector.errors import ProgrammingError as SnowflakeProgrammingError
except ImportError:
SnowflakeProgrammingError = None
SnowflakeProgrammingError = SnowflakeDatabaseError = None

try:
from trino.exceptions import TrinoUserError
Expand All @@ -97,6 +101,7 @@
try:
from psycopg2.errors import DivisionByZero as PsycoPg2DivisionByZero
from psycopg2.errors import IndeterminateDatatype as PsycoPg2IndeterminateDatatype
from psycopg2.errors import InterfaceError as PsycoPg2InterfaceError
from psycopg2.errors import InternalError_ as PsycoPg2InternalError
from psycopg2.errors import (
InvalidTextRepresentation as PsycoPg2InvalidTextRepresentation,
Expand All @@ -116,27 +121,40 @@
PsycoPg2InternalError
) = (
PsycoPg2ProgrammingError
) = PsycoPg2OperationalError = PsycoPg2UndefinedObject = None
) = (
PsycoPg2OperationalError
) = PsycoPg2UndefinedObject = PsycoPg2InterfaceError = None

try:
from pymysql.err import InterfaceError as MySQLInterfaceError
from pymysql.err import NotSupportedError as MySQLNotSupportedError
from pymysql.err import OperationalError as MySQLOperationalError
from pymysql.err import ProgrammingError as MySQLProgrammingError
except ImportError:
MySQLNotSupportedError = MySQLProgrammingError = MySQLOperationalError = None
MySQLNotSupportedError = (
MySQLProgrammingError
) = MySQLOperationalError = MySQLInterfaceError = None

try:
from pydruid.db.exceptions import Error as PyDruidError
from pydruid.db.exceptions import ProgrammingError as PyDruidProgrammingError
except ImportError:
PyDruidProgrammingError = None
PyDruidProgrammingError = PyDruidError = None

try:
from oracledb.exceptions import DatabaseError as OracleDatabaseError
from oracledb.exceptions import InterfaceError as OracleInterfaceError
except ImportError:
OracleDatabaseError = None
OracleDatabaseError = OracleInterfaceError = None

try:
from pyodbc import DataError as PyODBCDataError
from pyodbc import ProgrammingError as PyODBCProgrammingError
except ImportError:
PyODBCProgrammingError = PyODBCDataError = None


try:
from sqlite3 import ProgrammingError as sqlite3ProgrammingError
except ImportError:
sqlite3ProgrammingError = None
43 changes: 43 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,23 @@
import ibis.expr.operations as ops
from ibis.backends.conftest import ALL_BACKENDS
from ibis.backends.tests.errors import (
DuckDBConnectionException,
ExaQueryError,
ExaRuntimeError,
ImpalaHiveServer2Error,
MySQLInterfaceError,
OracleDatabaseError,
OracleInterfaceError,
PsycoPg2InterfaceError,
PsycoPg2InternalError,
PsycoPg2UndefinedObject,
Py4JJavaError,
PyDruidError,
PyODBCProgrammingError,
SnowflakeDatabaseError,
SnowflakeProgrammingError,
TrinoUserError,
sqlite3ProgrammingError,
)
from ibis.util import gen_name

Expand Down Expand Up @@ -1460,3 +1469,37 @@ def test_list_databases_schemas(con_create_database_schema):
con_create_database_schema.drop_schema(schema, database=database)
finally:
con_create_database_schema.drop_database(database)


@pytest.mark.notyet(
["pandas", "dask", "polars", "datafusion"],
raises=NotImplementedError,
reason="In process backends have nothing to close",
)
@pytest.mark.notyet(
["trino", "clickhouse", "impala", "bigquery"],
reason="Backend client does not conform to DB-API, subsequent op does not raise",
)
def test_close_connection(con):
new_con = getattr(ibis, con.name).connect(*con._con_args, **con._con_kwargs)

# Run any command that hits the backend
_ = new_con.list_tables()
new_con.disconnect()

# DB-API states that subsequent execution attempt should raise
with pytest.raises(
(
DuckDBConnectionException,
ExaRuntimeError,
MySQLInterfaceError,
OracleInterfaceError,
PsycoPg2InterfaceError,
Py4JJavaError,
PyDruidError,
PyODBCProgrammingError,
SnowflakeDatabaseError,
sqlite3ProgrammingError,
)
):
new_con.list_tables()

0 comments on commit da85415

Please sign in to comment.