Skip to content

Commit

Permalink
fix(pyspark): unwind catalog/database settings in same order they wer…
Browse files Browse the repository at this point in the history
…e set (#9067)
  • Loading branch information
gforsyth authored Apr 30, 2024
1 parent fd35b66 commit 962ee00
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 9 deletions.
36 changes: 27 additions & 9 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,34 @@ def current_catalog(self) -> str:
return catalog

@contextlib.contextmanager
def _active_database(self, name: str | None):
if name is None:
def _active_catalog_database(self, catalog: str | None, db: str | None):
if catalog is None and db is None:
yield
return
current = self.current_database
if catalog is not None and PYSPARK_LT_34:
raise com.UnsupportedArgumentError(
"Catalogs are not supported in pyspark < 3.4"
)
current_catalog = self.current_catalog
current_db = self.current_database

# This little horrible bit of work is to avoid trying to set
# the `CurrentDatabase` inside of a catalog where we don't have permission
# to do so. We can't have the catalog and database context managers work
# separately because we need to:
# 1. set catalog
# 2. set database
# 3. set catalog to previous
# 4. set database to previous
try:
self._session.catalog.setCurrentDatabase(name)
if catalog is not None:
self._session.catalog.setCurrentCatalog(catalog)
self._session.catalog.setCurrentDatabase(db)
yield
finally:
self._session.catalog.setCurrentDatabase(current)
if catalog is not None:
self._session.catalog.setCurrentCatalog(current_catalog)
self._session.catalog.setCurrentDatabase(current_db)

@contextlib.contextmanager
def _active_catalog(self, name: str | None):
Expand Down Expand Up @@ -438,7 +456,7 @@ def get_schema(

table_loc = self._to_sqlglot_table((catalog, database))
catalog, db = self._to_catalog_db_tuple(table_loc)
with self._active_catalog(catalog), self._active_database(db):
with self._active_catalog_database(catalog, db):
df = self._session.table(table_name)
struct = PySparkType.to_ibis(df.schema)

Expand Down Expand Up @@ -500,18 +518,18 @@ def create_table(
table = obj if isinstance(obj, ir.Expr) else ibis.memtable(obj)
query = self.compile(table)
mode = "overwrite" if overwrite else "error"
with self._active_catalog(catalog), self._active_database(db):
with self._active_catalog_database(catalog, db):
self._run_pre_execute_hooks(table)
df = self._session.sql(query)
df.write.saveAsTable(name, format=format, mode=mode)
elif schema is not None:
schema = PySparkSchema.from_ibis(schema)
with self._active_catalog(catalog), self._active_database(db):
with self._active_catalog_database(catalog, db):
self._session.catalog.createTable(name, schema=schema, format=format)
else:
raise com.IbisError("The schema or obj parameter is required")

return self.table(name, database=db)
return self.table(name, database=(catalog, db))

def create_view(
self,
Expand Down
19 changes: 19 additions & 0 deletions ibis/backends/pyspark/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import pytest

import ibis


@pytest.mark.xfail_version(pyspark=["pyspark<3.4"], reason="no catalog support")
def test_catalog_db_args(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})
Expand All @@ -20,3 +23,19 @@ def test_catalog_db_args(con, monkeypatch):
con.drop_table("t2", database="spark_catalog.default")

assert "t2" not in con.list_tables(database="default")


def test_create_table_no_catalog(con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]})

# create a table in specified catalog and db
con.create_table("t2", database=("default"), obj=t, overwrite=True)

assert "t2" not in con.list_tables()
assert "t2" in con.list_tables(database="default")
assert "t2" in con.list_tables(database=("default"))

con.drop_table("t2", database="default")

assert "t2" not in con.list_tables(database="default")

0 comments on commit 962ee00

Please sign in to comment.