From 4c270d4a922e2ec6291b119caab8661449e40cf7 Mon Sep 17 00:00:00 2001 From: Max Halford Date: Mon, 16 Oct 2023 22:44:33 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitmodules | 3 ++ .pre-commit-config.yaml | 23 +++++++++++++++ CONTRIBUTING.md | 45 +++++++++++++++++++++++++++++ examples/jaffle_shop/docs/README.md | 2 +- examples/jaffle_shop/jaffle_shop | 1 + lea/app/__init__.py | 20 ++++++------- lea/app/diff.py | 16 +++------- lea/app/docs.py | 9 ++---- lea/app/run.py | 23 +++++---------- lea/app/test.py | 9 ++---- lea/clients/__init__.py | 1 + lea/clients/base.py | 28 +++++++++++------- lea/clients/bigquery.py | 17 +++++------ lea/clients/duckdb.py | 16 ++++------ lea/views/__init__.py | 14 +++++++-- lea/views/dag.py | 12 ++++---- lea/views/python.py | 4 +-- lea/views/sql.py | 35 ++++++++++------------ pyproject.toml | 22 ++++++++++++++ tests/test_examples.py | 14 +++++++-- 20 files changed, 196 insertions(+), 118 deletions(-) create mode 100644 .gitmodules create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.md create mode 160000 examples/jaffle_shop/jaffle_shop diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..7416538 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "examples/jaffle_shop/jaffle_shop"] + path = examples/jaffle_shop/jaffle_shop + url = https://github.com/dbt-labs/jaffle_shop/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..5b2e656 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +files: . +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-json + - id: check-yaml + - id: trailing-whitespace + - id: mixed-line-ending + + - repo: local + hooks: + - id: black + name: black + language: python + types: [python] + entry: black + + - id: ruff + name: ruff + language: python + types: [python] + entry: ruff diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..0de5102 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,45 @@ +# Contributing + +## Setup + +Start by cloning the repository: + +```sh +git clone https://github.com/carbonfact/lea +``` + +Next, you'll need a Python environment: + +```sh +pyenv install -v 3.11 +``` + +You'll also need [Poetry](https://python-poetry.org/): + +```sh +curl -sSL https://install.python-poetry.org | python3 - +poetry install +poetry shell +``` + +## Testing + +You can run tests once the environment is set up: + +```sh +pytest +``` + +## Code quality + +Install the code quality routine so that it runs each time you try to push your commits. + +```sh +pre-commit install --hook-type pre-push +``` + +You can also run the code quality routine ad-hoc. + +```sh +pre-commit run --all-files +``` diff --git a/examples/jaffle_shop/docs/README.md b/examples/jaffle_shop/docs/README.md index 0909ef2..e5553bf 100644 --- a/examples/jaffle_shop/docs/README.md +++ b/examples/jaffle_shop/docs/README.md @@ -10,8 +10,8 @@ ```mermaid %%{init: {"flowchart": {"defaultRenderer": "elk"}} }%% flowchart TB - staging(staging) core(core) + staging(staging) staging --> core ``` diff --git a/examples/jaffle_shop/jaffle_shop b/examples/jaffle_shop/jaffle_shop new file mode 160000 index 0000000..b0b77aa --- /dev/null +++ b/examples/jaffle_shop/jaffle_shop @@ -0,0 +1 @@ +Subproject commit b0b77aac70f490770a1e77c02bb0a2b8771d3203 diff --git a/lea/app/__init__.py b/lea/app/__init__.py index 2adb1ce..ccfcd94 100644 --- a/lea/app/__init__.py +++ b/lea/app/__init__.py @@ -28,18 +28,14 @@ def env_validate_callback(env_path: str | None): @app.command() def prepare(production: bool = False, env: str = EnvPath): - """ - - """ + """ """ client = _make_client(production) client.prepare(console) @app.command() def teardown(production: bool = False, env: str = EnvPath): - """ - - """ + """ """ if production: raise ValueError("This is a dangerous operation, so it is not allowed in production.") @@ -61,7 +57,7 @@ def run( threads: int = 8, show: int = 20, raise_exceptions: bool = False, - env: str = EnvPath + env: str = EnvPath, ): from lea.app.run import run @@ -92,7 +88,7 @@ def test( threads: int = 8, production: bool = False, raise_exceptions: bool = False, - env: str = EnvPath + env: str = EnvPath, ): from lea.app.test import test @@ -110,7 +106,12 @@ def test( @app.command() -def docs(views_dir: str = ViewsDir, output_dir: str = "docs", production: bool = False, env: str = EnvPath): +def docs( + views_dir: str = ViewsDir, + output_dir: str = "docs", + production: bool = False, + env: str = EnvPath, +): from lea.app.docs import docs client = _make_client(production=production) @@ -118,7 +119,6 @@ def docs(views_dir: str = ViewsDir, output_dir: str = "docs", production: bool = docs(views_dir=views_dir, output_dir=output_dir, client=client, console=console) - @app.command() def diff(origin: str, destination: str, env: str = EnvPath): from lea.app.diff import calculate_diff diff --git a/lea/app/diff.py b/lea/app/diff.py index fb2e532..ec4a989 100644 --- a/lea/app/diff.py +++ b/lea/app/diff.py @@ -7,28 +7,20 @@ def calculate_diff(origin: str, destination: str, client: lea.clients.Client) -> str: - diff_table = client.get_diff_summary( - origin=origin, destination=destination - ) + diff_table = client.get_diff_summary(origin=origin, destination=destination) if diff_table.empty: return "No field additions or removals detected" removed_tables = set( - diff_table[ - diff_table.column.isnull() & (diff_table.diff_kind == "REMOVED") - ].table + diff_table[diff_table.column.isnull() & (diff_table.diff_kind == "REMOVED")].table ) added_tables = set( - diff_table[ - diff_table.column.isnull() & (diff_table.diff_kind == "ADDED") - ].table + diff_table[diff_table.column.isnull() & (diff_table.diff_kind == "ADDED")].table ) buffer = io.StringIO() print_ = functools.partial(print, file=buffer) - for table, columns in diff_table[diff_table.column.notnull()].groupby( - "table" - ): + for table, columns in diff_table[diff_table.column.notnull()].groupby("table"): if table in removed_tables: print_(f"- {table}") elif table in added_tables: diff --git a/lea/app/docs.py b/lea/app/docs.py index 0d0b011..e430251 100644 --- a/lea/app/docs.py +++ b/lea/app/docs.py @@ -62,16 +62,11 @@ def docs( # Write down the query content.write( - "```sql\n" - "SELECT *\n" - f"FROM {client._make_view_path(view)}\n" - "```\n\n" + "```sql\n" "SELECT *\n" f"FROM {client._make_view_path(view)}\n" "```\n\n" ) # Write down the columns view_columns = columns.query(f"table == '{schema}__{view.name}'")[["column", "type"]] - view_comments = view.extract_comments( - columns=view_columns["column"].tolist() - ) + view_comments = view.extract_comments(columns=view_columns["column"].tolist()) view_columns["Description"] = ( view_columns["column"] .map( diff --git a/lea/app/run.py b/lea/app/run.py index 4fa1c9f..76faf4a 100644 --- a/lea/app/run.py +++ b/lea/app/run.py @@ -49,12 +49,12 @@ def make_blacklist(dag: lea.views.DAGOfViews, only: list) -> set: for schema, table in only: # Ancestors - if schema.startswith('+'): + if schema.startswith("+"): blacklist.difference_update(dag.list_ancestors((schema[1:], table))) schema = schema[1:] # Descendants - if table.endswith('+'): + if table.endswith("+"): blacklist.difference_update(dag.list_descendants((schema, table[:-1]))) table = table[:-1] @@ -138,17 +138,13 @@ def display_progress() -> rich.table.Table: exceptions = {} skipped = set() cache_path = pathlib.Path(".cache.pkl") - cache = ( - set() - if fresh or not cache_path.exists() - else pickle.loads(cache_path.read_bytes()) - ) + cache = set() if fresh or not cache_path.exists() else pickle.loads(cache_path.read_bytes()) tic = time.time() console_log(f"{len(cache):,d} view(s) already done") with rich.live.Live( - display_progress() , vertical_overflow="visible", refresh_per_second=6 + display_progress(), vertical_overflow="visible", refresh_per_second=6 ) as live: while dag.is_active(): # We check if new views have been unlocked @@ -169,10 +165,7 @@ def display_progress() -> rich.table.Table: # A node can only be computed if all its dependencies have been computed # If all the dependencies have not been computed succesfully, we skip the node - if any( - dep in skipped or dep in exceptions - for dep in dag[node].dependencies - ): + if any(dep in skipped or dep in exceptions for dep in dag[node].dependencies): skipped.add(node) dag.done(node) continue @@ -180,7 +173,8 @@ def display_progress() -> rich.table.Table: jobs[node] = executor.submit( _do_nothing if dry or node in cache - else functools.partial(pretty_print_view, view=dag[node], console=console) if print_to_cli + else functools.partial(pretty_print_view, view=dag[node], console=console) + if print_to_cli else functools.partial(client.create, view=dag[node]) ) jobs_started_at[node] = dt.datetime.now() @@ -200,8 +194,7 @@ def display_progress() -> rich.table.Table: cache = ( set() if all_done - else cache - | {node for node in order if node not in exceptions and node not in skipped} + else cache | {node for node in order if node not in exceptions and node not in skipped} ) if cache: cache_path.write_bytes(pickle.dumps(cache)) diff --git a/lea/app/test.py b/lea/app/test.py index 73effe5..4e51db2 100644 --- a/lea/app/test.py +++ b/lea/app/test.py @@ -26,9 +26,7 @@ def test( generic_tests = [] for view in views: - view_columns = columns.query(f"table == '{view.schema}__{view.name}'")[ - "column" - ].tolist() + view_columns = columns.query(f"table == '{view.schema}__{view.name}'")["column"].tolist() for generic_test in client.yield_unit_tests(view=view, view_columns=view_columns): generic_tests.append(generic_test) console.log(f"Found {len(generic_tests):,d} generic tests") @@ -40,10 +38,7 @@ def test( tests = [test for test in tests if test.name not in blacklist] with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: - jobs = { - executor.submit(client.load, test): test - for test in tests - } + jobs = {executor.submit(client.load, test): test for test in tests} for job in concurrent.futures.as_completed(jobs): test = jobs[job] conflicts = job.result() diff --git a/lea/clients/__init__.py b/lea/clients/__init__.py index 52c4695..4034032 100644 --- a/lea/clients/__init__.py +++ b/lea/clients/__init__.py @@ -27,6 +27,7 @@ def make_client(production: bool): ) elif warehouse == "duckdb": from lea.clients.duckdb import DuckDB + return DuckDB( path=os.environ["LEA_DUCKDB_PATH"], schema=os.environ["LEA_SCHEMA"], diff --git a/lea/clients/base.py b/lea/clients/base.py index c18f022..f966ecd 100644 --- a/lea/clients/base.py +++ b/lea/clients/base.py @@ -80,8 +80,15 @@ def get_columns(self, schema: str) -> pd.DataFrame: def get_diff_summary(self, origin: str, destination: str) -> pd.DataFrame: - origin_columns = set(map(tuple, self.get_columns(origin)[["table", "column"]].values.tolist())) - destination_columns = set(map(tuple, self.get_columns(destination)[["table", "column"]].values.tolist())) + origin_columns = set( + map(tuple, self.get_columns(origin)[["table", "column"]].values.tolist()) + ) + destination_columns = set( + map( + tuple, + self.get_columns(destination)[["table", "column"]].values.tolist(), + ) + ) return pd.DataFrame( [ @@ -90,25 +97,25 @@ def get_diff_summary(self, origin: str, destination: str) -> pd.DataFrame: "column": None, "diff_kind": "ADDED", } - for table in {t for t, _ in origin_columns} - {t for t, _ in destination_columns} - ] + - [ + for table in {t for t, _ in origin_columns} - {t for t, _ in destination_columns} + ] + + [ { "table": table, "column": column, "diff_kind": "ADDED", } for table, column in origin_columns - destination_columns - ] + - [ + ] + + [ { "table": table, "column": None, "diff_kind": "REMOVED", } - for table in {t for t, _ in destination_columns } - {t for t, _ in origin_columns} - ] + - [ + for table in {t for t, _ in destination_columns} - {t for t, _ in origin_columns} + ] + + [ { "table": table, "column": column, @@ -118,7 +125,6 @@ def get_diff_summary(self, origin: str, destination: str) -> pd.DataFrame: ] ) - @abc.abstractmethod def make_test_unique_column(self, view: views.View, column: str) -> str: ... diff --git a/lea/clients/bigquery.py b/lea/clients/bigquery.py index 25e5ca4..d7b5f21 100644 --- a/lea/clients/bigquery.py +++ b/lea/clients/bigquery.py @@ -29,15 +29,14 @@ def sqlglot_dialect(self): @property def dataset_name(self): - return ( - f"{self._dataset_name}_{self.username}" - if self.username - else self._dataset_name - ) + return f"{self._dataset_name}_{self.username}" if self.username else self._dataset_name def create_dataset(self): from google.cloud import bigquery - dataset_ref = bigquery.DatasetReference(project=self.project_id, dataset_id=self.dataset_name) + + dataset_ref = bigquery.DatasetReference( + project=self.project_id, dataset_id=self.dataset_name + ) dataset = bigquery.Dataset(dataset_ref) dataset.location = self.location dataset = self.client.create_dataset(dataset, exists_ok=True) @@ -64,7 +63,7 @@ def _make_job(self, view: views.SQLView): "tableId": f"{view.schema}__{view.name}" if view.schema else view.name, }, "createDisposition": "CREATE_IF_NEEDED", - "writeDisposition": "WRITE_TRUNCATE" + "writeDisposition": "WRITE_TRUNCATE", }, "labels": { "job_dataset": self.dataset_name, @@ -109,9 +108,7 @@ def list_existing_view_names(self): ] def delete_view(self, view: views.View): - self.client.delete_table( - f"{self.project_id}.{self._make_view_path(view)}" - ) + self.client.delete_table(f"{self.project_id}.{self._make_view_path(view)}") def get_columns(self, schema=None) -> pd.DataFrame: schema = schema or self.dataset_name diff --git a/lea/clients/duckdb.py b/lea/clients/duckdb.py index 2e2936e..90d8105 100644 --- a/lea/clients/duckdb.py +++ b/lea/clients/duckdb.py @@ -11,7 +11,6 @@ class DuckDB(Client): - def __init__(self, path: str, schema: str, username: str): self.path = path self._schema = schema @@ -24,11 +23,7 @@ def sqlglot_dialect(self): @property def schema(self): - return ( - f"{self._schema}_{self.username}" - if self.username - else self._schema - ) + return f"{self._schema}_{self.username}" if self.username else self._schema def prepare(self, console): self.con.sql(f"CREATE SCHEMA IF NOT EXISTS {self.schema}") @@ -36,7 +31,9 @@ def prepare(self, console): def _create_python(self, view: views.PythonView): dataframe = self._load_python(view) # noqa: F841 - self.con.sql(f"CREATE OR REPLACE TABLE {self._make_view_path(view)} AS SELECT * FROM dataframe") + self.con.sql( + f"CREATE OR REPLACE TABLE {self._make_view_path(view)} AS SELECT * FROM dataframe" + ) def _create_sql(self, view: views.SQLView): query = view.query.replace(f"{self._schema}.", f"{self.schema}.") @@ -56,10 +53,7 @@ def teardown(self): def list_existing_view_names(self) -> list[tuple[str, str]]: results = duckdb.sql("SELECT table_schema, table_name FROM information_schema.tables").df() - return [ - (r["table_schema"], r["table_name"]) - for r in results.to_dict(orient="records") - ] + return [(r["table_schema"], r["table_name"]) for r in results.to_dict(orient="records")] def get_columns(self, schema=None) -> pd.DataFrame: schema = schema or self.schema diff --git a/lea/views/__init__.py b/lea/views/__init__.py index d8da80c..7e18d0e 100644 --- a/lea/views/__init__.py +++ b/lea/views/__init__.py @@ -10,7 +10,9 @@ from .sql import GenericSQLView, SQLView -def load_views(views_dir: pathlib.Path | str, sqlglot_dialect: sqlglot.dialects.Dialect | str) -> list[View]: +def load_views( + views_dir: pathlib.Path | str, sqlglot_dialect: sqlglot.dialects.Dialect | str +) -> list[View]: # Massage the inputs if isinstance(views_dir, str): @@ -36,5 +38,11 @@ def _load_view_from_path(path, origin, sqlglot_dialect): ] - -__all__ = ["load_views", "DAGOfViews", "View", "PythonView", "SQLView", "GenericSQLView"] +__all__ = [ + "load_views", + "DAGOfViews", + "View", + "PythonView", + "SQLView", + "GenericSQLView", +] diff --git a/lea/views/dag.py b/lea/views/dag.py index 4b6ca18..8deaf5c 100644 --- a/lea/views/dag.py +++ b/lea/views/dag.py @@ -10,13 +10,9 @@ class DAGOfViews(graphlib.TopologicalSorter, collections.UserDict): def __init__(self, views: list[lea.views.View]): - view_to_dependencies = { - (view.schema, view.name): view.dependencies for view in views - } + view_to_dependencies = {(view.schema, view.name): view.dependencies for view in views} graphlib.TopologicalSorter.__init__(self, view_to_dependencies) - collections.UserDict.__init__( - self, {(view.schema, view.name): view for view in views} - ) + collections.UserDict.__init__(self, {(view.schema, view.name): view for view in views}) self.dependencies = view_to_dependencies self.prepare() @@ -33,19 +29,23 @@ def schema_dependencies(self): def list_ancestors(self, node): """Returns a list of all the ancestors for a given node.""" + def _list_ancestors(node): for child in self.dependencies.get(node, []): yield child yield from _list_ancestors(child) + return list(_list_ancestors(node)) def list_descendants(self, node): """Returns a list of all the descendants for a given node.""" + def _list_descendants(node): for parent in self.dependencies: if node in self.dependencies[parent]: yield parent yield from _list_descendants(parent) + return list(_list_descendants(node)) def _to_mermaid_views(self): diff --git a/lea/views/python.py b/lea/views/python.py index 05d07a0..8280d5d 100644 --- a/lea/views/python.py +++ b/lea/views/python.py @@ -26,9 +26,7 @@ def _dependencies(): # .query try: - if isinstance(node, ast.Call) and node.func.attr.startswith( - "query" - ): + if isinstance(node, ast.Call) and node.func.attr.startswith("query"): yield from SQLView.parse_dependencies(node.args[0].value) except AttributeError: pass diff --git a/lea/views/sql.py b/lea/views/sql.py index ae2fd89..6e30b7c 100644 --- a/lea/views/sql.py +++ b/lea/views/sql.py @@ -47,9 +47,7 @@ def query(self): loader = jinja2.FileSystemLoader(self.origin) environment = jinja2.Environment(loader=loader) template = environment.get_template(str(self.relative_path)) - return template.render( - env=os.environ - ) + return template.render(env=os.environ) return text def parse_dependencies(self, query): @@ -70,13 +68,20 @@ def dependencies(self): try: return self.parse_dependencies(self.query) except sqlglot.errors.ParseError: - warnings.warn(f"SQLGlot couldn't parse {self.path} with dialect {self.dialect}. Falling back to regex.") + warnings.warn( + f"SQLGlot couldn't parse {self.path} with dialect {self.dialect}. Falling back to regex." + ) dependencies = set() for match in re.finditer( - r"(JOIN|FROM)\s+(?P[a-z][a-z_]+[a-z])\.(?P[a-z][a-z_]+[a-z])", self.query, re.IGNORECASE + r"(JOIN|FROM)\s+(?P[a-z][a-z_]+[a-z])\.(?P[a-z][a-z_]+[a-z])", + self.query, + re.IGNORECASE, ): schema, view_name = ( - (match.group("view").split("__")[0], match.group("view").split("__", 1)[1]) + ( + match.group("view").split("__")[0], + match.group("view").split("__", 1)[1], + ) if "__" in match.group("view") else (match.group("schema"), match.group("view")) ) @@ -92,9 +97,7 @@ def description(self): ) ) - def extract_comments( - self, columns: list[str] - ) -> dict[str, CommentBlock]: + def extract_comments(self, columns: list[str]) -> dict[str, CommentBlock]: dialect = sqlglot.Dialect.get_or_raise(self.dialect)() tokens = dialect.tokenizer.tokenize(self.query) @@ -115,11 +118,7 @@ def extract_comments( change = False for comment_block in comment_blocks: next_comment_block = next( - ( - cb - for cb in comment_blocks - if cb.first_line == comment_block.last_line + 1 - ), + (cb for cb in comment_blocks if cb.first_line == comment_block.last_line + 1), None, ) if next_comment_block: @@ -132,15 +131,11 @@ def extract_comments( # We assume the tokens are stored. Therefore, by looping over them and building a dictionary, # each key will be unique and the last value will be the last variable in the line. var_tokens = [ - token - for token in tokens - if token.token_type.value == "VAR" and token.text in columns + token for token in tokens if token.token_type.value == "VAR" and token.text in columns ] def is_var_line(line): - line_tokens = [ - t for t in tokens if t.line == line and t.token_type.value != "COMMA" - ] + line_tokens = [t for t in tokens if t.line == line and t.token_type.value != "COMMA"] return line_tokens[-1].token_type.value == "VAR" last_var_per_line = { diff --git a/pyproject.toml b/pyproject.toml index 4c1ad59..ca4027d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,3 +29,25 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] lea = "lea.main:app" + +[tool.black] +line-length = 100 +target-version = ['py310'] + +[tool.ruff] +select = ["E", "F", "I", "UP"] # https://beta.ruff.rs/docs/rules/ +line-length = 100 +target-version = 'py310' +ignore = ["E501"] + +[tool.ruff.isort] +required-imports = ["from __future__ import annotations"] + +[tool.pytest.ini_options] +addopts = [ + "--doctest-modules", + "--doctest-glob=README.md", + "--ignore=lea/examples", + "--verbose", + "--color=yes" +] diff --git a/tests/test_examples.py b/tests/test_examples.py index a0fbb6c..107c405 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -11,8 +11,8 @@ runner = CliRunner() -def test_jaffle_shop(): +def test_jaffle_shop(): app = make_app(make_client=make_client) here = pathlib.Path(__file__).parent env_path = str((here.parent / "examples" / "jaffle_shop" / ".env").absolute()) @@ -58,7 +58,17 @@ def test_jaffle_shop(): # Build docs docs_path = here.parent / "examples" / "jaffle_shop" / "docs" shutil.rmtree(docs_path, ignore_errors=True) - result = runner.invoke(app, ["docs", views_path, "--env", env_path, "--output-dir", str(docs_path.absolute())]) + result = runner.invoke( + app, + [ + "docs", + views_path, + "--env", + env_path, + "--output-dir", + str(docs_path.absolute()), + ], + ) assert result.exit_code == 0 assert docs_path.exists() assert (docs_path / "README.md").exists()