Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Sep 11, 2024
1 parent 13fde2e commit b6ee504
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 20 deletions.
2 changes: 1 addition & 1 deletion lea/clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _view_key_to_table_reference(self, view_key: tuple[str], with_username: bool
...

@abc.abstractmethod
def _table_reference_to_view_key(self, table_reference: str) -> tuple[str]:
def _table_reference_to_view_key(self, table_reference: str) -> tuple[str, ...]:
...

def materialize_view(self, view: lea.views.View) -> QueryResult:
Expand Down
10 changes: 6 additions & 4 deletions lea/clients/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ def _view_key_to_table_reference(
>>> client = BigQuery(
... credentials=None,
... location="US",
... project_id="project",
... compute_project_id="compute",
... write_project_id="write",
... dataset_name="dataset",
... username="max",
... wap_mode=False
Expand All @@ -188,7 +189,7 @@ def _view_key_to_table_reference(
'dataset.schema__table'
>>> client._view_key_to_table_reference(("schema", "table"), with_context=True)
'`project`.dataset_max.schema__table'
'`write`.dataset_max.schema__table'
"""
table_reference = f"{self._dataset_name}.{lea._SEP.join(view_key)}"
Expand All @@ -203,13 +204,14 @@ def _view_key_to_table_reference(
table_reference = f"{self.write_project_id}.{table_reference}"
return table_reference

def _table_reference_to_view_key(self, table_reference: str) -> tuple[str]:
def _table_reference_to_view_key(self, table_reference: str) -> tuple[str, ...]:
"""
>>> client = BigQuery(
... credentials=None,
... location="US",
... project_id="project",
... compute_project_id="compute",
... write_project_id="write",
... dataset_name="dataset",
... username="max",
... wap_mode=False
Expand Down
33 changes: 20 additions & 13 deletions lea/clients/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ class DuckDB(Client):
def __init__(self, path: str, username: str | None = None, wap_mode: bool = False):
import duckdb

path_ = pathlib.Path(path)
if path.startswith("md:"):
path = f"{path}_{username}" if username is not None else path
else:
path = pathlib.Path(path)
if username is not None:
path = (path.parent / f"{path.stem}_{username}{path.suffix}").absolute()
self.path = path
path_ = pathlib.Path(f"{path}_{username}" if username is not None else path)
elif username is not None:
path_ = pathlib.Path(path)
path_ = path_.parent / f"{path_.stem}_{username}{path_.suffix}"
self.path_ = path_
self.username = username
self.wap_mode = wap_mode
self.con = duckdb.connect(str(self.path))
self.con = (
duckdb.connect(":memory:")
if path == ":memory:"
else duckdb.connect(str(self.path_.absolute()))
)

@property
def sqlglot_dialect(self):
Expand Down Expand Up @@ -73,7 +77,7 @@ def list_tables(self) -> pd.DataFrame:
return self.read_sql(
f"""
SELECT
'{self.path.stem}' || '.' || schema_name || '.' || table_name AS table_reference,
'{self.path_.stem}' || '.' || schema_name || '.' || table_name AS table_reference,
estimated_size AS n_rows, -- TODO: Figure out how to get the exact number
estimated_size AS n_bytes -- TODO: Figure out how to get this
FROM duckdb_tables()
Expand All @@ -84,14 +88,16 @@ def list_columns(self) -> pd.DataFrame:
return self.read_sql(
f"""
SELECT
'{self.path.stem}' || '.' || table_schema || '.' || table_name AS table_reference,
'{self.path_.stem}' || '.' || table_schema || '.' || table_name AS table_reference,
column_name AS column,
data_type AS type
FROM information_schema.columns
"""
)

def _view_key_to_table_reference(self, view_key: tuple[str], with_context: bool) -> str:
def _view_key_to_table_reference(
self, view_key: tuple[str], with_context: bool, with_project_id=False
) -> str:
"""
>>> client = DuckDB(path=":memory:", username=None)
Expand All @@ -103,16 +109,17 @@ def _view_key_to_table_reference(self, view_key: tuple[str], with_context: bool)
'schema.subschema__table'
"""
leftover: list[str] = []
schema, *leftover = view_key
table_reference = f"{schema}.{lea._SEP.join(leftover)}"
if with_context:
if self.username:
table_reference = f"{self.path.stem}.{table_reference}"
table_reference = f"{self.path_.stem}.{table_reference}"
if self.wap_mode:
table_reference = f"{table_reference}{lea._SEP}{lea._WAP_MODE_SUFFIX}"
return table_reference

def _table_reference_to_view_key(self, table_reference: str) -> tuple[str]:
def _table_reference_to_view_key(self, table_reference: str) -> tuple[str, ...]:
"""
>>> client = DuckDB(path=":memory:", username=None)
Expand All @@ -125,7 +132,7 @@ def _table_reference_to_view_key(self, table_reference: str) -> tuple[str]:
"""
database, leftover = table_reference.split(".", 1)
if database == self.path.stem:
if database == self.path_.stem:
schema, leftover = leftover.split(".", 1)
else:
schema = database
Expand Down
2 changes: 1 addition & 1 deletion lea/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _make_table_reference_mapping(
for view_key in self.regular_views
}

return table_reference_mapping
return {k: v for k, v in table_reference_mapping.items() if v != k}

def prepare(self):
self.client.prepare(self.regular_views.values())
Expand Down
3 changes: 2 additions & 1 deletion lea/views/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
lea.clients.BigQuery(
credentials=None,
location=None,
project_id=None,
compute_project_id=None,
write_project_id=None,
dataset_name="dataset",
username="max",
wap_mode=False,
Expand Down

0 comments on commit b6ee504

Please sign in to comment.