Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dependency parsing #3

Merged
merged 2 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
dist/
/*.ipynb
.DS_Store
*.wal
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ views/
table_6.sql
```

Each view will be named according to its location, following the warehouse convention. For instance, `schema_1/table_1.sql` will be named `dataset.schema_1__table_1` in BigQuery and `schema_1.table_1` in DuckDB.
Each view will be named according to its location, following the warehouse convention. For instance, lea names the `schema/table.sql` view to `dataset.schema__table` in BigQuery and `schema.table` in DuckDB.

To reference a table in a sub-schema, the convention in lea it to use a double underscore `__`. For instance, `schema/sub_schema/table.sql` should be to referred to as `dataset.schema__sub_schema__table` in BigQuery and `schema.sub_schema__table` in DuckDB.

The schemas are expected to be placed under a `views` directory. This can be changed by providing an argument to the `run` command:

Expand Down Expand Up @@ -342,6 +344,8 @@ staging.payments
>>> views = lea.views.load_views('examples/jaffle_shop/views', sqlglot_dialect='duckdb')
>>> views = [v for v in views if v.schema != 'tests']
>>> dag = lea.views.DAGOfViews(views)
>>> dag.prepare()

>>> while dag.is_active():
... for node in sorted(dag.get_ready()):
... print(dag[node])
Expand Down
2 changes: 2 additions & 0 deletions lea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@

from . import app, clients, views

_SEP = "__"

__all__ = ["app", "clients", "views"]
5 changes: 4 additions & 1 deletion lea/app/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def run(

# Organize the views into a directed acyclic graph
dag = lea.views.DAGOfViews(views)
dag.prepare()

# Determine which views need to be run
whitelist = (
Expand All @@ -199,7 +200,9 @@ def run(
continue
console_log(f"Removing {schema}.{table}")
if not dry:
view_to_delete = lea.views.GenericSQLView(schema=schema, name=table, query="")
view_to_delete = lea.views.GenericSQLView(
schema=schema, name=table, query="", sqlglot_dialect=client.sqlglot_dialect
)
client.delete_view(view=view_to_delete)
console_log(f"Removed {schema}.{table}")

Expand Down
14 changes: 10 additions & 4 deletions lea/clients/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import pandas as pd
import sqlglot

from lea import views

Expand All @@ -25,7 +26,7 @@ def prepare(self, console):

@property
def sqlglot_dialect(self):
return "bigquery"
return sqlglot.dialects.Dialects.BIGQUERY

@property
def dataset_name(self):
Expand Down Expand Up @@ -104,7 +105,8 @@ def _load_sql(self, view: views.SQLView) -> pd.DataFrame:

def list_existing_view_names(self):
return [
table.table_id.split("__", 1) for table in self.client.list_tables(self.dataset_name)
table.table_id.split(lea._SEP, 1)
for table in self.client.list_tables(self.dataset_name)
]

def delete_view(self, view: views.View):
Expand All @@ -119,10 +121,14 @@ def get_columns(self, schema=None) -> pd.DataFrame:
data_type AS type
FROM {schema}.INFORMATION_SCHEMA.COLUMNS
"""
return self._load_sql(views.GenericSQLView(schema=None, name=None, query=query))
return self._load_sql(
views.GenericSQLView(
schema=None, name=None, query=query, sqlglot_dialect=self.sqlglot_dialect
)
)

def _make_view_path(self, view: views.View) -> str:
return f"{self.dataset_name}.{view.schema}__{view.name}"
return f"{self.dataset_name}.{view.schema}{lea._SEP}{view.name}"

def make_test_unique_column(self, view: views.View, column: str) -> str:
return f"""
Expand Down
27 changes: 14 additions & 13 deletions lea/clients/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import duckdb
import pandas as pd
import sqlglot

from lea import views
import lea

from .base import Client

Expand All @@ -18,7 +19,7 @@ def __init__(self, path: str, username: str | None):

@property
def sqlglot_dialect(self):
return "duckdb"
return sqlglot.dialects.Dialects.DUCKDB

def prepare(self, views, console):
schemas = set(
Expand All @@ -28,27 +29,27 @@ def prepare(self, views, console):
self.con.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}")
console.log(f"Created schema {schema}")

def _create_python(self, view: views.PythonView):
def _create_python(self, view: lea.views.PythonView):
dataframe = self._load_python(view) # noqa: F841
self.con.sql(
f"CREATE OR REPLACE TABLE {self._make_view_path(view)} AS SELECT * FROM dataframe"
)

def _create_sql(self, view: views.SQLView):
def _create_sql(self, view: lea.views.SQLView):
query = view.query
if self.username:
for schema, *_ in view.dependencies:
for schema in {schema for schema, *_ in view.dependencies}:
query = query.replace(f"{schema}.", f"{schema}_{self.username}.")
self.con.sql(f"CREATE OR REPLACE TABLE {self._make_view_path(view)} AS ({query})")

def _load_sql(self, view: views.SQLView):
def _load_sql(self, view: lea.views.SQLView):
query = view.query
if self.username:
for schema, *_ in view.dependencies:
for schema in {schema for schema, *_ in view.dependencies}:
query = query.replace(f"{schema}.", f"{schema}_{self.username}.")
return self.con.cursor().sql(query).df()

def delete_view(self, view: views.View):
def delete_view(self, view: lea.views.View):
self.con.sql(f"DROP TABLE IF EXISTS {self._make_view_path(view)}")

def teardown(self):
Expand All @@ -61,23 +62,23 @@ def list_existing_view_names(self) -> list[tuple[str, str]]:
def get_columns(self, schema=None) -> pd.DataFrame:
query = """
SELECT
table_schema || '.' || table_name AS view_name,
table_schema || '.' || table_name AS view_name,
column_name AS column,
data_type AS type
FROM information_schema.columns
"""
return self.con.sql(query).df()

def _make_view_path(self, view: views.View) -> str:
def _make_view_path(self, view: lea.views.View) -> str:
schema, *leftover = view.key
schema = f"{schema}_{self.username}" if self.username else schema
return f"{schema}.{'__'.join(leftover)}"
return f"{schema}.{lea._SEP.join(leftover)}"

def make_test_unique_column(self, view: views.View, column: str) -> str:
def make_test_unique_column(self, view: lea.views.View, column: str) -> str:
schema, *leftover = view.key
return f"""
SELECT {column}, COUNT(*) AS n
FROM {f"{schema}.{'__'.join(leftover)}"}
FROM {f"{schema}.{lea._SEP.join(leftover)}"}
GROUP BY {column}
HAVING n > 1
"""
2 changes: 1 addition & 1 deletion lea/views/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def load_views(
# Massage the inputs
if isinstance(views_dir, str):
views_dir = pathlib.Path(views_dir)
if isinstance(views_dir, str):
if isinstance(sqlglot_dialect, str):
sqlglot_dialect = sqlglot.dialects.Dialects(sqlglot_dialect)

def _load_view_from_path(path, origin, sqlglot_dialect):
Expand Down
1 change: 0 additions & 1 deletion lea/views/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def __init__(self, views: list[lea.views.View]):
graphlib.TopologicalSorter.__init__(self, view_to_dependencies)
collections.UserDict.__init__(self, {view.key: view for view in views})
self.dependencies = view_to_dependencies
self.prepare()

@property
def schemas(self) -> set:
Expand Down
35 changes: 22 additions & 13 deletions lea/views/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import jinja2
import sqlglot
import sqlglot.optimizer.scope

import lea

from .base import View

Expand Down Expand Up @@ -50,18 +53,24 @@ def query(self):
return template.render(env=os.environ)
return text

def parse_dependencies(self, query):
parse = sqlglot.parse_one(query, dialect=self.sqlglot_dialect)
cte_names = {(None, cte.alias) for cte in parse.find_all(sqlglot.exp.CTE)}
table_names = {
(table.sql().split(".")[0], table.name)
if "__" not in table.name and "." in table.sql()
else (table.name.split("__")[0], table.name.split("__", 1)[1])
if "__" in table.name
else (None, table.name)
for table in parse.find_all(sqlglot.exp.Table)
}
return table_names - cte_names
def parse_dependencies(self, query) -> set[tuple[str, str]]:
expression = sqlglot.parse_one(query, dialect=self.sqlglot_dialect)
dependencies = set()

for scope in sqlglot.optimizer.scope.traverse_scope(expression):
for table in scope.tables:
if (
not isinstance(table.this, sqlglot.exp.Func)
and sqlglot.exp.table_name(table) not in scope.cte_sources
):
if self.sqlglot_dialect is sqlglot.dialects.Dialects.BIGQUERY:
dependencies.add(tuple(table.name.split(lea._SEP)))
elif self.sqlglot_dialect is sqlglot.dialects.Dialects.DUCKDB:
dependencies.add((table.db, *table.name.split(lea._SEP)))
else:
raise ValueError(f"Unsupported SQL dialect: {self.sqlglot_dialect}")

return dependencies

@property
def dependencies(self):
Expand All @@ -79,7 +88,7 @@ def dependencies(self):
):
schema, view_name = (
(
match.group("view").split("__")[0],
match.group("view").split(lea._SEP)[0],
match.group("view").split("__", 1)[1],
)
if "__" in match.group("view")
Expand Down
64 changes: 64 additions & 0 deletions lea/views/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import itertools

import pytest
import sqlglot

import lea


@pytest.mark.parametrize(
"view, expected",
[
pytest.param(
lea.views.GenericSQLView(
schema=None,
name=None,
query=query,
sqlglot_dialect=sqlglot_dialect,
),
expected,
id=f"{sqlglot_dialect.name}#{i}",
)
for sqlglot_dialect, cases in {
sqlglot.dialects.Dialects.BIGQUERY: [
(
"""
SELECT *
FROM dataset.schema__table

""",
{("schema", "table")},
),
(
"""
SELECT *
FROM dataset.schema__sub_schema__table

""",
{("schema", "sub_schema", "table")},
),
],
sqlglot.dialects.Dialects.DUCKDB: [
(
"""
SELECT *
FROM schema.table

""",
{("schema", "table")},
),
(
"""
SELECT *
FROM schema.sub_schema__table

""",
{("schema", "sub_schema", "table")},
),
],
}.items()
for i, (query, expected) in enumerate(cases)
],
)
def test_dependency_parsing(view, expected):
assert view.dependencies == expected
8 changes: 5 additions & 3 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,20 @@ def test_jaffle_shop():

# Write .env file
with open(env_path, "w") as f:
f.write("LEA_USERNAME=max\n" "LEA_WAREHOUSE=duckdb\n" "LEA_DUCKDB_PATH=duckdb.db\n")
f.write(
"LEA_USERNAME=max\n" "LEA_WAREHOUSE=duckdb\n" "LEA_DUCKDB_PATH=test_jaffle_shop.db\n"
)

# Prepare
result = runner.invoke(app, ["prepare", views_path, "--env", env_path])
assert result.exit_code == 0

# Run
result = runner.invoke(app, ["run", views_path, "--env", env_path])
result = runner.invoke(app, ["run", views_path, "--env", env_path, "--fresh"])
assert result.exit_code == 0

# Check number of tables created
con = duckdb.connect("duckdb.db")
con = duckdb.connect("test_jaffle_shop.db")
tables = con.sql("SELECT table_schema, table_name FROM information_schema.tables").df()
assert tables.shape[0] == 7

Expand Down
Loading