From 3b1fe388351ab357dcd0628e6210c449b5b9469b Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Tue, 17 Sep 2024 07:07:02 -0400 Subject: [PATCH] fix(memtables): track memtables with a weakset to allow overwriting tables with the same name but different data --- ibis/backends/__init__.py | 48 ++++++++++++++++--- ibis/backends/bigquery/__init__.py | 10 ---- ibis/backends/clickhouse/__init__.py | 24 ++-------- ibis/backends/datafusion/__init__.py | 9 ---- ibis/backends/duckdb/__init__.py | 10 +--- ibis/backends/exasol/__init__.py | 3 -- ibis/backends/flink/__init__.py | 1 + ibis/backends/impala/__init__.py | 4 -- ibis/backends/mssql/__init__.py | 10 ---- ibis/backends/mysql/__init__.py | 20 ++------ ibis/backends/oracle/__init__.py | 17 +------ ibis/backends/polars/__init__.py | 3 -- ibis/backends/postgres/__init__.py | 15 ------ ibis/backends/pyspark/__init__.py | 4 -- ibis/backends/risingwave/__init__.py | 15 ------ ibis/backends/snowflake/__init__.py | 19 -------- ibis/backends/sqlite/__init__.py | 12 ----- ibis/backends/tests/test_client.py | 72 +++++++++++++++++++++++++--- ibis/backends/trino/__init__.py | 15 ------ 19 files changed, 117 insertions(+), 194 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index 8b73d812a34b0..c22977edcae96 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -10,6 +10,7 @@ import sys import urllib.parse import weakref +from collections import Counter from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple @@ -863,6 +864,10 @@ def __init__(self, *args, **kwargs): self._con_args: tuple[Any] = args self._con_kwargs: dict[str, Any] = kwargs self._can_reconnect: bool = True + # mapping of memtable names to their finalizers + self._finalizers = {} + self._memtables = weakref.WeakSet() + self._current_memtables = weakref.WeakValueDictionary() super().__init__() @property @@ -1110,16 +1115,47 @@ def _register_udfs(self, expr: ir.Expr) -> None: if self.supports_python_udfs: raise NotImplementedError(self.name) - def _in_memory_table_exists(self, name: str) -> bool: - return name in self.list_tables() + def _verify_in_memory_tables_are_unique(self, expr: ir.Expr) -> None: + memtables = expr.op().find(ops.InMemoryTable) + name_counts = Counter(op.name for op in memtables) + + if duplicate_names := sorted( + name for name, count in name_counts.items() if count > 1 + ): + raise exc.IbisError(f"Duplicate in-memory table names: {duplicate_names}") + return memtables def _register_in_memory_tables(self, expr: ir.Expr) -> None: - for memtable in expr.op().find(ops.InMemoryTable): - if not self._in_memory_table_exists(memtable.name): + for memtable in self._verify_in_memory_tables_are_unique(expr): + name = memtable.name + + # this particular memtable has never been registered + if memtable not in self._memtables: + # but we have a memtable with the same name + if ( + current_memtable := self._current_memtables.pop(name, None) + ) is not None: + # if we're here this means we overwrite, so do the following: + # 1. remove the old memtable from the set of memtables + # 2. grab the old finalizer and invoke it + self._memtables.remove(current_memtable) + finalizer = self._finalizers.pop(name) + finalizer() + else: + # if memtable is in the set, then by construction it must be + # true that the name of this memtable is in the current + # memtables mapping + assert name in self._current_memtables + + # if there's no memtable named `name` then register it, setup a + # finalizer, and set it as the current memtable with `name` + if self._current_memtables.get(name) is None: self._register_in_memory_table(memtable) - weakref.finalize( - memtable, self._finalize_in_memory_table, memtable.name + self._memtables.add(memtable) + self._finalizers[name] = weakref.finalize( + memtable, self._finalize_in_memory_table, name ) + self._current_memtables[name] = memtable def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: if self.supports_in_memory_tables: diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index df1b2f3674d98..1d7d8a58acd85 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -170,16 +170,6 @@ def _session_dataset(self): self.__session_dataset = self._make_session() return self.__session_dataset - def _in_memory_table_exists(self, name: str) -> bool: - table_ref = bq.TableReference(self._session_dataset, name) - - try: - self._get_table(table_ref) - except com.TableNotFound: - return False - else: - return True - def _finalize_memtable(self, name: str) -> None: table_ref = bq.TableReference(self._session_dataset, name) self.client.delete_table(table_ref, not_found_ok=True) diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index be3334373594c..af003ef441157 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -265,7 +265,9 @@ def _normalize_external_tables(self, external_tables=None) -> ExternalData | Non def _collect_in_memory_tables( self, expr: ir.Table | None, external_tables: Mapping | None = None ): - memtables = {op.name: op for op in expr.op().find(ops.InMemoryTable)} + memtables = { + op.name: op for op in self._verify_in_memory_tables_are_unique(expr) + } externals = toolz.valmap(_to_memtable, external_tables or {}) return toolz.merge(memtables, externals) @@ -779,23 +781,3 @@ def create_view( with self._safe_raw_sql(src, external_tables=external_tables): pass return self.table(name, database=database) - - def _in_memory_table_exists(self, name: str) -> bool: - name = sg.table(name, quoted=self.compiler.quoted).sql(self.dialect) - try: - # DESCRIBE TABLE $TABLE FORMAT NULL is the fastest way to check - # table existence in clickhouse; FORMAT NULL produces no data which - # is ideal since we don't care about the output for existence - # checking - # - # Other methods compared were - # 1. SELECT 1 FROM $TABLE LIMIT 0 - # 2. SHOW TABLES LIKE $TABLE LIMIT 1 - # - # if the table exists nothing is returned and there's no error - # otherwise there's an error - self.con.raw_query(f"DESCRIBE {name} FORMAT NULL") - except cc.driver.exceptions.DatabaseError: - return False - else: - return True diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 0570f163e9f99..62bd7535ef2c1 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -408,15 +408,6 @@ def _register_failure(self): f"please call one of {msg} directly" ) - def _in_memory_table_exists(self, name: str) -> bool: - db = self.con.catalog().database() - try: - db.table(name) - except Exception: # noqa: BLE001 because DataFusion has nothing better - return False - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: # self.con.register_table is broken, so we do this roundabout thing # of constructing a datafusion DataFrame, which has a side effect diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index bb4d4f7aa9fdd..d570afb02e7d6 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -1606,16 +1606,8 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: } ) - def _in_memory_table_exists(self, name: str) -> bool: - try: - # this handles both tables and views - self.con.table(name) - except (duckdb.CatalogException, duckdb.InvalidInputException): - return False - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: + self.con.unregister(op.name) self.con.register(op.name, op.data.to_pyarrow(op.schema)) def _finalize_memtable(self, name: str) -> None: diff --git a/ibis/backends/exasol/__init__.py b/ibis/backends/exasol/__init__.py index ee381b2b363cf..e6626a228003f 100644 --- a/ibis/backends/exasol/__init__.py +++ b/ibis/backends/exasol/__init__.py @@ -276,9 +276,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: finally: self.con.execute(drop_view) - def _in_memory_table_exists(self, name: str) -> bool: - return self.con.meta.table_exists(name) - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 97ea8a06a8f54..e45615c4fa307 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -371,6 +371,7 @@ def compile( def execute(self, expr: ir.Expr, **kwargs: Any) -> Any: """Execute an expression.""" + self._verify_in_memory_tables_are_unique(expr) self._register_udfs(expr) table_expr = expr.as_table() diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index 522c0538bf172..88ef7d1d34027 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -1223,10 +1223,6 @@ def explain( return "\n".join(["Query:", util.indent(query, 2), "", *results.iloc[:, 0]]) - def _in_memory_table_exists(self, name: str) -> bool: - with contextlib.closing(self.con.cursor()) as cur: - return cur.table_exists(name) - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: diff --git a/ibis/backends/mssql/__init__.py b/ibis/backends/mssql/__init__.py index 3a215e1ece0bf..434381d51f044 100644 --- a/ibis/backends/mssql/__init__.py +++ b/ibis/backends/mssql/__init__.py @@ -738,16 +738,6 @@ def create_table( namespace=ops.Namespace(catalog=catalog, database=db), ).to_expr() - def _in_memory_table_exists(self, name: str) -> bool: - # The single character U here means user-defined table - # see https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-objects-transact-sql?view=sql-server-ver16 - sql = sg.select(sg.func("object_id", sge.convert(name), sge.convert("U"))).sql( - self.dialect - ) - with self.begin() as cur: - [(result,)] = cur.execute(sql).fetchall() - return result is not None - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 664a1eff1da8d..1a4447a42c875 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -468,23 +468,6 @@ def create_table( name, schema=schema, source=self, namespace=ops.Namespace(database=database) ).to_expr() - def _in_memory_table_exists(self, name: str) -> bool: - name = sg.to_identifier(name, quoted=self.compiler.quoted).sql(self.dialect) - # just return the single field with column names; no need to bring back - # everything if the command succeeds - sql = f"SHOW COLUMNS FROM {name} LIKE 'Field'" - try: - with self.begin() as cur: - cur.execute(sql) - cur.fetchall() - except pymysql.err.ProgrammingError as e: - err_code, _ = e.args - if err_code == ER.NO_SUCH_TABLE: - return False - raise - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: @@ -496,6 +479,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: name = op.name quoted = self.compiler.quoted + # TODO(cpcloud): remove when port to mysqldb is merged + self.drop_table(name, force=True) + create_stmt = sg.exp.Create( kind="TABLE", this=sg.exp.Schema( diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index 6bcfe69b95918..4df42baf6fd9b 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -24,7 +24,7 @@ from ibis import util from ibis.backends import CanListDatabase, CanListSchema from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers.base import NULL, STAR, C +from ibis.backends.sql.compilers.base import STAR, C if TYPE_CHECKING: from urllib.parse import ParseResult @@ -522,21 +522,6 @@ def drop_table( super().drop_table(name, database=(catalog, db), force=force) - def _in_memory_table_exists(self, name: str) -> bool: - sql = ( - sg.select(NULL) - .from_(sg.to_identifier("USER_OBJECTS", quoted=self.compiler.quoted)) - .where( - C.OBJECT_TYPE.eq(sge.convert("TABLE")), - C.OBJECT_NAME.eq(sge.convert(name)), - ) - .limit(sge.convert(1)) - .sql(self.dialect) - ) - with self.begin() as cur: - results = cur.execute(sql).fetchall() - return bool(results) - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 9e9c25ceb40ad..25d05c02bed64 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -94,9 +94,6 @@ def table(self, name: str) -> ir.Table: schema = sch.infer(table) return ops.DatabaseTable(name, schema, self).to_expr() - def _in_memory_table_exists(self, name: str) -> bool: - return name in self._tables - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: self._add_table(op.name, op.data.to_polars(op.schema).lazy()) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 68feacd6c02c4..ee249edad355b 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -89,21 +89,6 @@ def _from_url(self, url: ParseResult, **kwargs): return self.connect(**kwargs) - def _in_memory_table_exists(self, name: str) -> bool: - import psycopg2.errors - - ident = sg.to_identifier(name, quoted=self.compiler.quoted) - sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect) - - try: - with self.begin() as cur: - cur.execute(sql) - cur.fetchall() - except psycopg2.errors.UndefinedTable: - return False - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: from psycopg2.extras import execute_batch diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index babbd124eeb66..9fc5ab49d989f 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -446,10 +446,6 @@ def _register_udfs(self, expr: ir.Expr) -> None: self._session.udf.register(f"unwrap_json_{typ.__name__}", unwrap_json(typ)) self._session.udf.register("unwrap_json_float", unwrap_json_float) - def _in_memory_table_exists(self, name: str) -> bool: - sql = f"SHOW TABLES IN {self.current_database} LIKE '{name}'" - return bool(self._session.sql(sql).count()) - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = PySparkSchema.from_ibis(op.schema) df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema) diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index 05927651b4b4a..fd3400fb79a17 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -262,21 +262,6 @@ def create_table( name, schema=schema, source=self, namespace=ops.Namespace(database=database) ).to_expr() - def _in_memory_table_exists(self, name: str) -> bool: - import psycopg2.errors - - ident = sg.to_identifier(name, quoted=self.compiler.quoted) - sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect) - - try: - with self.begin() as cur: - cur.execute(sql) - cur.fetchall() - except psycopg2.errors.InternalError: - return False - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 0d62bb76f4383..29fc8b9ef93da 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -663,25 +663,6 @@ def list_tables( return self._filter_with_like(tables + views, like=like) - def _in_memory_table_exists(self, name: str) -> bool: - import snowflake.connector - - ident = sg.to_identifier(name, quoted=self.compiler.quoted) - sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect) - - try: - with self.con.cursor() as cur: - cur.execute(sql).fetchall() - except snowflake.connector.errors.ProgrammingError as e: - # this cryptic error message is the only generic and reliable way - # to tell if the error means "table not found for any reason" - # otherwise, we need to reraise the exception - if e.sqlstate == "42S02": - return False - raise - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: import pyarrow.parquet as pq diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index a1f4f078178db..2ac538b06181b 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -345,18 +345,6 @@ def _generate_create_table(self, table: sge.Table, schema: sch.Schema): return sge.Create(kind="TABLE", this=target) - def _in_memory_table_exists(self, name: str) -> bool: - ident = sg.to_identifier(name, quoted=self.compiler.quoted) - query = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect) - try: - with self.begin() as cur: - cur.execute(query) - cur.fetchall() - except sqlite3.OperationalError: - return False - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: table = sg.table(op.name, quoted=self.compiler.quoted, catalog="temp") create_stmt = self._generate_create_table(table, op.schema).sql(self.name) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 8e3618b4454bc..9f587b49b4eed 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1723,18 +1723,78 @@ def test_memtable_cleanup(con): # the table isn't registered until we actually execute, and since we # haven't yet executed anything, the table shouldn't be there - assert not con._in_memory_table_exists(name) + assert name not in con.list_tables() # execute, which means the table is registered and should be visible in # con.list_tables() con.execute(t.select("a")) - assert con._in_memory_table_exists(name) + assert name in con.list_tables() con.execute(t.select("b")) - assert con._in_memory_table_exists(name) + assert name in con.list_tables() # remove all references to `t`, which means the `op` shouldn't be reachable - # and the table should thus be dropped and no longer visible in - # con.list_tables() + # and the table should thus be dropped and no longer visible del t - assert not con._in_memory_table_exists(name) + assert name not in con.list_tables() + + +def test_same_name_memtable_is_overwritten(con): + name = ibis.util.gen_name("temp_memtable") + + data = {"a": [1, 2, 3], "b": ["a", "b", "c"]} + + t = ibis.memtable(data, name=name) + assert len(con.execute(t)) == 3 + + s = ibis.memtable({"a": [1], "b": ["a"]}, name=name) + assert len(con.execute(s)) == 1 + + +@pytest.mark.notimpl( + ["clickhouse", "flink"], + raises=AssertionError, + reason="backend doesn't use _register_in_memory_table", +) +def test_memtable_registered_exactly_once(con, mocker): + name = ibis.util.gen_name("temp_memtable") + + spy = mocker.spy(con, "_register_in_memory_table") + + data = {"a": [1, 2, 3], "b": ["a", "b", "c"]} + + t = ibis.memtable(data, name=name) + + assert len(con.execute(t)) == 3 + assert len(con.execute(t)) == 3 + + spy.assert_called_once_with(t.op()) + + +def test_unreachable_memtable_does_not_clobber_existing_data(con): + t1 = ibis.memtable({"a": [1, 2, 3]}, name="test") + t2 = ibis.memtable({"a": [4, 5]}, name="test") + + assert len(con.execute(t1)) == 3 + + assert len(con.execute(t2)) == 2 + + assert len(con.execute(t1)) == 3 + + # Drop t1, its finalizer runs, deleting table `test` + del t1 + + # ensure that the previous clean up doesn't erase t2 + assert len(con.execute(t2)) == 2 + + +def test_identically_named_memtables_cannot_be_joined(con): + t1 = ibis.memtable({"a": [1, 2, 3]}, name="test") + t2 = ibis.memtable({"a": [4, 5]}, name="test") + + # mixing two memtables with the same name is an error + expr = t1.join(t2, "a") + with pytest.raises( + com.IbisError, match=r"Duplicate in-memory table names: \['test'\]" + ): + con.execute(expr) diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 61c3c9a4b2237..9926f49766021 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -551,21 +551,6 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: df = TrinoPandasData.convert_table(df, schema) return df - def _in_memory_table_exists(self, name: str) -> bool: - ident = sg.to_identifier(name, quoted=self.compiler.quoted) - sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect) - - try: - with self.begin() as cur: - cur.execute(sql) - cur.fetchall() - except trino.exceptions.TrinoUserError as e: - if e.error_name == "TABLE_NOT_FOUND": - return False - raise - else: - return True - def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: schema = op.schema if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]: