Skip to content

Commit

Permalink
fix(memtables): track memtables with a weakset to allow overwriting t…
Browse files Browse the repository at this point in the history
…ables with the same name but different data
  • Loading branch information
cpcloud committed Sep 18, 2024
1 parent 966c5e8 commit 3b1fe38
Show file tree
Hide file tree
Showing 19 changed files with 117 additions and 194 deletions.
48 changes: 42 additions & 6 deletions ibis/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import urllib.parse
import weakref
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple

Expand Down Expand Up @@ -863,6 +864,10 @@ def __init__(self, *args, **kwargs):
self._con_args: tuple[Any] = args
self._con_kwargs: dict[str, Any] = kwargs
self._can_reconnect: bool = True
# mapping of memtable names to their finalizers
self._finalizers = {}
self._memtables = weakref.WeakSet()
self._current_memtables = weakref.WeakValueDictionary()
super().__init__()

@property
Expand Down Expand Up @@ -1110,16 +1115,47 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def _in_memory_table_exists(self, name: str) -> bool:
return name in self.list_tables()
def _verify_in_memory_tables_are_unique(self, expr: ir.Expr) -> None:
memtables = expr.op().find(ops.InMemoryTable)
name_counts = Counter(op.name for op in memtables)

if duplicate_names := sorted(
name for name, count in name_counts.items() if count > 1
):
raise exc.IbisError(f"Duplicate in-memory table names: {duplicate_names}")
return memtables

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in expr.op().find(ops.InMemoryTable):
if not self._in_memory_table_exists(memtable.name):
for memtable in self._verify_in_memory_tables_are_unique(expr):
name = memtable.name

# this particular memtable has never been registered
if memtable not in self._memtables:
# but we have a memtable with the same name
if (
current_memtable := self._current_memtables.pop(name, None)
) is not None:
# if we're here this means we overwrite, so do the following:
# 1. remove the old memtable from the set of memtables
# 2. grab the old finalizer and invoke it
self._memtables.remove(current_memtable)
finalizer = self._finalizers.pop(name)
finalizer()
else:
# if memtable is in the set, then by construction it must be
# true that the name of this memtable is in the current
# memtables mapping
assert name in self._current_memtables

# if there's no memtable named `name` then register it, setup a
# finalizer, and set it as the current memtable with `name`
if self._current_memtables.get(name) is None:
self._register_in_memory_table(memtable)
weakref.finalize(
memtable, self._finalize_in_memory_table, memtable.name
self._memtables.add(memtable)
self._finalizers[name] = weakref.finalize(
memtable, self._finalize_in_memory_table, name
)
self._current_memtables[name] = memtable

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
if self.supports_in_memory_tables:
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,6 @@ def _session_dataset(self):
self.__session_dataset = self._make_session()
return self.__session_dataset

def _in_memory_table_exists(self, name: str) -> bool:
table_ref = bq.TableReference(self._session_dataset, name)

try:
self._get_table(table_ref)
except com.TableNotFound:
return False
else:
return True

def _finalize_memtable(self, name: str) -> None:
table_ref = bq.TableReference(self._session_dataset, name)
self.client.delete_table(table_ref, not_found_ok=True)
Expand Down
24 changes: 3 additions & 21 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def _normalize_external_tables(self, external_tables=None) -> ExternalData | Non
def _collect_in_memory_tables(
self, expr: ir.Table | None, external_tables: Mapping | None = None
):
memtables = {op.name: op for op in expr.op().find(ops.InMemoryTable)}
memtables = {
op.name: op for op in self._verify_in_memory_tables_are_unique(expr)
}
externals = toolz.valmap(_to_memtable, external_tables or {})
return toolz.merge(memtables, externals)

Expand Down Expand Up @@ -779,23 +781,3 @@ def create_view(
with self._safe_raw_sql(src, external_tables=external_tables):
pass
return self.table(name, database=database)

def _in_memory_table_exists(self, name: str) -> bool:
name = sg.table(name, quoted=self.compiler.quoted).sql(self.dialect)
try:
# DESCRIBE TABLE $TABLE FORMAT NULL is the fastest way to check
# table existence in clickhouse; FORMAT NULL produces no data which
# is ideal since we don't care about the output for existence
# checking
#
# Other methods compared were
# 1. SELECT 1 FROM $TABLE LIMIT 0
# 2. SHOW TABLES LIKE $TABLE LIMIT 1
#
# if the table exists nothing is returned and there's no error
# otherwise there's an error
self.con.raw_query(f"DESCRIBE {name} FORMAT NULL")
except cc.driver.exceptions.DatabaseError:
return False
else:
return True
9 changes: 0 additions & 9 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,6 @@ def _register_failure(self):
f"please call one of {msg} directly"
)

def _in_memory_table_exists(self, name: str) -> bool:
db = self.con.catalog().database()
try:
db.table(name)
except Exception: # noqa: BLE001 because DataFusion has nothing better
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
# self.con.register_table is broken, so we do this roundabout thing
# of constructing a datafusion DataFrame, which has a side effect
Expand Down
10 changes: 1 addition & 9 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,16 +1606,8 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
}
)

def _in_memory_table_exists(self, name: str) -> bool:
try:
# this handles both tables and views
self.con.table(name)
except (duckdb.CatalogException, duckdb.InvalidInputException):
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self.con.unregister(op.name)
self.con.register(op.name, op.data.to_pyarrow(op.schema))

def _finalize_memtable(self, name: str) -> None:
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
finally:
self.con.execute(drop_view)

def _in_memory_table_exists(self, name: str) -> bool:
return self.con.meta.table_exists(name)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def compile(

def execute(self, expr: ir.Expr, **kwargs: Any) -> Any:
"""Execute an expression."""
self._verify_in_memory_tables_are_unique(expr)
self._register_udfs(expr)

table_expr = expr.as_table()
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,10 +1223,6 @@ def explain(

return "\n".join(["Query:", util.indent(query, 2), "", *results.iloc[:, 0]])

def _in_memory_table_exists(self, name: str) -> bool:
with contextlib.closing(self.con.cursor()) as cur:
return cur.table_exists(name)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,16 +738,6 @@ def create_table(
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
# The single character U here means user-defined table
# see https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-objects-transact-sql?view=sql-server-ver16
sql = sg.select(sg.func("object_id", sge.convert(name), sge.convert("U"))).sql(
self.dialect
)
with self.begin() as cur:
[(result,)] = cur.execute(sql).fetchall()
return result is not None

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
20 changes: 3 additions & 17 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,23 +468,6 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
name = sg.to_identifier(name, quoted=self.compiler.quoted).sql(self.dialect)
# just return the single field with column names; no need to bring back
# everything if the command succeeds
sql = f"SHOW COLUMNS FROM {name} LIKE 'Field'"
try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except pymysql.err.ProgrammingError as e:
err_code, _ = e.args
if err_code == ER.NO_SUCH_TABLE:
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand All @@ -496,6 +479,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
name = op.name
quoted = self.compiler.quoted

# TODO(cpcloud): remove when port to mysqldb is merged
self.drop_table(name, force=True)

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
Expand Down
17 changes: 1 addition & 16 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ibis import util
from ibis.backends import CanListDatabase, CanListSchema
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import NULL, STAR, C
from ibis.backends.sql.compilers.base import STAR, C

if TYPE_CHECKING:
from urllib.parse import ParseResult
Expand Down Expand Up @@ -522,21 +522,6 @@ def drop_table(

super().drop_table(name, database=(catalog, db), force=force)

def _in_memory_table_exists(self, name: str) -> bool:
sql = (
sg.select(NULL)
.from_(sg.to_identifier("USER_OBJECTS", quoted=self.compiler.quoted))
.where(
C.OBJECT_TYPE.eq(sge.convert("TABLE")),
C.OBJECT_NAME.eq(sge.convert(name)),
)
.limit(sge.convert(1))
.sql(self.dialect)
)
with self.begin() as cur:
results = cur.execute(sql).fetchall()
return bool(results)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema

Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ def table(self, name: str) -> ir.Table:
schema = sch.infer(table)
return ops.DatabaseTable(name, schema, self).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
return name in self._tables

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self._add_table(op.name, op.data.to_polars(op.schema).lazy())

Expand Down
15 changes: 0 additions & 15 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,6 @@ def _from_url(self, url: ParseResult, **kwargs):

return self.connect(**kwargs)

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.UndefinedTable:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
from psycopg2.extras import execute_batch

Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,6 @@ def _register_udfs(self, expr: ir.Expr) -> None:
self._session.udf.register(f"unwrap_json_{typ.__name__}", unwrap_json(typ))
self._session.udf.register("unwrap_json_float", unwrap_json_float)

def _in_memory_table_exists(self, name: str) -> bool:
sql = f"SHOW TABLES IN {self.current_database} LIKE '{name}'"
return bool(self._session.sql(sql).count())

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = PySparkSchema.from_ibis(op.schema)
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
Expand Down
15 changes: 0 additions & 15 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,21 +262,6 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.InternalError:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
19 changes: 0 additions & 19 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,25 +663,6 @@ def list_tables(

return self._filter_with_like(tables + views, like=like)

def _in_memory_table_exists(self, name: str) -> bool:
import snowflake.connector

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.con.cursor() as cur:
cur.execute(sql).fetchall()
except snowflake.connector.errors.ProgrammingError as e:
# this cryptic error message is the only generic and reliable way
# to tell if the error means "table not found for any reason"
# otherwise, we need to reraise the exception
if e.sqlstate == "42S02":
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

Expand Down
12 changes: 0 additions & 12 deletions ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,18 +345,6 @@ def _generate_create_table(self, table: sge.Table, schema: sch.Schema):

return sge.Create(kind="TABLE", this=target)

def _in_memory_table_exists(self, name: str) -> bool:
ident = sg.to_identifier(name, quoted=self.compiler.quoted)
query = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)
try:
with self.begin() as cur:
cur.execute(query)
cur.fetchall()
except sqlite3.OperationalError:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
table = sg.table(op.name, quoted=self.compiler.quoted, catalog="temp")
create_stmt = self._generate_create_table(table, op.schema).sql(self.name)
Expand Down
Loading

0 comments on commit 3b1fe38

Please sign in to comment.