diff --git a/.gitignore b/.gitignore index 675cd01..b7d98b5 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ *.db .env dist/ +/*.ipynb +.DS_Store diff --git a/README.md b/README.md index 4473a66..28c35d4 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ lea aims to be simple and opinionated, and yet offers the possibility to be exte Right now lea is compatible with BigQuery (used at Carbonfact) and DuckDB (quack quack). - [Example](#example) +- [Teaser](#teaser) - [Installation](#installation) - [Usage](#usage) - [Configuration](#configuration) @@ -53,6 +54,8 @@ Right now lea is compatible with BigQuery (used at Carbonfact) and DuckDB (quack - [Jaffle shop 🥪](examples/jaffle_shop/) +## Teaser + ## Installation ```sh @@ -69,7 +72,6 @@ lea is configured by setting environment variables. The following variables are ```sh # General configuration -LEA_SCHEMA=kaya LEA_USERNAME=max LEA_WAREHOUSE=bigquery @@ -79,6 +81,7 @@ LEA_DUCKDB_PATH=duckdb.db # BigQuery 🦏 LEA_BQ_LOCATION=EU LEA_BQ_PROJECT_ID=carbonfact-dwh +LEA_BQ_DATASET_NAME=kaya LEA_BQ_SERVICE_ACCOUNT= ``` @@ -117,6 +120,8 @@ views/ table_6.sql ``` +Each view will be named according to its location, following the warehouse convention. For instance, `schema_1/table_1.sql` will be named `dataset.schema_1__table_1` in BigQuery and `schema_1.table_1` in DuckDB. + The schemas are expected to be placed under a `views` directory. This can be changed by providing an argument to the `run` command: ```sh @@ -163,6 +168,12 @@ You can select all views in a schema: lea run --only core/ ``` +This also work with sub-schemas: + +```sh +lea run --only analytics.finance/ +``` + There are thus 8 possible operators: ``` @@ -213,10 +224,10 @@ lea test views There are two types of tests: - Singular tests -- these are queries which return failing rows. They are stored in a `tests` directory. -- Annotation tests -- these are comment annotations in the queries themselves: +- Assertion tests -- these are comment annotations in the queries themselves: - `@UNIQUE` -- checks that a column's values are unique. -As with the `run` command, there is a `--production` flag to disable the `` suffix. +As with the `run` command, there is a `--production` flag to disable the `` suffix and thus test production data. ### `lea docs` @@ -302,6 +313,8 @@ lea is meant to be used as a CLI. But you can import it as a Python library too. >>> for view in sorted(views, key=str): ... print(view) ... print(sorted(view.dependencies)) +analytics.finance.kpis +[('core', 'orders')] analytics.kpis [('core', 'customers'), ('core', 'orders')] core.customers @@ -326,14 +339,15 @@ staging.payments >>> views = [v for v in views if v.schema != 'tests'] >>> dag = lea.views.DAGOfViews(views) >>> while dag.is_active(): -... for schema, table in sorted(dag.get_ready()): -... print(f'{schema}.{table}') -... dag.done((schema, table)) +... for node in sorted(dag.get_ready()): +... print(dag[node]) +... dag.done(node) staging.customers staging.orders staging.payments core.customers core.orders +analytics.finance.kpis analytics.kpis ``` diff --git a/examples/jaffle_shop/README.md b/examples/jaffle_shop/README.md index 7caa6a3..10535b2 100644 --- a/examples/jaffle_shop/README.md +++ b/examples/jaffle_shop/README.md @@ -19,19 +19,24 @@ The first thing to do is create an `.env` file, as so: ```sh echo " -LEA_SCHEMA=jaffle_shop LEA_USERNAME=max LEA_WAREHOUSE=duckdb LEA_DUCKDB_PATH=duckdb.db " > .env ``` -Next, run the following command to create the `duckdb.db` file and the `jaffle_shop` schema therein: +Next, run the following command to create the `jaffle_shop.db` file and the schemas therein: ```sh lea prepare ``` +``` +Created schema analytics_max +Created schema core_max +Created schema staging_max +``` + Now you can run the views: ```sh diff --git a/examples/jaffle_shop/docs/README.md b/examples/jaffle_shop/docs/README.md index 74fcc93..96370dd 100644 --- a/examples/jaffle_shop/docs/README.md +++ b/examples/jaffle_shop/docs/README.md @@ -3,8 +3,8 @@ ## Schemas - [`analytics`](./analytics) -- [`core`](./core) - [`staging`](./staging) +- [`core`](./core) ## Schema flowchart @@ -32,6 +32,7 @@ flowchart TB subgraph staging end + core.orders --> analytics.finance.kpis core.customers --> analytics.kpis core.orders --> analytics.kpis staging.customers --> core.customers diff --git a/examples/jaffle_shop/docs/analytics/README.md b/examples/jaffle_shop/docs/analytics/README.md index c0325c1..68514e4 100644 --- a/examples/jaffle_shop/docs/analytics/README.md +++ b/examples/jaffle_shop/docs/analytics/README.md @@ -2,15 +2,28 @@ ## Table of contents -- [kpis](#kpis) +- [analytics.finance.kpis](#analytics.finance.kpis) +- [analytics.kpis](#analytics.kpis) ## Views -### kpis +### analytics.finance.kpis ```sql SELECT * -FROM jaffle_shop_max.analytics__kpis +FROM analytics_max.finance__kpis +``` + +| Column | Type | Description | Unique | +|:--------------------|:---------|:--------------|:---------| +| average_order_value | `DOUBLE` | | | +| total_order_value | `DOUBLE` | | | + +### analytics.kpis + +```sql +SELECT * +FROM analytics_max.kpis ``` | Column | Type | Description | Unique | diff --git a/examples/jaffle_shop/docs/core/README.md b/examples/jaffle_shop/docs/core/README.md index 35986be..72caa2f 100644 --- a/examples/jaffle_shop/docs/core/README.md +++ b/examples/jaffle_shop/docs/core/README.md @@ -2,16 +2,16 @@ ## Table of contents -- [customers](#customers) -- [orders](#orders) +- [core.customers](#core.customers) +- [core.orders](#core.orders) ## Views -### customers +### core.customers ```sql SELECT * -FROM jaffle_shop_max.core__customers +FROM core_max.customers ``` | Column | Type | Description | Unique | @@ -24,11 +24,11 @@ FROM jaffle_shop_max.core__customers | most_recent_order | `VARCHAR` | | | | number_of_orders | `BIGINT` | | | -### orders +### core.orders ```sql SELECT * -FROM jaffle_shop_max.core__orders +FROM core_max.orders ``` | Column | Type | Description | Unique | diff --git a/examples/jaffle_shop/docs/staging/README.md b/examples/jaffle_shop/docs/staging/README.md index bd805ca..302188f 100644 --- a/examples/jaffle_shop/docs/staging/README.md +++ b/examples/jaffle_shop/docs/staging/README.md @@ -2,19 +2,19 @@ ## Table of contents -- [customers](#customers) -- [orders](#orders) -- [payments](#payments) +- [staging.customers](#staging.customers) +- [staging.orders](#staging.orders) +- [staging.payments](#staging.payments) ## Views -### customers +### staging.customers Docstring for the customers view. ```sql SELECT * -FROM jaffle_shop_max.staging__customers +FROM staging_max.customers ``` | Column | Type | Description | Unique | @@ -23,13 +23,13 @@ FROM jaffle_shop_max.staging__customers | first_name | `VARCHAR` | | | | last_name | `VARCHAR` | | | -### orders +### staging.orders Docstring for the orders view. ```sql SELECT * -FROM jaffle_shop_max.staging__orders +FROM staging_max.orders ``` | Column | Type | Description | Unique | @@ -39,11 +39,11 @@ FROM jaffle_shop_max.staging__orders | order_id | `BIGINT` | | | | status | `VARCHAR` | | | -### payments +### staging.payments ```sql SELECT * -FROM jaffle_shop_max.staging__payments +FROM staging_max.payments ``` | Column | Type | Description | Unique | diff --git a/examples/jaffle_shop/views/analytics/finance/kpis.sql b/examples/jaffle_shop/views/analytics/finance/kpis.sql new file mode 100644 index 0000000..6d5d4e3 --- /dev/null +++ b/examples/jaffle_shop/views/analytics/finance/kpis.sql @@ -0,0 +1,4 @@ +SELECT + SUM(amount) AS total_order_value, + AVG(amount) AS average_order_value +FROM core.orders diff --git a/examples/jaffle_shop/views/analytics/kpis.sql b/examples/jaffle_shop/views/analytics/kpis.sql index ef3b5ba..49dbd27 100644 --- a/examples/jaffle_shop/views/analytics/kpis.sql +++ b/examples/jaffle_shop/views/analytics/kpis.sql @@ -2,7 +2,7 @@ SELECT 'n_customers' AS metric, COUNT(*) AS value FROM - jaffle_shop.core__customers + core.customers UNION ALL @@ -10,4 +10,4 @@ SELECT 'n_orders' AS metric, COUNT(*) AS value FROM - jaffle_shop.core__orders + core.orders diff --git a/examples/jaffle_shop/views/core/customers.sql b/examples/jaffle_shop/views/core/customers.sql index e338d07..c288aad 100644 --- a/examples/jaffle_shop/views/core/customers.sql +++ b/examples/jaffle_shop/views/core/customers.sql @@ -1,18 +1,18 @@ with customers as ( - select * from jaffle_shop.staging__customers + select * from staging.customers ), orders as ( - select * from jaffle_shop.staging__orders + select * from staging.orders ), payments as ( - select * from jaffle_shop.staging__payments + select * from staging.payments ), diff --git a/examples/jaffle_shop/views/core/orders.sql.jinja b/examples/jaffle_shop/views/core/orders.sql.jinja index e98f68f..d81b6e8 100644 --- a/examples/jaffle_shop/views/core/orders.sql.jinja +++ b/examples/jaffle_shop/views/core/orders.sql.jinja @@ -2,13 +2,13 @@ with orders as ( - select * from jaffle_shop.staging__orders + select * from staging.orders ), payments as ( - select * from jaffle_shop.staging__payments + select * from staging.payments ), diff --git a/examples/jaffle_shop/views/tests/orders_are_dated.sql b/examples/jaffle_shop/views/tests/orders_are_dated.sql index dbf4efa..59498d0 100644 --- a/examples/jaffle_shop/views/tests/orders_are_dated.sql +++ b/examples/jaffle_shop/views/tests/orders_are_dated.sql @@ -1,3 +1,3 @@ SELECT * -FROM jaffle_shop.core__orders +FROM core.orders WHERE order_date IS NULL diff --git a/lea/app/__init__.py b/lea/app/__init__.py index 2aef105..5c56cf6 100644 --- a/lea/app/__init__.py +++ b/lea/app/__init__.py @@ -6,6 +6,8 @@ import rich.console import typer +import lea + app = typer.Typer() console = rich.console.Console() @@ -27,10 +29,12 @@ def env_validate_callback(env_path: str | None): @app.command() -def prepare(production: bool = False, env: str = EnvPath): - """ """ +def prepare(views_dir: str = ViewsDir, production: bool = False, env: str = EnvPath): client = _make_client(production) - client.prepare(console) + views = lea.views.load_views(views_dir, sqlglot_dialect=client.sqlglot_dialect) + views = [view for view in views if view.schema not in {"tests", "funcs"}] + + client.prepare(views, console) @app.command() @@ -64,9 +68,13 @@ def run( # The client determines where the views will be written client = _make_client(production) + # Load views + views = lea.views.load_views(views_dir, sqlglot_dialect=client.sqlglot_dialect) + views = [view for view in views if view.schema not in {"tests", "funcs"}] + run( client=client, - views_dir=views_dir, + views=views, only=only, dry=dry, print_to_cli=print, diff --git a/lea/app/docs.py b/lea/app/docs.py index 252bd87..afc3548 100644 --- a/lea/app/docs.py +++ b/lea/app/docs.py @@ -20,7 +20,6 @@ def docs( # List all the relevant views views = lea.views.load_views(views_dir, sqlglot_dialect=client.sqlglot_dialect) views = [view for view in views if view.schema not in {"tests", "funcs"}] - console.log(f"Found {len(views):,d} views") # Organize the views into a directed acyclic graph dag = lea.views.DAGOfViews(views) @@ -32,7 +31,7 @@ def docs( readme_content = io.StringIO() readme_content.write("# Views\n\n") readme_content.write("## Schemas\n\n") - for schema in dag.schemas: + for schema in sorted(dag.schemas): readme_content.write(f"- [`{schema}`](./{schema})\n") content = io.StringIO() @@ -44,18 +43,18 @@ def docs( # Write down table of contents content.write("## Table of contents\n\n") - for view in sorted(dag.values(), key=lambda view: view.name): + for view in sorted(dag.values(), key=lambda view: view.key): if view.schema != schema: continue - content.write(f"- [{view.name}](#{view.name})\n") + content.write(f"- [{view}](#{view})\n") content.write("\n") # Write down the views content.write("## Views\n\n") - for view in sorted(dag.values(), key=lambda view: view.name): + for view in sorted(dag.values(), key=lambda view: view.key): if view.schema != schema: continue - content.write(f"### {view.name}\n\n") + content.write(f"### {view}\n\n") if view.description: content.write(f"{view.description}\n\n") @@ -64,7 +63,9 @@ def docs( "```sql\n" "SELECT *\n" f"FROM {client._make_view_path(view)}\n" "```\n\n" ) # Write down the columns - view_columns = columns.query(f"table == '{schema}__{view.name}'")[["column", "type"]] + view_columns = columns.query(f"view_name == '{client._make_view_path(view)}'")[ + ["column", "type"] + ] view_comments = view.extract_comments(columns=view_columns["column"].tolist()) view_columns["Description"] = ( view_columns["column"] @@ -121,3 +122,4 @@ def docs( readme = output_dir / "README.md" readme.parent.mkdir(parents=True, exist_ok=True) readme.write_text(readme_content.getvalue()) + console.log(f"Wrote {readme}", style="bold green") diff --git a/lea/app/run.py b/lea/app/run.py index 0839745..933f85c 100644 --- a/lea/app/run.py +++ b/lea/app/run.py @@ -57,8 +57,8 @@ def make_whitelist(query: str, dag: lea.views.DAGOfViews) -> set: >>> dag = lea.views.DAGOfViews(views) >>> def pprint(whitelist): - ... for schema, table in sorted(whitelist): - ... print(f'{schema}.{table}') + ... for key in sorted(whitelist): + ... print('.'.join(key)) schema.table @@ -68,6 +68,7 @@ def make_whitelist(query: str, dag: lea.views.DAGOfViews) -> set: schema.table+ (descendants) >>> pprint(make_whitelist('staging.orders+', dag)) + analytics.finance.kpis analytics.kpis core.customers core.orders @@ -100,6 +101,7 @@ def make_whitelist(query: str, dag: lea.views.DAGOfViews) -> set: schema/+ (all tables in schema with their descendants) >>> pprint(make_whitelist('staging/+', dag)) + analytics.finance.kpis analytics.kpis core.customers core.orders @@ -119,6 +121,7 @@ def make_whitelist(query: str, dag: lea.views.DAGOfViews) -> set: +schema/+ (all tables in schema with their ancestors and descendants) >>> pprint(make_whitelist('+core/+', dag)) + analytics.finance.kpis analytics.kpis core.customers core.orders @@ -126,6 +129,11 @@ def make_whitelist(query: str, dag: lea.views.DAGOfViews) -> set: staging.orders staging.payments + schema.subschema/ + + >>> pprint(make_whitelist('analytics.finance/', dag)) + analytics.finance.kpis + """ def _yield_whitelist(query, include_ancestors, include_descendants): @@ -140,27 +148,27 @@ def _yield_whitelist(query, include_ancestors, include_descendants): ) return if query.endswith("/"): - for schema, table in dag: - if schema == query[:-1]: + for key in dag: + if str(dag[key]).startswith(query[:-1]): yield from _yield_whitelist( - f"{schema}.{table}", + ".".join(key), include_ancestors=include_ancestors, include_descendants=include_descendants, ) else: - schema, table = query.split(".") - yield schema, table + key = tuple(query.split(".")) + yield key if include_ancestors: - yield from dag.list_ancestors((schema, table)) + yield from dag.list_ancestors(key) if include_descendants: - yield from dag.list_descendants((schema, table)) + yield from dag.list_descendants(key) return set(_yield_whitelist(query, include_ancestors=False, include_descendants=False)) def run( client: lea.clients.Client, - views_dir: str, + views: list[lea.views.View], only: list[str], dry: bool, print_to_cli: bool, @@ -174,8 +182,6 @@ def run( console_log = _do_nothing if print_to_cli else console.log # List the relevant views - views = lea.views.load_views(views_dir, sqlglot_dialect=client.sqlglot_dialect) - views = [view for view in views if view.schema not in {"tests", "funcs"}] console_log(f"{len(views):,d} view(s) in total") # Organize the views into a directed acyclic graph @@ -202,26 +208,22 @@ def display_progress() -> rich.table.Table: return None table = rich.table.Table(box=None) table.add_column("#", header_style="italic") - table.add_column("schema", header_style="italic") table.add_column("view", header_style="italic") table.add_column("status", header_style="italic") table.add_column("duration", header_style="italic") order_not_done = [node for node in order if node not in cache] - for i, (schema, view_name) in list(enumerate(order_not_done, start=1))[-show:]: - status = SUCCESS if (schema, view_name) in jobs_ended_at else RUNNING - status = ERRORED if (schema, view_name) in exceptions else status - status = SKIPPED if (schema, view_name) in skipped else status + for i, node in list(enumerate(order_not_done, start=1))[-show:]: + status = SUCCESS if node in jobs_ended_at else RUNNING + status = ERRORED if node in exceptions else status + status = SKIPPED if node in skipped else status duration = ( - ( - jobs_ended_at.get((schema, view_name), dt.datetime.now()) - - jobs_started_at[(schema, view_name)] - ) - if (schema, view_name) in jobs_started_at + (jobs_ended_at.get(node, dt.datetime.now()) - jobs_started_at[node]) + if node in jobs_started_at else dt.timedelta(seconds=0) ) rounded_seconds = round(duration.total_seconds(), 1) - table.add_row(str(i), schema, view_name, status, f"{rounded_seconds}s") + table.add_row(str(i), str(dag[node]), status, f"{rounded_seconds}s") return table @@ -313,8 +315,8 @@ def display_progress() -> rich.table.Table: # Summary of errors if exceptions: - for (schema, view_name), exception in exceptions.items(): - console.print(f"{schema}.{view_name}", style="bold red") + for node, exception in exceptions.items(): + console.print(str(dag[node]), style="bold red") console.print(exception) if raise_exceptions: diff --git a/lea/app/test.py b/lea/app/test.py index 4e51db2..f2de262 100644 --- a/lea/app/test.py +++ b/lea/app/test.py @@ -19,24 +19,30 @@ def test( columns = client.get_columns() # The client determines where the views will be written - # List the test views + + # List singular tests views = lea.views.load_views(views_dir, sqlglot_dialect=client.sqlglot_dialect) singular_tests = [view for view in views if view.schema == "tests"] console.log(f"Found {len(singular_tests):,d} singular tests") - generic_tests = [] - for view in views: - view_columns = columns.query(f"table == '{view.schema}__{view.name}'")["column"].tolist() - for generic_test in client.yield_unit_tests(view=view, view_columns=view_columns): - generic_tests.append(generic_test) - console.log(f"Found {len(generic_tests):,d} generic tests") + # List assertion tests + assertion_tests = [] + for view in filter(lambda v: v.schema not in {"funcs", "tests"}, views): + view_columns = columns.query(f"view_name == '{client._make_view_path(view)}'")[ + "column" + ].tolist() + + for test in client.yield_unit_tests(view=view, view_columns=view_columns): + assertion_tests.append(test) + console.log(f"Found {len(assertion_tests):,d} assertion tests") # Determine which tests need to be run - tests = singular_tests + generic_tests - blacklist = set(t.name for t in tests).difference(only) if only else set() + tests = singular_tests + assertion_tests + blacklist = set(t.key for t in tests).difference(only) if only else set() console.log(f"{len(tests) - len(blacklist):,d} test(s) selected") - tests = [test for test in tests if test.name not in blacklist] + tests = [test for test in tests if test.key not in blacklist] + # Run tests concurrently with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: jobs = {executor.submit(client.load, test): test for test in tests} for job in concurrent.futures.as_completed(jobs): diff --git a/lea/clients/__init__.py b/lea/clients/__init__.py index 2691974..88e600f 100644 --- a/lea/clients/__init__.py +++ b/lea/clients/__init__.py @@ -21,7 +21,7 @@ def make_client(production: bool): ), location=os.environ["LEA_BQ_LOCATION"], project_id=os.environ["LEA_BQ_PROJECT_ID"], - dataset_name=os.environ["LEA_SCHEMA"], + dataset_name=os.environ["LEA_BQ_DATASET_NAME"], username=username, ) elif warehouse == "duckdb": @@ -29,7 +29,6 @@ def make_client(production: bool): return DuckDB( path=os.environ["LEA_DUCKDB_PATH"], - schema=os.environ["LEA_SCHEMA"], username=username, ) else: diff --git a/lea/clients/base.py b/lea/clients/base.py index e3ac5b3..87cbe37 100644 --- a/lea/clients/base.py +++ b/lea/clients/base.py @@ -53,9 +53,9 @@ def _load_python(self, view: views.PythonView): spec.loader.exec_module(module) # Step 2: Retrieve the variable from the module's namespace - dataframe = getattr(module, view.name, None) + dataframe = getattr(module, view.key[1], None) # HACK if dataframe is None: - raise ValueError(f"Could not find variable {view.name} in {view.path}") + raise ValueError(f"Could not find variable {view.key[1]} in {view.path}") return dataframe def load(self, view: views.View): @@ -142,8 +142,9 @@ def yield_unit_tests(self, view, view_columns): if comment.text == "@UNIQUE": yield views.GenericSQLView( schema="tests", - name=f"{view.schema}.{view.name}.{column}@UNIQUE", + name=f"{view}.{column}@UNIQUE", query=self.make_test_unique_column(view, column), + sqlglot_dialect=self.sqlglot_dialect, ) else: raise ValueError(f"Unhandled tag: {comment.text}") diff --git a/lea/clients/bigquery.py b/lea/clients/bigquery.py index d7b5f21..9a9d4c9 100644 --- a/lea/clients/bigquery.py +++ b/lea/clients/bigquery.py @@ -60,7 +60,7 @@ def _make_job(self, view: views.SQLView): "destinationTable": { "projectId": self.project_id, "datasetId": self.dataset_name, - "tableId": f"{view.schema}__{view.name}" if view.schema else view.name, + "tableId": f"{self._make_view_path(view).split('.', 1)[1]}", }, "createDisposition": "CREATE_IF_NEEDED", "writeDisposition": "WRITE_TRUNCATE", @@ -68,7 +68,7 @@ def _make_job(self, view: views.SQLView): "labels": { "job_dataset": self.dataset_name, "job_schema": view.schema, - "job_table": f"{view.schema}__{view.name}" if view.schema else view.name, + "job_table": f"{self._make_view_path(view).split('.', 1)[1]}", "job_username": self.username, "job_is_github_actions": "GITHUB_ACTIONS" in os.environ, }, @@ -99,7 +99,7 @@ def _create_python(self, view: views.PythonView): def _load_sql(self, view: views.SQLView) -> pd.DataFrame: query = view.query if self.username: - query = query.replace(f"{self._dataset_name}.", f"{self.dataset_name}.") + query = query.replace(f"{self._dataset_name}_{self.username}.", f"{self.dataset_name}.") return pd.read_gbq(query, credentials=self.client._credentials) def list_existing_view_names(self): diff --git a/lea/clients/duckdb.py b/lea/clients/duckdb.py index 90d8105..05e1317 100644 --- a/lea/clients/duckdb.py +++ b/lea/clients/duckdb.py @@ -11,9 +11,8 @@ class DuckDB(Client): - def __init__(self, path: str, schema: str, username: str): + def __init__(self, path: str, username: str | None): self.path = path - self._schema = schema self.username = username self.con = duckdb.connect(self.path) @@ -21,13 +20,13 @@ def __init__(self, path: str, schema: str, username: str): def sqlglot_dialect(self): return "duckdb" - @property - def schema(self): - return f"{self._schema}_{self.username}" if self.username else self._schema - - def prepare(self, console): - self.con.sql(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") - console.log(f"Created schema {self.schema}") + def prepare(self, views, console): + schemas = set( + f"{view.schema}_{self.username}" if self.username else view.schema for view in views + ) + for schema in schemas: + self.con.sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") + console.log(f"Created schema {schema}") def _create_python(self, view: views.PythonView): dataframe = self._load_python(view) # noqa: F841 @@ -36,13 +35,17 @@ def _create_python(self, view: views.PythonView): ) def _create_sql(self, view: views.SQLView): - query = view.query.replace(f"{self._schema}.", f"{self.schema}.") + query = view.query + if self.username: + for schema, *_ in view.dependencies: + query = query.replace(f"{schema}.", f"{schema}_{self.username}.") self.con.sql(f"CREATE OR REPLACE TABLE {self._make_view_path(view)} AS ({query})") def _load_sql(self, view: views.SQLView): query = view.query if self.username: - query = query.replace(f"{self._schema}.", f"{self.schema}.") + for schema, *_ in view.dependencies: + query = query.replace(f"{schema}.", f"{schema}_{self.username}.") return self.con.cursor().sql(query).df() def delete_view(self, view: views.View): @@ -56,24 +59,25 @@ def list_existing_view_names(self) -> list[tuple[str, str]]: return [(r["table_schema"], r["table_name"]) for r in results.to_dict(orient="records")] def get_columns(self, schema=None) -> pd.DataFrame: - schema = schema or self.schema query = f""" SELECT - table_name AS table, + table_schema || '.' || table_name AS view_name, column_name AS column, data_type AS type FROM information_schema.columns - WHERE table_schema = '{schema}' """ return self.con.sql(query).df() def _make_view_path(self, view: views.View) -> str: - return f"{self.schema}.{view.dunder_name}" + schema, *leftover = view.key + schema = f"{schema}_{self.username}" if self.username else schema + return f"{schema}.{'__'.join(leftover)}" def make_test_unique_column(self, view: views.View, column: str) -> str: + schema, *leftover = view.key return f""" SELECT {column}, COUNT(*) AS n - FROM {self._make_view_path(view)} + FROM {f"{schema}.{'__'.join(leftover)}"} GROUP BY {column} HAVING n > 1 """ diff --git a/lea/views/__init__.py b/lea/views/__init__.py index d9dcdfa..32d3181 100644 --- a/lea/views/__init__.py +++ b/lea/views/__init__.py @@ -24,7 +24,7 @@ def _load_view_from_path(path, origin, sqlglot_dialect): if path.suffix == ".py": return PythonView(origin, relative_path) if path.suffix == ".sql" or path.suffixes == [".sql", ".jinja"]: - return SQLView(origin, relative_path, dialect=sqlglot_dialect) + return SQLView(origin, relative_path, sqlglot_dialect=sqlglot_dialect) return [ _load_view_from_path(path, origin=views_dir, sqlglot_dialect=sqlglot_dialect) diff --git a/lea/views/base.py b/lea/views/base.py index 13cb378..4e42f96 100644 --- a/lea/views/base.py +++ b/lea/views/base.py @@ -24,15 +24,8 @@ def schema(self): return self.relative_path.parts[0] @property - def name(self): - name_parts = itertools.chain( - self.relative_path.parts[1:-1], [self.relative_path.name.split(".")[0]] - ) - return "__".join(name_parts) - - @property - def dunder_name(self): - return f"{self.schema}__{self.name}" + def key(self): + return tuple([*self.relative_path.parts[:-1], self.relative_path.name.split(".")[0]]) @property @abc.abstractmethod diff --git a/lea/views/dag.py b/lea/views/dag.py index f31839f..0762f92 100644 --- a/lea/views/dag.py +++ b/lea/views/dag.py @@ -10,20 +10,20 @@ class DAGOfViews(graphlib.TopologicalSorter, collections.UserDict): def __init__(self, views: list[lea.views.View]): - view_to_dependencies = {(view.schema, view.name): view.dependencies for view in views} + view_to_dependencies = {view.key: view.dependencies for view in views} graphlib.TopologicalSorter.__init__(self, view_to_dependencies) - collections.UserDict.__init__(self, {(view.schema, view.name): view for view in views}) + collections.UserDict.__init__(self, {view.key: view for view in views}) self.dependencies = view_to_dependencies self.prepare() @property - def schemas(self): - return sorted(set(schema for schema, _ in self)) + def schemas(self) -> set: + return set(schema for schema, *_ in self) @property def schema_dependencies(self): deps = collections.defaultdict(set) - for (src_schema, _), dsts in self.dependencies.items(): + for (src_schema, *_), dsts in self.dependencies.items(): deps[src_schema].update([schema for schema, _ in dsts if schema != src_schema]) return deps diff --git a/lea/views/python.py b/lea/views/python.py index 8638a1c..7c612f4 100644 --- a/lea/views/python.py +++ b/lea/views/python.py @@ -45,4 +45,4 @@ def extract_comments(self, columns: list[str]): return {} def __repr__(self): - return f"{self.schema}.{self.name}" + return ".".join(self.key) diff --git a/lea/views/sql.py b/lea/views/sql.py index 0a51621..2acfdb4 100644 --- a/lea/views/sql.py +++ b/lea/views/sql.py @@ -35,10 +35,10 @@ def last_line(self): @dataclasses.dataclass class SQLView(View): - dialect: sqlglot.Dialect + sqlglot_dialect: sqlglot.Dialect def __repr__(self): - return f"{self.schema}.{self.name}" + return ".".join(self.key) @property def query(self): @@ -51,7 +51,7 @@ def query(self): return text def parse_dependencies(self, query): - parse = sqlglot.parse_one(query, dialect=self.dialect) + parse = sqlglot.parse_one(query, dialect=self.sqlglot_dialect) cte_names = {(None, cte.alias) for cte in parse.find_all(sqlglot.exp.CTE)} table_names = { (table.sql().split(".")[0], table.name) @@ -69,7 +69,7 @@ def dependencies(self): return self.parse_dependencies(self.query) except sqlglot.errors.ParseError: warnings.warn( - f"SQLGlot couldn't parse {self.path} with dialect {self.dialect}. Falling back to regex." + f"SQLGlot couldn't parse {self.path} with dialect {self.sqlglot_dialect}. Falling back to regex." ) dependencies = set() for match in re.finditer( @@ -98,7 +98,7 @@ def description(self): ) def extract_comments(self, columns: list[str]) -> dict[str, CommentBlock]: - dialect = sqlglot.Dialect.get_or_raise(self.dialect)() + dialect = sqlglot.Dialect.get_or_raise(self.sqlglot_dialect)() tokens = dialect.tokenizer.tokenize(self.query) # Extract comments, which are lines that start with -- @@ -159,19 +159,24 @@ def is_var_line(line): class GenericSQLView(SQLView): - def __init__(self, schema, name, query): + def __init__(self, schema, name, query, sqlglot_dialect): self._schema = schema self._name = name self._query = textwrap.dedent(query) + self._sqlglot_dialect = sqlglot_dialect @property def schema(self): return self._schema @property - def name(self): - return self._name + def key(self): + return (self._name,) @property def query(self): return self._query + + @property + def sqlglot_dialect(self): + return self._sqlglot_dialect diff --git a/tests/test_examples.py b/tests/test_examples.py index 21d129c..421090e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -20,15 +20,10 @@ def test_jaffle_shop(): # Write .env file with open(env_path, "w") as f: - f.write( - "LEA_SCHEMA=jaffle_shop\n" - "LEA_USERNAME=max\n" - "LEA_WAREHOUSE=duckdb\n" - "LEA_DUCKDB_PATH=duckdb.db\n" - ) + f.write("LEA_USERNAME=max\n" "LEA_WAREHOUSE=duckdb\n" "LEA_DUCKDB_PATH=duckdb.db\n") # Prepare - result = runner.invoke(app, ["prepare", "--env", env_path]) + result = runner.invoke(app, ["prepare", views_path, "--env", env_path]) assert result.exit_code == 0 # Run @@ -38,21 +33,21 @@ def test_jaffle_shop(): # Check number of tables created con = duckdb.connect("duckdb.db") tables = con.sql("SELECT table_schema, table_name FROM information_schema.tables").df() - assert tables.shape[0] == 6 + assert tables.shape[0] == 7 # Check number of rows in core__customers - customers = con.sql("SELECT * FROM jaffle_shop_max.core__customers").df() + customers = con.sql("SELECT * FROM core_max.customers").df() assert customers.shape[0] == 100 # Check number of rows in core__orders - orders = con.sql("SELECT * FROM jaffle_shop_max.core__orders").df() + orders = con.sql("SELECT * FROM core_max.orders").df() assert orders.shape[0] == 99 # Run unit tests result = runner.invoke(app, ["test", views_path, "--env", env_path]) assert result.exit_code == 0 - assert "Found 1 generic tests" in result.stdout assert "Found 1 singular tests" in result.stdout + assert "Found 1 assertion tests" in result.stdout assert "SUCCESS" in result.stdout # Build docs @@ -74,3 +69,4 @@ def test_jaffle_shop(): assert (docs_path / "README.md").exists() assert (docs_path / "core" / "README.md").exists() assert (docs_path / "staging" / "README.md").exists() + assert (docs_path / "analytics" / "README.md").exists()