diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 84e6415..3bf028e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -3,26 +3,12 @@ on: [pull_request] permissions: contents: write jobs: - check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: '3.10' - cache: 'pip' - - uses: extractions/setup-just@v1 - - run: just install - - run: | - git fetch origin - pre-commit run --from-ref origin/${{ github.event.pull_request.base.ref }} --to-ref ${{ github.event.pull_request.head.sha }} - test: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: [ "3.10", "3.11" ] + python-version: [ "3.10", "3.11", "3.12" ] steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/.gitignore b/.gitignore index 37c1e9c..94e914b 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ test_binary workflow *.pyz orbiter-* +.python-version diff --git a/docs/Rules_and_Rulesets/index.md b/docs/Rules_and_Rulesets/index.md index 6195811..c34a475 100644 --- a/docs/Rules_and_Rulesets/index.md +++ b/docs/Rules_and_Rulesets/index.md @@ -7,7 +7,3 @@ show_root_toc_entry: false ::: orbiter.rules.rulesets.translate - -::: orbiter.rules.rulesets.load_filetype - -::: orbiter.rules.rulesets.xmltodict_parse diff --git a/docs/Rules_and_Rulesets/rulesets.md b/docs/Rules_and_Rulesets/rulesets.md index d4d584d..7f0fd1e 100644 --- a/docs/Rules_and_Rulesets/rulesets.md +++ b/docs/Rules_and_Rulesets/rulesets.md @@ -1,24 +1,35 @@ +# Translation ::: orbiter.rules.rulesets.TranslationRuleset options: separate_signature: true show_signature_annotations: true signature_crossrefs: true +::: orbiter.rules.rulesets.xmltodict_parse +## Rulesets ::: orbiter.rules.rulesets.Ruleset + options: + heading_level: 3 ::: orbiter.rules.rulesets.DAGFilterRuleset options: + heading_level: 3 show_bases: true ::: orbiter.rules.rulesets.DAGRuleset options: + heading_level: 3 show_bases: true ::: orbiter.rules.rulesets.TaskFilterRuleset options: + heading_level: 3 show_bases: true ::: orbiter.rules.rulesets.TaskRuleset options: + heading_level: 3 show_bases: true ::: orbiter.rules.rulesets.TaskDependencyRuleset options: + heading_level: 3 show_bases: true ::: orbiter.rules.rulesets.PostProcessingRuleset options: + heading_level: 3 show_bases: true diff --git a/docs/index.md b/docs/index.md index a278dac..1b668f3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -25,7 +25,7 @@ from an [Origin](./origins) system to an Airflow project. ## Installation -You can install the [`orbiter` CLI](./CLI), if you have Python >= 3.10 installed via `pip`: +Install the [`orbiter` CLI](./CLI), if you have Python >= 3.10 installed via `pip`: ```shell pip install astronomer-orbiter ``` @@ -33,8 +33,8 @@ If you do not have a compatible Python environment, pre-built binary executables are available for download on the [Releases](https://github.com/astronomer/orbiter/releases) page. ## Translate -You can utilize the [`orbiter` CLI](./cli) with existing translations to convert workflows -from other systems to Apache Airflow. +Utilize the [`orbiter` CLI](./cli) with existing translations to convert workflows +from other systems to an Airflow project. 1. Set up a new folder, and create a `workflow/` folder. Add your workflows files to it ```shell @@ -46,21 +46,22 @@ from other systems to Apache Airflow. ``` 2. Determine the specific translation ruleset via: 1. the [Origins](origins) documentation - 2. the [`orbiter help`](./cli#help) command + 2. the [`orbiter list-rulesets`](./cli#list-rulesets) command 3. or [by creating a translation ruleset](#authoring-rulesets-customization), if one does not exist -3. Install the specific translation ruleset via the [`orbiter install`](./cli#install) command +3. Install the translation ruleset via the [`orbiter install`](./cli#install) command (substituting `` with the value in the last step) + ```shell + orbiter install --repo= + ``` 4. Use the [`orbiter translate`](./cli#translate) command with the `` determined in the last step This will produce output to an `output/` folder: ```shell - orbiter translate workflow/ output/ --ruleset + orbiter translate workflow/ --ruleset output/ ``` 5. Review the contents of the `output/` folder. If extensions or customizations are required, review [how to extend a translation ruleset](#extend-or-customize) -6. Utilize the [`astro` CLI](https://www.astronomer.io/docs/astro/cli/overview) +6. (optional) Utilize the [`astro` CLI](https://www.astronomer.io/docs/astro/cli/overview) to run Airflow instance with your migrated workloads -7. Deploy to [Astro](https://www.astronomer.io/try-astro/) to run your translated workflows in production! 🚀 - -You can see more specifics on how to use the Orbiter CLI in the [CLI](./cli) section. +7. (optional) Deploy to [Astro](https://www.astronomer.io/try-astro/) to run your translated workflows in production! 🚀 ## Authoring Rulesets & Customization Orbiter can be extended to fit specific needs, patterns, or to support additional origins. @@ -98,12 +99,12 @@ To extend or customize an existing ruleset, you can easily modify it with simple ```shell orbiter translate workflow/ output/ --ruleset override.translation_ruleset ``` -5. Follow the remaining steps 4 -> 6 of the [Translate](#translate) instructions +5. Follow the remaining steps of the [Translate](#translate) instructions ### Authoring a new Ruleset -You can utilize the [Template `TranslationRuleset`](./Rules_and_Rulesets/template) -as a starter, to create a new [`TranslationRuleset`][orbiter.rules.rulesets.TranslationRuleset]. +You can utilize the [`TranslationRuleset` Template](./Rules_and_Rulesets/template) +to create a new [`TranslationRuleset`][orbiter.rules.rulesets.TranslationRuleset]. ## FAQ - **Can this tool convert my workflows from tool X to Airflow?** diff --git a/docs/objects/Tasks/index.md b/docs/objects/Tasks/index.md index 2d82275..f067f7c 100644 --- a/docs/objects/Tasks/index.md +++ b/docs/objects/Tasks/index.md @@ -4,7 +4,7 @@ are units of work. An Operator is a pre-defined task with specific functionality Operators can be looked up in the [Astronomer Registry](https://registry.astronomer.io/). -The easiest way to utilize an operator is to use a subclass of `OrbiterOperator` (e.g. `OrbiterBashOperator`). +The easiest way to create an operator in a translation to [use an existing subclass of `OrbiterOperator` (e.g. `OrbiterBashOperator`)](./Operators_and_Callbacks/operators). If an `OrbiterOperator` subclass doesn't exist for your use case, you can: @@ -27,7 +27,8 @@ If an `OrbiterOperator` subclass doesn't exist for your use case, you can: ) ``` -2) Create a new subclass of `OrbiterOperator` (beneficial if you are using it frequently in separate `@task_rules`) +2) Create a new subclass of `OrbiterOperator`, which can be beneficial if you are using it frequently + in separate `@task_rules` ```python from orbiter.objects.task import OrbiterOperator from orbiter.objects.requirement import OrbiterRequirement diff --git a/docs/objects/index.md b/docs/objects/index.md index f730767..1debf8d 100644 --- a/docs/objects/index.md +++ b/docs/objects/index.md @@ -4,7 +4,7 @@ are **rendered** to produce an Apache Airflow Project An [`OrbiterProject`][orbiter.objects.project.OrbiterProject] holds everything necessary to render an Airflow Project. -This is generated by a [`TranslationRuleset.translate_fn`][orbiter.rules.rulesets.TranslationRuleset]. +It is generated by a [`TranslationRuleset.translate_fn`][orbiter.rules.rulesets.TranslationRuleset]. ![Diagram of Orbiter Translation](../orbiter_diagram.png) diff --git a/docs/objects/project.md b/docs/objects/project.md index 472c1c2..49b9f8c 100644 --- a/docs/objects/project.md +++ b/docs/objects/project.md @@ -1,5 +1,5 @@ An [`OrbiterProject`][orbiter.objects.project.OrbiterProject] holds everything necessary to render an Airflow Project. -This is generated by a [`TranslationRuleset.translate_fn`][orbiter.rules.rulesets.TranslationRuleset]. +It is generated by a [`TranslationRuleset.translate_fn`][orbiter.rules.rulesets.TranslationRuleset]. ## Diagram ```mermaid diff --git a/docs/objects/dags.md b/docs/objects/workflow.md similarity index 94% rename from docs/objects/dags.md rename to docs/objects/workflow.md index fa83376..65250ee 100644 --- a/docs/objects/dags.md +++ b/docs/objects/workflow.md @@ -28,10 +28,6 @@ classDiagram --8<-- "orbiter/objects/task.py:mermaid-op-props" } - class OrbiterTaskDependency["orbiter.objects.task.OrbiterTaskDependency"] { - --8<-- "orbiter/objects/task.py:mermaid-td-props" - } - class OrbiterTimetable["orbiter.objects.timetables.OrbiterTimetable"] { --8<-- "orbiter/objects/timetables/__init__.py:mermaid-props" } diff --git a/justfile b/justfile index 0e01f88..7f9f917 100644 --- a/justfile +++ b/justfile @@ -33,6 +33,10 @@ test: test-with-coverage: {{ PYTHON }} -m pytest -c pyproject.toml --cov=./ --cov-report=xml +# Run integration tests +test-integration $MANUAL_TESTS="true": + @just test + # Run ruff and black (normally done with pre-commit) lint: ruff check . @@ -51,7 +55,7 @@ deploy-docs UPSTREAM="origin": clean # Remove temporary or build folders clean: - rm -rf build dist site *.egg-info + rm -rf build dist site *.egg-info *.pyz orbiter-* workflow output find . | grep -E "(__pycache__|\.pyc|\.pyo$$)" | xargs rm -rf # Tag as v$(.__version__) and push to Github @@ -91,13 +95,14 @@ docker-build-binary: just build-binary EOF -docker-run-binary REPO='astronomer-orbiter-translations' RULESET='orbiter_translations.control_m.xml_base.translation_ruleset': +docker-run-binary REPO='orbiter-community-translations' RULESET='orbiter_translations.oozie.xml_demo.translation_ruleset': #!/usr/bin/env bash set -euxo pipefail cat <<"EOF" | docker run --platform linux/amd64 -v `pwd`:/data -w /data -i ubuntu /bin/bash chmod +x ./orbiter-linux-x86_64 && \ set -a && source .env && set +a && \ - ./orbiter-linux-x86_64 help && \ + ./orbiter-linux-x86_64 list-rulesets && \ + mkdir -p workflow && \ LOG_LEVEL=DEBUG ./orbiter-linux-x86_64 install --repo={{REPO}} && \ LOG_LEVEL=DEBUG ./orbiter-linux-x86_64 translate workflow/ output/ --ruleset {{RULESET}} EOF diff --git a/mkdocs.yml b/mkdocs.yml index fa24c54..0dffa09 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -109,5 +109,7 @@ plugins: docstring_options: trim_doctest_flags: true show_bases: false + extensions: + - griffe_inherited_docstrings copyright: "Apache Airflow® is a trademark of the Apache Software Foundation. Copyright 2024 Astronomer, Inc." diff --git a/orbiter/__init__.py b/orbiter/__init__.py index 5df03da..42c405b 100644 --- a/orbiter/__init__.py +++ b/orbiter/__init__.py @@ -1,15 +1,19 @@ from __future__ import annotations +import os import re from enum import Enum from typing import Any, Tuple -__version__ = "1.0.1" +__version__ = "1.0.2a1" version = __version__ KG_ACCOUNT_ID = "3b189b4c-c047-4fdb-9b46-408aa2978330" +ORBITER_TASK_SUFFIX = os.getenv("ORBITER_TASK_SUFFIX", "_task") +"""By default, we add `_task` as a suffix to a task name to prevent name collision issues. This can be overridden.""" + class FileType(Enum): YAML = "YAML" diff --git a/orbiter/__main__.py b/orbiter/__main__.py index 2200f15..9a58399 100644 --- a/orbiter/__main__.py +++ b/orbiter/__main__.py @@ -170,17 +170,19 @@ def translate( Provide a specific ruleset with the `--ruleset` flag. - Run `orbiter help` to see available rulesets. + Run `orbiter list-rulesets` to see available rulesets. `INPUT_DIR` defaults to `$CWD/workflow`. `OUTPUT_DIR` defaults to `$CWD/output` + + Formats output with Ruff (https://astral.sh/ruff), by default. """ logger.debug(f"Creating output directory {output_dir}") output_dir.mkdir(parents=True, exist_ok=True) - sys.path.insert(0, os.getcwd()) logger.debug(f"Adding current directory {os.getcwd()} to sys.path") + sys.path.insert(0, os.getcwd()) if RUNNING_AS_BINARY: _add_pyz() @@ -265,7 +267,7 @@ def _bin_install(repo: str, key: str): raise NotImplementedError() _add_pyz() (_, _version) = import_from_qualname("orbiter_translations.version") - logging.info(f"Successfully installed {repo}, version: {_version}") + logger.info(f"Successfully installed {repo}, version: {_version}") # noinspection t @@ -284,8 +286,8 @@ def _bin_install(repo: str, key: str): @click.option( "-k", "--key", - help="[Optional] License Key to use for the translation ruleset.\n\n" - "Should look like 'AAAA-BBBB-1111-2222-3333-XXXX-YYYY-ZZZZ'", + help="[Optional] License Key to use for the translation ruleset. Should look like " + "`AAAA-BBBB-1111-2222-3333-XXXX-YYYY-ZZZZ`", type=str, default=None, allow_from_autoenv=True, @@ -298,7 +300,7 @@ def install( ), key: str | None, ): - """Install a new Orbiter Translation Ruleset from a repository""" + """Install a new Translation Ruleset from a repository""" if not repo: choices = [ "astronomer-orbiter-translations", @@ -329,7 +331,7 @@ def install( # noinspection PyShadowingBuiltins @orbiter.command(help="List available Translation Rulesets") -def help(): +def list_rulesets(): console = Console() table = tabulate( diff --git a/orbiter/assets/supported_origins.csv b/orbiter/assets/supported_origins.csv index fda8541..8af95bf 100644 --- a/orbiter/assets/supported_origins.csv +++ b/orbiter/assets/supported_origins.csv @@ -1,12 +1,15 @@ -Origin,Maintainer,Repository,Ruleset(s),Task Equivalent,DAG Equivalent -DAG Factory,Community,[`orbiter-community-translations`](https://github.com/astronomer/orbiter-community-translations),`orbiter_translations.dag_factory.yaml_base.translation_ruleset`,---,--- -Control M,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.control_m.json_base.translation_ruleset`,Job,Folder -⠀,⠀,⠀,`orbiter_translations.control_m.json_ssh.translation_ruleset`,⠀,⠀ +Origin,Maintainer,Repository,Ruleset,DAG Equivalent,Task Equivalent +DAG Factory,Community,[`orbiter-community-translations`](https://github.com/astronomer/orbiter-community-translations),`orbiter_translations.dag_factory.yaml_base.translation_ruleset`,DAG,Task +Control M,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.control_m.json_base.translation_ruleset`,Folder,Job +⠀⠀,⠀,⠀,`orbiter_translations.control_m.json_ssh.translation_ruleset`,⠀,⠀ ⠀,⠀,⠀,`orbiter_translations.control_m.xml_base.translation_ruleset`,⠀,⠀ ⠀,⠀,⠀,`orbiter_translations.control_m.xml_ssh.translation_ruleset`,⠀,⠀ -Automic,Astronomer,`astronomer-orbiter-translations`,WIP,Job,Job Plan +⠀,⠀,[`orbiter-community-translations`](https://github.com/astronomer/orbiter-community-translations),`orbiter_translations.control_m.xml_demo.translation_ruleset`,⠀,⠀ +Automic,Astronomer,`astronomer-orbiter-translations`,WIP,Job Plan,Job Autosys,Astronomer,`astronomer-orbiter-translations`,WIP,⠀,⠀ -JAMS,Astronomer,`astronomer-orbiter-translations`,WIP,Job,Folder⠀ -SSIS,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.ssis.xml_base.translation_ruleset`,⠀,⠀ -Oozie,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.oozie.xml_base.translation_ruleset`,Node,Workflow +JAMS,Astronomer,`astronomer-orbiter-translations`,WIP,Folder,Job +SSIS,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.ssis.xml_base.translation_ruleset`,Pipeline,Component +⠀,⠀,[`orbiter-community-translations`](https://github.com/astronomer/orbiter-community-translations),`orbiter_translations.oozie.xml_demo.translation_ruleset`,⠀,⠀ +Oozie,Astronomer,`astronomer-orbiter-translations`,`orbiter_translations.oozie.xml_base.translation_ruleset`,Workflow,Node +⠀,⠀⠀,[`orbiter-community-translations`](https://github.com/astronomer/orbiter-community-translations),`orbiter_translations.oozie.xml_demo.translation_ruleset`,⠀,⠀ **& more!**,⠀,⠀,⠀,⠀,⠀ diff --git a/orbiter/ast_helper.py b/orbiter/ast_helper.py index 1255071..1e659ed 100644 --- a/orbiter/ast_helper.py +++ b/orbiter/ast_helper.py @@ -149,6 +149,17 @@ def py_function(c: Callable): return ast.parse(inspect.getsource(c)).body[0] +def py_reference(name: str) -> ast.Expr: + """ + ```pycon + >>> render_ast(py_reference("foo")) + 'foo' + + ``` + """ + return ast.Expr(value=ast.Name(id=name)) + + def render_ast(ast_object) -> str: return ast.unparse(ast_object) diff --git a/orbiter/objects/__init__.py b/orbiter/objects/__init__.py index 28688e7..f0aebda 100644 --- a/orbiter/objects/__init__.py +++ b/orbiter/objects/__init__.py @@ -11,6 +11,16 @@ from orbiter.objects.include import OrbiterInclude +CALLBACK_KEYS = [ + "on_success_callback", + "on_failure_callback", + "sla_miss_callback", + "on_retry_callback", + "on_execute_callback", + "on_skipped_callback", +] + + def validate_imports(v): assert len(v) for i in v: @@ -24,27 +34,27 @@ def validate_imports(v): class OrbiterBase(BaseModel, ABC, arbitrary_types_allowed=True): """**AbstractBaseClass** for Orbiter objects, provides a number of properties - :param imports: List of OrbiterRequirement objects + :param imports: List of [OrbiterRequirement][orbiter.objects.requirement.OrbiterRequirement] objects :type imports: List[OrbiterRequirement] :param orbiter_kwargs: Optional dictionary of keyword arguments, to preserve what was originally parsed by a rule :type orbiter_kwargs: dict, optional - :param orbiter_conns: Optional set of OrbiterConnection objects + :param orbiter_conns: Optional set of [OrbiterConnection][orbiter.objects.connection.OrbiterConnection] objects :type orbiter_conns: Set[OrbiterConnection], optional - :param orbiter_vars: Optional set of OrbiterVariable objects - :type orbiter_vars: Set[OrbiterVariable], optional - :param orbiter_env_vars: Optional set of OrbiterEnvVar objects + :param orbiter_env_vars: Optional set of [OrbiterEnvVar][orbiter.objects.env_var.OrbiterEnvVar] objects :type orbiter_env_vars: Set[OrbiterEnvVar], optional - :param orbiter_includes: Optional set of OrbiterInclude objects + :param orbiter_includes: Optional set of [OrbiterInclude][orbiter.objects.include.OrbiterInclude] objects :type orbiter_includes: Set[OrbiterInclude], optional + :param orbiter_vars: Optional set of [OrbiterVariable][orbiter.objects.variable.OrbiterVariable] objects + :type orbiter_vars: Set[OrbiterVariable], optional """ imports: ImportList orbiter_kwargs: dict = None orbiter_conns: Set[OrbiterConnection] | None = None - orbiter_vars: Set[OrbiterVariable] | None = None orbiter_env_vars: Set[OrbiterEnvVar] | None = None orbiter_includes: Set[OrbiterInclude] | None = None + orbiter_vars: Set[OrbiterVariable] | None = None def conn_id(conn_id: str, prefix: str = "", conn_type: str = "generic") -> dict: diff --git a/orbiter/objects/callbacks/__init__.py b/orbiter/objects/callbacks/__init__.py index 2ff435c..46ca789 100644 --- a/orbiter/objects/callbacks/__init__.py +++ b/orbiter/objects/callbacks/__init__.py @@ -1,7 +1,6 @@ import ast -from abc import ABC -from orbiter.ast_helper import OrbiterASTBase, py_object +from orbiter.ast_helper import OrbiterASTBase, py_object, py_reference from orbiter.objects import OrbiterBase, ImportList, OrbiterRequirement from orbiter.objects.task import RenderAttributes @@ -12,11 +11,30 @@ """ -class OrbiterCallback(OrbiterASTBase, OrbiterBase, ABC, extra="forbid"): - """**Abstract class** representing an Airflow +class OrbiterCallback(OrbiterASTBase, OrbiterBase, extra="forbid"): + """Represents an Airflow [callback function](https://airflow.apache.org/docs/apache-airflow/stable/administration-and-deployment/logging-monitoring/callbacks.html), which might be used in `DAG.on_failure_callback`, or `Task.on_success_callback`, or etc. + Can be instantiated directly as a bare callback function (with no arguments): + ```pycon + >>> from orbiter.objects.dag import OrbiterDAG + >>> from orbiter.objects.include import OrbiterInclude + >>> my_callback = OrbiterCallback( + ... function="my_callback", + ... imports=[OrbiterRequirement(module="my_callback", names=["my_callback"])], + ... orbiter_includes={OrbiterInclude(filepath="my_callback.py", contents="...")} + ... ) + >>> OrbiterDAG(dag_id='', file_path='', on_failure_callback=my_callback) + ... # doctest: +ELLIPSIS + from airflow import DAG + from my_callback import my_callback + ... + with DAG(... on_failure_callback=my_callback): + + ``` + + or be subclassed: ```pycon >>> class OrbiterMyCallback(OrbiterCallback): ... function: str = "my_callback" @@ -27,7 +45,6 @@ class OrbiterCallback(OrbiterASTBase, OrbiterBase, ABC, extra="forbid"): my_callback(foo='fop', bar='bop') ``` - :param function: The name of the function to call :type function: str :param **OrbiterBase: [OrbiterBase][orbiter.objects.OrbiterBase] inherited properties @@ -44,11 +61,14 @@ class OrbiterCallback(OrbiterASTBase, OrbiterBase, ABC, extra="forbid"): render_attributes: RenderAttributes = [] def _to_ast(self) -> ast.Expr: - return py_object( - name=self.function, - **{ - k: getattr(self, k) - for k in self.render_attributes - if (k and getattr(self, k)) or (k == "from_email") - }, - ) + if self.render_attributes: + return py_object( + name=self.function, + **{ + k: getattr(self, k) + for k in self.render_attributes + if (k and getattr(self, k)) or (k == "from_email") + }, + ) + else: + return py_reference(self.function) diff --git a/orbiter/objects/callbacks/smtp.py b/orbiter/objects/callbacks/smtp.py index c8f84e0..67e5780 100644 --- a/orbiter/objects/callbacks/smtp.py +++ b/orbiter/objects/callbacks/smtp.py @@ -18,7 +18,7 @@ class OrbiterSmtpNotifierCallback(OrbiterCallback, extra="allow"): """ - An [Airflow SMTP Callback (link)](https://airflow.apache.org/docs/apache-airflow-providers-smtp/stable/_api/airflow/providers/smtp/notifications/smtp/index.html) + An [Airflow SMTP Callback](https://airflow.apache.org/docs/apache-airflow-providers-smtp/stable/_api/airflow/providers/smtp/notifications/smtp/index.html) !!! note diff --git a/orbiter/objects/dag.py b/orbiter/objects/dag.py index 32de5fd..bb6e40c 100644 --- a/orbiter/objects/dag.py +++ b/orbiter/objects/dag.py @@ -2,6 +2,7 @@ import ast from datetime import datetime +from functools import reduce from pathlib import Path from typing import Annotated, Any, Dict, Iterable, List, Callable @@ -10,8 +11,7 @@ from orbiter import clean_value from orbiter.ast_helper import OrbiterASTBase, py_object, py_with -from orbiter.objects import ImportList, OrbiterBase -from orbiter.objects.callbacks import OrbiterCallback +from orbiter.objects import ImportList, OrbiterBase, CALLBACK_KEYS from orbiter.objects.requirement import OrbiterRequirement from orbiter.objects.task import OrbiterOperator from orbiter.objects.task_group import OrbiterTaskGroup @@ -22,7 +22,6 @@ OrbiterDAG --> "many" OrbiterInclude OrbiterDAG --> "many" OrbiterConnection OrbiterDAG --> "many" OrbiterEnvVar -OrbiterDAG --> "many" OrbiterPool OrbiterDAG --> "many" OrbiterRequirement OrbiterDAG --> "many" OrbiterVariable --8<-- [end:mermaid-project-relationships] @@ -41,37 +40,50 @@ def _get_imports_recursively( tasks: Iterable[OrbiterOperator | OrbiterTaskGroup], ) -> List[OrbiterRequirement]: - imports = [] - extra_attributes_imports = [] + """ + + >>> from orbiter.objects.task import OrbiterTask + >>> from orbiter.objects.task_group import OrbiterTaskGroup + >>> from orbiter.objects.callbacks import OrbiterCallback + >>> _get_imports_recursively([ + ... OrbiterTask(task_id="foo", imports=[OrbiterRequirement(names=['foo'])]), + ... OrbiterTaskGroup(task_group_id="bar", imports=[OrbiterRequirement(names=['bar'])], tasks=[ + ... OrbiterTask(task_id="baz", imports=[OrbiterRequirement(names=['baz'])], + ... on_failure_callback=OrbiterCallback(imports=[OrbiterRequirement(names=['qux'])], function='qux') + ... ) + ... ]) + ... ]) + ... # doctest: +ELLIPSIS + [OrbiterRequirements(...names=[bar]...names=[baz]...names=[foo]...names=[qux]...] + + """ + imports = set() for task in tasks: - for callback in [ - callback - for callback in [ - ((task.__dict__ or {}) | (task.model_extra or {})).get( - "on_failure_callback" - ), - ((task.__dict__ or {}) | (task.model_extra or {})).get( - "on_success_callback" - ), - ] - if callback - ]: - callback: OrbiterCallback - extra_attributes_imports.extend(callback.imports) - - imports.extend( - task.imports - + extra_attributes_imports - + _get_imports_recursively(task.tasks) - if isinstance(task, OrbiterTaskGroup) - else task.imports + extra_attributes_imports - ) - return imports + # Add task imports + imports |= set(task.imports) + + def reduce_imports_from_callback(old, item): + try: + # Look for on_failure_callback + task_props = (getattr(task, "model_extra", {}) or {}) | ( + getattr(task, "__dict__", {}) or {} + ) + callback = task_props.get(item) + # get imports from callback, merge them all + return old | set(getattr(callback, "imports")) + except (AttributeError, KeyError): + return old + + imports |= reduce(reduce_imports_from_callback, CALLBACK_KEYS, set()) + if hasattr(task, "tasks"): + # descend, for a task group + imports |= set(_get_imports_recursively(task.tasks)) + return list(sorted(imports, key=str)) class OrbiterDAG(OrbiterASTBase, OrbiterBase, extra="allow"): - """Represents an Airflow - [DAG](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html), + # noinspection PyUnresolvedReferences + """Represents an Airflow [DAG](https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/dags.html), with its tasks and dependencies. Renders to a `.py` file in the `/dags` folder @@ -95,6 +107,8 @@ class OrbiterDAG(OrbiterASTBase, OrbiterBase, extra="allow"): :type params: Dict[str, Any], optional :param doc_md: Documentation for the DAG with markdown support :type doc_md: str, optional + :param kwargs: Additional keyword arguments to pass to the DAG + :type kwargs: dict, optional :param **OrbiterBase: [OrbiterBase][orbiter.objects.OrbiterBase] inherited properties """ @@ -111,6 +125,7 @@ class OrbiterDAG(OrbiterASTBase, OrbiterBase, extra="allow"): params: Dict[str, Any] doc_md: str | None tasks: Dict[str, OrbiterOperator] + kwargs: dict orbiter_kwargs: dict orbiter_conns: Set[OrbiterConnection] orbiter_vars: Set[OrbiterVariable] @@ -301,6 +316,13 @@ def dedupe_callable(ast_collection): if isinstance(self.schedule, OrbiterTimetable) else set() ) + | reduce( + # Look for e.g. on_failure_callback in model_extra, get imports, merge them all + lambda old, item: old + | set(getattr(self.model_extra.get(item, {}), "imports", set())), + CALLBACK_KEYS, + set(), + ) ) imports = [i._to_ast() for i in sorted(pre_imports)] @@ -311,14 +333,11 @@ def dedupe_callable(ast_collection): ) # foo >> bar - task_dependencies = sorted( - [ - downstream - for task in self.tasks.values() - for downstream in task.downstream - ] - ) - task_dependencies = [downstream._to_ast() for downstream in task_dependencies] + task_dependencies = [ + task._downstream_to_ast() + for task in sorted(self.tasks.values()) + if task._downstream_to_ast() + ] # with DAG(...) as dag: with_dag = py_with( diff --git a/orbiter/objects/project.py b/orbiter/objects/project.py index 41cc7e2..7328b0f 100644 --- a/orbiter/objects/project.py +++ b/orbiter/objects/project.py @@ -254,7 +254,11 @@ def add_dags(self, dags: OrbiterDAG | Iterable[OrbiterDAG]) -> "OrbiterProject": # noinspection t def _add_recursively( things: Iterable[ - OrbiterOperator | OrbiterTaskGroup | OrbiterCallback | OrbiterTimetable + OrbiterOperator + | OrbiterTaskGroup + | OrbiterCallback + | OrbiterTimetable + | OrbiterDAG ], ): for thing in things: @@ -278,27 +282,14 @@ def _add_recursively( self.add_requirements(imports) if isinstance(thing, OrbiterTaskGroup) and (tasks := thing.tasks): _add_recursively(tasks) - - # find callbacks in any 'model extra' or attributes named - # "on_success_callback" or "on_failure_callback" - if ( - hasattr(thing, "__dict__") - and hasattr(thing, "model_extra") - and len( + if hasattr(thing, "__dict__") or hasattr(thing, "model_extra"): + # If it's a pydantic model or dict, check its properties for more things to add + _add_recursively( ( - callbacks := { - k: v - for k, v in ( - (thing.__dict__ or dict()) - | (thing.model_extra or dict()) - ).items() - if k in ("on_success_callback", "on_failure_callback") - and issubclass(type(v), OrbiterCallback) - } - ) + (getattr(thing, "__dict__", {}) or dict()) + | (getattr(thing, "model_extra", {}) or dict()) + ).values() ) - ): - _add_recursively(callbacks.values()) for dag in [dags] if isinstance(dags, OrbiterDAG) else dags: dag_id = dag.dag_id @@ -309,14 +300,11 @@ def _add_recursively( else: self.dags[dag_id] = dag - # Add imports to the project - self.add_requirements(dag.imports) - # Add anything that might be in the tasks of the DAG - such as imports, Connections, etc _add_recursively((dag.tasks or {}).values()) # Add anything that might be in the `dag.schedule` - such as Includes, Timetables, Connections, etc - _add_recursively([dag.schedule]) + _add_recursively([dag]) return self def add_env_vars( diff --git a/orbiter/objects/task.py b/orbiter/objects/task.py index 52e02e5..f0206e7 100644 --- a/orbiter/objects/task.py +++ b/orbiter/objects/task.py @@ -2,13 +2,17 @@ import ast from abc import ABC -from typing import Set, List, ClassVar, Annotated, Callable, Literal +from typing import Set, List, ClassVar, Annotated, Callable from loguru import logger from pydantic import AfterValidator, BaseModel, validate_call -from orbiter import clean_value -from orbiter.ast_helper import OrbiterASTBase, py_bitshift, py_function +from orbiter import clean_value, ORBITER_TASK_SUFFIX +from orbiter.ast_helper import ( + OrbiterASTBase, + py_function, + py_bitshift, +) from orbiter.ast_helper import py_assigned_object from orbiter.objects import ImportList from orbiter.objects import OrbiterBase @@ -21,7 +25,6 @@ OrbiterOperator --> "many" OrbiterConnection OrbiterOperator --> "many" OrbiterVariable OrbiterOperator --> "many" OrbiterEnvVar -OrbiterOperator --> "many" OrbiterTaskDependency --8<-- [end:mermaid-dag-relationships] --8<-- [start:mermaid-task-relationships] @@ -37,57 +40,64 @@ def task_add_downstream( self, task_id: str | List[str] | OrbiterTaskDependency ) -> "OrbiterOperator" | "OrbiterTaskGroup": # noqa: F821 + # noinspection PyProtectedMember """ Add a downstream task dependency - """ + + ```pycon + >>> from orbiter.objects.operators.empty import OrbiterEmptyOperator + >>> from orbiter.objects.task import task_add_downstream, OrbiterTaskDependency + >>> from orbiter.ast_helper import render_ast + >>> render_ast(task_add_downstream(OrbiterEmptyOperator(task_id="task_id"), "downstream")._downstream_to_ast()) + 'task_id_task >> downstream_task' + >>> render_ast( + ... task_add_downstream(OrbiterEmptyOperator(task_id="task_id"), + ... OrbiterTaskDependency(task_id="task_id", downstream="downstream"))._downstream_to_ast() + ... ) + 'task_id_task >> downstream_task' + >>> render_ast(task_add_downstream( + ... OrbiterEmptyOperator(task_id="task_id"), + ... ["downstream"] + ... )._downstream_to_ast()) + 'task_id_task >> downstream_task' + >>> render_ast(task_add_downstream( + ... OrbiterEmptyOperator(task_id="task_id"), + ... ["downstream", "downstream2"] + ... )._downstream_to_ast()) + 'task_id_task >> [downstream2_task, downstream_task]' + + ``` + """ # noqa: E501 if isinstance(task_id, OrbiterTaskDependency): task_dependency = task_id if task_dependency.task_id != self.task_id: raise ValueError( f"task_dependency={task_dependency} has a different task_id than {self.task_id}" ) - self.downstream.add(task_dependency) - return self + # do normal parsing logic, but with these downstream items + task_id = task_dependency.downstream - if not len(task_id): + if isinstance(task_id, str): + self.downstream |= {to_task_id(task_id)} return self + else: + if not len(task_id): + return self - if len(task_id) == 1: - task_id = task_id[0] - downstream_task_id = ( - [to_task_id(t) for t in task_id] - if isinstance(task_id, list) - else to_task_id(task_id) - ) - logger.debug(f"Adding downstream {downstream_task_id} to {self.task_id}") - self.downstream.add( - OrbiterTaskDependency(task_id=self.task_id, downstream=downstream_task_id) - ) - return self + downstream_task_id = {to_task_id(t) for t in task_id} + logger.debug(f"Adding downstream {downstream_task_id} to {self.task_id}") + self.downstream |= downstream_task_id + return self -class OrbiterTaskDependency(OrbiterASTBase, BaseModel, extra="forbid"): +class OrbiterTaskDependency(BaseModel, extra="forbid"): """Represents a task dependency, which is added to either an [`OrbiterOperator`][orbiter.objects.task.OrbiterOperator] or an [`OrbiterTaskGroup`][orbiter.objects.task_group.OrbiterTaskGroup]. - Can take a single downstream `task_id` - ```pycon - >>> OrbiterTaskDependency(task_id="task_id", downstream="downstream") - task_id_task >> downstream_task - - ``` - - or a list of downstream `task_ids` - ```pycon - >>> OrbiterTaskDependency(task_id="task_id", downstream=["a", "b"]) - task_id_task >> [a_task, b_task] - - ``` - :param task_id: The task_id for the operator :type task_id: str - :param downstream: downstream tasks + :param downstream: downstream task(s) :type downstream: str | List[str] """ @@ -96,25 +106,22 @@ class OrbiterTaskDependency(OrbiterASTBase, BaseModel, extra="forbid"): downstream: TaskId | List[TaskId] # --8<-- [end:mermaid-td-props] - def _to_ast(self): - if isinstance(self.downstream, str): - return py_bitshift( - to_task_id(self.task_id, "_task"), to_task_id(self.downstream, "_task") - ) - elif isinstance(self.downstream, list): - return py_bitshift( - to_task_id(self.task_id, "_task"), - [to_task_id(t, "_task") for t in self.downstream], - ) + def __str__(self): + return f"{self.task_id} >> {self.downstream}" + + def __repr__(self): + return str(self) class OrbiterOperator(OrbiterASTBase, OrbiterBase, ABC, extra="allow"): """ **Abstract class** representing a - [Task in Airflow](https://airflow.apache.org/docs/apache-airflow/stable/tutorial/fundamentals.html#operators), - must be subclassed (such as [`OrbiterBashOperator`][orbiter.objects.operators.bash.OrbiterBashOperator]) + [Task in Airflow](https://airflow.apache.org/docs/apache-airflow/stable/tutorial/fundamentals.html#operators). + + **Must be subclassed** (such as [`OrbiterBashOperator`][orbiter.objects.operators.bash.OrbiterBashOperator], + or [`OrbiterTask`][orbiter.objects.task.OrbiterTask]). - Instantiation/inheriting: + Subclassing Example: ```pycon >>> from orbiter.objects import OrbiterRequirement >>> class OrbiterMyOperator(OrbiterOperator): @@ -128,15 +135,16 @@ class OrbiterOperator(OrbiterASTBase, OrbiterBase, ABC, extra="allow"): Adding single downstream tasks: ```pycon - >>> foo.add_downstream("downstream").downstream - {task_id_task >> downstream_task} + >>> from orbiter.ast_helper import render_ast + >>> render_ast(foo.add_downstream("downstream")._downstream_to_ast()) + 'task_id_task >> downstream_task' ``` Adding multiple downstream tasks: ```pycon - >>> sorted(list(foo.add_downstream(["a", "b"]).downstream)) - [task_id_task >> [a_task, b_task], task_id_task >> downstream_task] + >>> render_ast(foo.add_downstream(["a", "b"])._downstream_to_ast()) + 'task_id_task >> [a_task, b_task, downstream_task]' ``` @@ -167,7 +175,7 @@ class OrbiterOperator(OrbiterASTBase, OrbiterBase, ABC, extra="allow"): :param operator: Operator name :type operator: str, optional :param downstream: Downstream tasks, defaults to `set()` - :type downstream: Set[OrbiterTaskDependency], optional + :type downstream: Set[str], optional :param **kwargs: Other properties that may be passed to operator :param **OrbiterBase: [OrbiterBase][orbiter.objects.OrbiterBase] inherited properties """ @@ -182,7 +190,7 @@ class OrbiterOperator(OrbiterASTBase, OrbiterBase, ABC, extra="allow"): pool: str | None = None pool_slots: int | None = None orbiter_pool: OrbiterPool | None = None - downstream: Set[OrbiterTaskDependency] = set() + downstream: Set[str] = set() render_attributes: RenderAttributes = [ "task_id", @@ -199,7 +207,7 @@ class OrbiterOperator(OrbiterASTBase, OrbiterBase, ABC, extra="allow"): pool: str | None pool_slots: int | None trigger_rule: str | None - downstream: Set[OrbiterTaskDependency] + downstream: Set[str] add_downstream(str | List[str] | OrbiterTaskDependency) --8<-- [end:mermaid-op-props] """ @@ -209,6 +217,21 @@ def add_downstream( ) -> "OrbiterOperator": return task_add_downstream(self, task_id) + def _downstream_to_ast(self) -> List[ast.stmt]: + if not self.downstream: + return [] + elif len(self.downstream) == 1: + (t,) = tuple(self.downstream) + return py_bitshift( + to_task_id(self.task_id, ORBITER_TASK_SUFFIX), + to_task_id(t, ORBITER_TASK_SUFFIX), + ) + else: + return py_bitshift( + to_task_id(self.task_id, ORBITER_TASK_SUFFIX), + sorted([to_task_id(t, ORBITER_TASK_SUFFIX) for t in self.downstream]), + ) + def _to_ast(self) -> ast.stmt: def prop(k): attr = getattr(self, k, None) or getattr(self.model_extra, k, None) @@ -216,7 +239,7 @@ def prop(k): # foo = Bar(x=x,y=y, z=z) return py_assigned_object( - to_task_id(self.task_id, "_task"), + to_task_id(self.task_id, ORBITER_TASK_SUFFIX), self.operator, **{k: prop(k) for k in self.render_attributes if k and getattr(self, k)}, **{k: prop(k) for k in (self.model_extra.keys() or [])}, @@ -225,11 +248,10 @@ def prop(k): class OrbiterTask(OrbiterOperator, extra="allow"): """ - A generic Airflow [`OrbiterOperator`][orbiter.objects.task.OrbiterOperator] that can be instantiated directly. - - The operator that is instantiated is inferred from the `imports` field. + A generic version of [`OrbiterOperator`][orbiter.objects.task.OrbiterOperator] that can be instantiated directly. - The first `*Operator` or `*Sensor` import is used. + The operator that is instantiated is inferred from the `imports` field, + via the first `*Operator` or `*Sensor` import. [View info for specific operators at the Astronomer Registry.](https://registry.astronomer.io/) @@ -291,7 +313,7 @@ def prop(k): [operator] = operator_names self_as_ast = py_assigned_object( - to_task_id(self.task_id, "_task"), + to_task_id(self.task_id, ORBITER_TASK_SUFFIX), operator, **{ k: prop(k) @@ -313,24 +335,22 @@ def prop(k): @validate_call -def to_task_id(task_id: str, assignment_suffix: Literal["", "_task"] = "") -> str: +def to_task_id(task_id: str, assignment_suffix: str = "") -> str: # noinspection PyTypeChecker """General utiltty function - turns MyTaskId into my_task_id (or my_task_id_task suffix is `_task`) - :param task_id: - :param assignment_suffix: e.g. `_task` for `task_id_task = MyOperator(...)` - ```pycon >>> to_task_id("MyTaskId") 'my_task_id' >>> to_task_id("MyTaskId", "_task") 'my_task_id_task' - >>> to_task_id("MyTaskId", "_other") - Traceback (most recent call last): - pydantic_core._pydantic_core.ValidationError: ... >>> to_task_id("my_task_id_task", "_task") 'my_task_id_task' ``` + :param task_id: + :type task_id: str + :param assignment_suffix: e.g. `_task` for `task_id_task = MyOperator(...)` + :type assignment_suffix: str """ task_id = clean_value(task_id) return task_id + ( diff --git a/orbiter/objects/task_group.py b/orbiter/objects/task_group.py index e35cf82..aa75820 100644 --- a/orbiter/objects/task_group.py +++ b/orbiter/objects/task_group.py @@ -6,13 +6,20 @@ from pydantic import field_validator -from orbiter.ast_helper import OrbiterASTBase, py_with, py_object +from orbiter import ORBITER_TASK_SUFFIX +from orbiter.ast_helper import ( + OrbiterASTBase, + py_with, + py_object, + py_bitshift, +) from orbiter.objects import OrbiterBase, ImportList, OrbiterRequirement from orbiter.objects.task import ( TaskId, OrbiterTaskDependency, OrbiterOperator, task_add_downstream, + to_task_id, ) __mermaid__ = """ @@ -32,15 +39,19 @@ class OrbiterTaskGroup(OrbiterASTBase, OrbiterBase, ABC, extra="forbid"): ```pycon >>> from orbiter.objects.operators.bash import OrbiterBashOperator + >>> from orbiter.ast_helper import render_ast >>> OrbiterTaskGroup(task_group_id="foo", tasks=[ ... OrbiterBashOperator(task_id="b", bash_command="b"), ... OrbiterBashOperator(task_id="a", bash_command="a").add_downstream("b"), - ... ]) + ... ], downstream={"c"}) with TaskGroup(group_id='foo') as foo: b_task = BashOperator(task_id='b', bash_command='b') a_task = BashOperator(task_id='a', bash_command='a') a_task >> b_task + >>> render_ast(OrbiterTaskGroup(task_group_id="foo", tasks=[], downstream={"c"})._downstream_to_ast()) + 'foo >> c_task' + ``` :param task_group_id: The id of the TaskGroup @@ -67,7 +78,17 @@ class OrbiterTaskGroup(OrbiterASTBase, OrbiterBase, ABC, extra="forbid"): ] task_group_id: TaskId tasks: List[Any] - downstream: Set[OrbiterTaskDependency] = set() + downstream: Set[str] = set() + + @property + def task_id(self): + # task_id property, so it can be treated like an OrbiterOperator more easily + return self.task_group_id + + @task_id.setter + def task_id(self, value): + # task_id property, so it can be treated like an OrbiterOperator more easily + self.task_group_id = value # noinspection PyNestedDecorators @field_validator("tasks") @@ -89,11 +110,25 @@ def add_downstream( ) -> "OrbiterTaskGroup": return task_add_downstream(self, task_id) + def _downstream_to_ast(self): + if not self.downstream: + return + elif len(self.downstream) == 1: + (t,) = tuple(self.downstream) + return py_bitshift( + to_task_id(self.task_id), to_task_id(t, ORBITER_TASK_SUFFIX) + ) + else: + return py_bitshift( + to_task_id(self.task_id), + sorted([to_task_id(t, ORBITER_TASK_SUFFIX) for t in self.downstream]), + ) + def _to_ast(self) -> ast.stmt: # noinspection PyProtectedMember return py_with( py_object("TaskGroup", group_id=self.task_group_id).value, [operator._to_ast() for operator in self.tasks] - + [dep._to_ast() for operator in self.tasks for dep in operator.downstream], + + [operator._downstream_to_ast() for operator in self.tasks], self.task_group_id, ) diff --git a/orbiter/rules/__init__.py b/orbiter/rules/__init__.py index a2a3929..6225190 100644 --- a/orbiter/rules/__init__.py +++ b/orbiter/rules/__init__.py @@ -2,19 +2,20 @@ The brain of the Orbiter framework is in it's [`Rules`][orbiter.rules.Rule] and the [`Rulesets`][orbiter.rules.rulesets.Ruleset] that contain them. -- A [`Rule`][orbiter.rules.Rule] contains a python function that is evaluated and produces something -(typically an [Object](../objects)) or nothing +- A [`Rule`][orbiter.rules.Rule] is a python function that is evaluated and produces **something** +(typically an [Object](../objects)) or **nothing** - A [`Ruleset`][orbiter.rules.rulesets.Ruleset] is a collection of [`Rules`][orbiter.rules.Rule] that are evaluated in priority order - A [`TranslationRuleset`][orbiter.rules.rulesets.TranslationRuleset] is a collection of [`Rulesets`][orbiter.rules.rulesets.Ruleset], - relating to an [Origin](../origins) and [FileType][orbiter.rules.rulesets.load_filetype], + relating to an [Origin](../origins) and `FileType`, with a [`translate_fn`][orbiter.rules.rulesets.translate] which determines how to apply the rulesets. -Different [`Rules`][orbiter.rules.Rule] are applied in different scenarios; -such as for converting input to a DAG ([`@dag_rule`][orbiter.rules.DAGRule]), -or a specific Airflow Operator ([`@task_rule`][orbiter.rules.TaskRule]), -or for filtering entries from the input data +Different [`Rules`][orbiter.rules.Rule] are applied in different scenarios, such as: + +- converting input to an Airflow DAG ([`@dag_rule`][orbiter.rules.DAGRule]), +- converting input to a specific Airflow Operator ([`@task_rule`][orbiter.rules.TaskRule]), +- filtering entries from the input data ([`@dag_filter_rule`][orbiter.rules.DAGFilterRule], [`@task_filter_rule`][orbiter.rules.TaskFilterRule]). !!! tip @@ -36,6 +37,8 @@ def my_rule(val): if 'command' in val: return OrbiterBashOperator(task_id=val['id'], bash_command=val['command']) + else: + return None ``` This returns a @@ -47,11 +50,14 @@ def my_rule(val): from __future__ import annotations import functools +import json import re from typing import Callable, Any, Collection, TYPE_CHECKING, List from pydantic import BaseModel, Field +from loguru import logger + from orbiter.objects.task import OrbiterOperator, OrbiterTaskDependency if TYPE_CHECKING: @@ -146,11 +152,17 @@ class Rule(BaseModel, Callable, extra="forbid"): priority: int = Field(0, ge=0) def __call__(self, *args, **kwargs): - result = self.rule(*args, **kwargs) - # Save the original kwargs under orbiter_kwargs - if result: - if kwargs and hasattr(result, "orbiter_kwargs"): - setattr(result, "orbiter_kwargs", kwargs) + try: + result = self.rule(*args, **kwargs) + # Save the original kwargs under orbiter_kwargs + if result: + if kwargs and hasattr(result, "orbiter_kwargs"): + setattr(result, "orbiter_kwargs", kwargs) + except Exception as e: + logger.warning( + f"[RULE]: {self.rule.__name__}\n[ERROR]:\n{type(e)} - {e}\n[INPUT]:\n{args}\n{kwargs}" + ) + result = None return result @@ -291,6 +303,21 @@ def foo(val: OrbiterProject) -> None: post_processing_rule: Callable[[...], PostProcessingRule] = rule +@task_rule(priority=1) +def cannot_map_rule(val: dict) -> OrbiterOperator | None: + """Can be used in a TaskRuleset. + Returns an `OrbiterEmptyOperator` with a doc string that says it cannot map the task. + Useful to ensure that tasks that cannot be mapped are still visible in the output. + """ + from orbiter.objects.operators.empty import OrbiterEmptyOperator + + # noinspection PyArgumentList + return OrbiterEmptyOperator( + task_id="UNKNOWN", + doc_md=f"""Input did not translate: `{json.dumps(val, default=str)}`""", + ) + + EMPTY_RULE = Rule(rule=lambda _: None, priority=0) """Empty rule, for testing""" diff --git a/orbiter/rules/rulesets.py b/orbiter/rules/rulesets.py index 229d60a..aba7cbe 100644 --- a/orbiter/rules/rulesets.py +++ b/orbiter/rules/rulesets.py @@ -3,16 +3,19 @@ import functools import inspect import re +import uuid from _operator import add from abc import ABC from itertools import chain from pathlib import Path -from typing import List, Any, Collection, Annotated, Callable, Union +from tempfile import TemporaryDirectory +from typing import List, Any, Collection, Annotated, Callable, Union, Generator from loguru import logger from pydantic import BaseModel, AfterValidator, validate_call -from orbiter import FileType, import_from_qualname +from orbiter import FileType +from orbiter import import_from_qualname from orbiter.objects.dag import OrbiterDAG from orbiter.objects.project import OrbiterProject from orbiter.objects.task import OrbiterOperator, OrbiterTaskDependency @@ -95,6 +98,117 @@ def validate_qualified_imports(qualified_imports: List[str]) -> List[str]: ] +# noinspection t +def xmltodict_parse(input_str: str) -> Any: + """Calls `xmltodict.parse` and does post-processing fixes. + + !!! note + + The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER: + + - a dict (one child element of type) + - or a list of dict (many child element of type) + + This behavior can be confusing, and is an issue with the original xml spec being referenced. + + **This method deviates by standardizing to the latter case (always a `list[dict]`).** + + **All XML elements will be a list of dictionaries, even if there's only one element.** + + ```pycon + >>> xmltodict_parse("") + Traceback (most recent call last): + xml.parsers.expat.ExpatError: no element found: line 1, column 0 + >>> xmltodict_parse("") + {'a': None} + >>> xmltodict_parse("") + {'a': [{'@foo': 'bar'}]} + >>> xmltodict_parse("") # Singleton - gets modified + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]} + >>> xmltodict_parse("") # Nested Singletons - modified + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]} + >>> xmltodict_parse("") + {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]} + + ``` + :param input_str: The XML string to parse + :type input_str: str + :return: The parsed XML + :rtype: dict + """ + import xmltodict + + # noinspection t + def _fix(d): + """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry.""" + # if it's a dict, descend to fix + if isinstance(d, dict): + for k, v in d.items(): + # @keys are properties of elements, non-@keys are elements + if not k.startswith("@"): + if isinstance(v, dict): + # THE FIX + # any non-@keys should be a list of dict, even if there's just one of the element + d[k] = [v] + _fix(v) + else: + _fix(v) + # if it's a list, descend to fix + if isinstance(d, list): + for v in d: + _fix(v) + + output = xmltodict.parse(input_str) + _fix(output) + return output + + +def _add_task_deduped(_task, _tasks, n=""): + """ + If this task_id doesn't already exist, add it as normal to the tasks dictionary. + If this task_id does exist - add a number to the end and try again + + ```pycon + >>> from pydantic import BaseModel + >>> class Task(BaseModel): + ... task_id: str + >>> tasks = {} + >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks + {'foo': Task(task_id='foo')} + >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks + {'foo': Task(task_id='foo'), 'foo1': Task(task_id='foo1')} + >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks + {'foo': Task(task_id='foo'), 'foo1': Task(task_id='foo1'), 'foo2': Task(task_id='foo2')} + + ``` + """ + if hasattr(_task, "task_id"): + _id = "task_id" + elif hasattr(_task, "task_group_id"): + _id = "task_group_id" + else: + raise TypeError( + "Attempting to add a task without a `task_id` or `task_group_id` attribute" + ) + + old_task_id = getattr(_task, _id) + new_task_id = old_task_id + n + if new_task_id not in _tasks: + if n != "": + logger.warning( + f"{old_task_id} encountered more than once, task IDs must be unique! " + f"Modifying task ID to '{new_task_id}'!" + ) + setattr(_task, _id, new_task_id) + _tasks[new_task_id] = _task + else: + try: + n = str(int(n) + 1) + except ValueError: + n = "1" + _add_task_deduped(_task, _tasks, n) + + # noinspection t @validate_call def translate(translation_ruleset, input_dir: Path) -> OrbiterProject: @@ -104,13 +218,14 @@ def translate(translation_ruleset, input_dir: Path) -> OrbiterProject: {"": { ..., "": { ...} }} ``` - The default translation function (`orbiter.rules.rulesets.translate`) performs the following steps: + The standard translation function performs the following steps: ![Diagram of Orbiter Translation](../orbiter_diagram.png) - 1. **Find all files** with the expected + 1. [**Find all files**][orbiter.rules.rulesets.TranslationRuleset.get_files_with_extension] with the expected [`TranslationRuleset.file_type`][orbiter.rules.rulesets.TranslationRuleset] - (`.json`, `.xml`, `.yaml`, etc.) in the input folder. Load each file and turn it into a Python Dictionary. + (`.json`, `.xml`, `.yaml`, etc.) in the input folder. + - [**Load each file**][orbiter.rules.rulesets.TranslationRuleset.loads] and turn it into a Python Dictionary. 2. **For each file:** Apply the [`TranslationRuleset.dag_filter_ruleset`][orbiter.rules.rulesets.DAGFilterRuleset] to filter down to entries that can translate to a DAG, in priority order. - **For each**: Apply the [`TranslationRuleset.dag_ruleset`][orbiter.rules.rulesets.DAGRuleset], @@ -137,15 +252,6 @@ def translate(translation_ruleset, input_dir: Path) -> OrbiterProject: """ - - def _get_files_with_extension(_extension: str, _input_dir: Path) -> List[Path]: - return [ - directory / file - for (directory, _, files) in _input_dir.walk() - for file in files - if _extension == file.lower()[-len(_extension) :] - ] - if not isinstance(translation_ruleset, TranslationRuleset): raise RuntimeError( f"Error! type(translation_ruleset)=={type(translation_ruleset)}!=TranslationRuleset! Exiting!" @@ -154,22 +260,10 @@ def _get_files_with_extension(_extension: str, _input_dir: Path) -> List[Path]: # Create an initial OrbiterProject project = OrbiterProject() - extension = translation_ruleset.file_type.value.lower() - - logger.info(f"Finding files with extension={extension} in {input_dir}") - files = _get_files_with_extension(extension, input_dir) - - # .yaml is sometimes '.yml' - if extension == "yaml": - files.extend(_get_files_with_extension("yml", input_dir)) - - logger.info(f"Found {len(files)} files with extension={extension} in {input_dir}") - - for file in files: - logger.info(f"Translating file={file.resolve()}") - - # Load the file and convert it into a python dict - input_dict = load_filetype(file.read_text(), translation_ruleset.file_type) + for i, (file, input_dict) in enumerate( + translation_ruleset.get_files_with_extension(input_dir) + ): + logger.info(f"Translating [File {i}]={file.resolve()}") # DAG FILTER Ruleset - filter down to keys suspected of being translatable to a DAG, in priority order. dag_dicts = functools.reduce( @@ -316,15 +410,16 @@ def apply_many( {'a': {'Type': 'Folder'}, 'c': {'Type': 'Folder'}} ``` + !!! tip - You cannot pass input without length - ```pycon - >>> ruleset.apply_many({}) - ... # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - RuntimeError: Input is not Collection[Any] with length! + You cannot pass input without length + ```pycon + >>> ruleset.apply_many({}) + ... # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + RuntimeError: Input is not Collection[Any] with length! - ``` + ``` :param input_val: List to evaluate ruleset over :type input_val: Collection[Any] :param take_first: Only take the first (if any) result from each ruleset application @@ -512,6 +607,10 @@ class PostProcessingRuleset(Ruleset): ruleset: List[PostProcessingRule | Rule | Callable[[OrbiterProject], None] | dict] +EMPTY_RULESET = {"ruleset": [EMPTY_RULE]} +"""Empty ruleset, for testing""" + + class TranslationRuleset(BaseModel, ABC, extra="forbid"): """ A `Ruleset` is a collection of [`Rules`][orbiter.rules.Rule] that are @@ -543,7 +642,7 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): :type dag_filter_ruleset: DAGFilterRuleset | dict :param dag_ruleset: [`DAGRuleset`][orbiter.rules.rulesets.DAGRuleset] (of [`DAGRules`][orbiter.rules.DAGRule]) :type dag_ruleset: DAGRuleset | dict - :param task_filter_ruleset: [`TaskFilterRule`][orbiter.rules.rulesets.TaskFilterRule] + :param task_filter_ruleset: [`TaskFilterRuleset`][orbiter.rules.rulesets.TaskFilterRuleset] (of [`TaskFilterRule`][orbiter.rules.TaskFilterRule]) :type task_filter_ruleset: TaskFilterRuleset | dict :param task_ruleset: [`TaskRuleset`][orbiter.rules.rulesets.TaskRuleset] (of [`TaskRules`][orbiter.rules.TaskRule]) @@ -570,137 +669,132 @@ class TranslationRuleset(BaseModel, ABC, extra="forbid"): post_processing_ruleset: PostProcessingRuleset | dict translate_fn: TranslateFn = translate + @validate_call + def loads(self, input_str: str) -> dict: + """ + Converts all files of type into a Python dictionary "intermediate representation" form, + prior to any rulesets being applied. + + | FileType | Conversion Method | + |----------|-------------------------------------------------------------| + | `XML` | [`xmltodict_parse`][orbiter.rules.rulesets.xmltodict_parse] | + | `YAML` | `yaml.safe_load` | + | `JSON` | `json.loads` | + + :param input_str: The string to convert to a dictionary + :type input_str: str + :return: The dictionary representation of the input_str + :rtype: dict + """ -def _add_task_deduped(_task, _tasks, n=""): - """ - If this task_id doesn't already exist, add it as normal to the tasks dictionary. - If this task_id does exist - add a number to the end and try again - - ```pycon - >>> from pydantic import BaseModel - >>> class Task(BaseModel): - ... task_id: str - >>> tasks = {} - >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks - {'foo': Task(task_id='foo')} - >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks - {'foo': Task(task_id='foo'), 'foo1': Task(task_id='foo1')} - >>> _add_task_deduped(Task(task_id="foo"), tasks); tasks - {'foo': Task(task_id='foo'), 'foo1': Task(task_id='foo1'), 'foo2': Task(task_id='foo2')} - - ``` - """ - new_task_id = _task.task_id + n - if new_task_id not in _tasks: - if n != "": - logger.warning( - f"{_task.task_id} encountered more than once, task IDs must be unique! " - f"Modifying task ID to '{new_task_id}'!" - ) - _task.task_id = new_task_id - _tasks[new_task_id] = _task - else: - try: - n = str(int(n) + 1) - except ValueError: - n = "1" - _add_task_deduped(_task, _tasks, n) - - -EMPTY_RULESET = {"ruleset": [EMPTY_RULE]} -"""Empty ruleset, for testing""" - - -@validate_call -def load_filetype(input_str: str, file_type: FileType) -> dict: - """ - Orbiter converts all file types into a Python dictionary "intermediate representation" form, - prior to any rulesets being applied. - - | FileType | Conversion Method | - |----------|-------------------------------------------------------------| - | `XML` | [`xmltodict_parse`][orbiter.rules.rulesets.xmltodict_parse] | - | `YAML` | `yaml.safe_load` | - | `JSON` | `json.loads` | - """ - - if file_type == FileType.JSON: - import json - - return json.loads(input_str) - elif file_type == FileType.YAML: - import yaml - - return yaml.safe_load(input_str) - elif file_type == FileType.XML: - return xmltodict_parse(input_str) - else: - raise NotImplementedError(f"Cannot load file_type={file_type}") - - -# noinspection t -def xmltodict_parse(input_str: str) -> Any: - """Calls `xmltodict.parse` and does post-processing fixes. - - !!! note + if self.file_type == FileType.JSON: + import json - The original [`xmltodict.parse`](https://pypi.org/project/xmltodict/) method returns EITHER: + return json.loads(input_str) + elif self.file_type == FileType.YAML: + import yaml - - a dict (one child element of type) - - or a list of dict (many child element of type) + return yaml.safe_load(input_str) + elif self.file_type == FileType.XML: + return xmltodict_parse(input_str) + else: + raise NotImplementedError(f"Cannot load file_type={self.file_type}") - This behavior can be confusing, and is an issue with the original xml spec being referenced. + @validate_call + def dumps(self, input_dict: dict) -> str: + """ + Convert Python dictionary back to source string form, useful for testing + + | FileType | Conversion Method | + |----------|---------------------| + | `XML` | `xmltodict.unparse` | + | `YAML` | `yaml.safe_dump` | + | `JSON` | `json.dumps` | + + :param input_dict: The dictionary to convert to a string + :type input_dict: dict + :return str: The string representation of the input_dict, in the file_type format + :rtype: str + """ + if self.file_type == FileType.JSON: + import json - **This method deviates by standardizing to the latter case (always a `list[dict]`).** + return json.dumps(input_dict, indent=2) + elif self.file_type == FileType.YAML: + import yaml - **All XML elements will be a list of dictionaries, even if there's only one element.** + return yaml.safe_dump(input_dict) + elif self.file_type == FileType.XML: + import xmltodict - ```pycon - >>> xmltodict_parse("") - Traceback (most recent call last): - xml.parsers.expat.ExpatError: no element found: line 1, column 0 - >>> xmltodict_parse("") - {'a': None} - >>> xmltodict_parse("") - {'a': [{'@foo': 'bar'}]} - >>> xmltodict_parse("") # Singleton - gets modified - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}]}]} - >>> xmltodict_parse("") # Nested Singletons - modified - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz', 'bar': [{'bop': None}]}]}]} - >>> xmltodict_parse("") - {'a': [{'@foo': 'bar', 'foo': [{'@bar': 'baz'}, {'@bing': 'bop'}]}]} + return xmltodict.unparse(input_dict) + else: + raise NotImplementedError(f"Cannot dump file_type={self.file_type}") - ``` - :param input_str: The XML string to parse - :type input_str: str - :return: The parsed XML - :rtype: dict - """ - import xmltodict + def get_files_with_extension(self, input_dir: Path) -> Generator[Path, dict]: + """ + A generator that yields files with a specific extension(s) in a directory - # noinspection t - def _fix(d): - """fix the dict in place, recursively, standardizing on a list of dict even if there's only one entry.""" - # if it's a dict, descend to fix - if isinstance(d, dict): - for k, v in d.items(): - # @keys are properties of elements, non-@keys are elements - if not k.startswith("@"): - if isinstance(v, dict): - # THE FIX - # any non-@keys should be a list of dict, even if there's just one of the element - d[k] = [v] - _fix(v) - else: - _fix(v) - # if it's a list, descend to fix - if isinstance(d, list): - for v in d: - _fix(v) + :param input_dir: The directory to search in + :type input_dir: Path + :return: Generator item of (Path, dict) for each file found + :rtype: Generator[Path, dict] + """ + extension = f".{self.file_type.value.lower()}" + extensions = [extension] + + # YAML and YML are both valid extensions + extension_sub = { + "yaml": "yml", + } + if other_extension := extension_sub.get(self.file_type.value.lower()): + extensions.append(f".{other_extension}") + + logger.debug(f"Finding files with extension={extensions} in {input_dir}") + + def backport_walk(input_dir: Path): + """Path.walk() is only available in Python 3.12+, so, backport""" + import os + + for result in os.walk(input_dir): + yield Path(result[0]), result[1], result[2] + + for directory, _, files in ( + input_dir.walk() if hasattr(input_dir, "walk") else backport_walk(input_dir) + ): + logger.debug(f"Checking directory={directory}") + for file in files: + file = directory / file + if file.suffix.lower() in extensions: + logger.debug(f"File={file} matches extension={extensions}") + yield ( + # Return the file path + file, + # and load the file and convert it into a python dict + self.loads(file.read_text()), + ) - output = xmltodict.parse(input_str) - _fix(output) - return output + def test(self, input_value: str | dict) -> OrbiterProject: + """ + Test an input against the whole ruleset. + - 'input_dict' (a parsed python dict) + - or 'input_str' (raw value) to test against the ruleset. + + :param input_value: The input to test + can be either a dict (passed to `translate_ruleset.dumps()` before `translate_ruleset.loads()`) + or a string (read directly by `translate_ruleset.loads()`) + :type input_value: str | dict + :return: OrbiterProject produced after applying the ruleset + :rtype: OrbiterProject + """ + with TemporaryDirectory() as tempdir: + file = Path(tempdir) / f"{uuid.uuid4()}.{self.file_type.value}" + file.write_text( + self.dumps(input_value) + if isinstance(input_value, dict) + else input_value + ) + return self.translate_fn(translation_ruleset=self, input_dir=file.parent) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index 1426062..d4a85ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ dev = [ "mkdocstrings-python", "markdown-exec", # for rendering the csv table of origins to share w/ CLI "pygments", + "griffe-inherited-docstrings", # test "pytest>=7.4", diff --git a/tests/conftest.py b/tests/conftest.py index 43c51e9..c199236 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,13 @@ +import os from pathlib import Path import orbiter import pytest +manual_tests = pytest.mark.skipif( + not bool(os.getenv("MANUAL_TESTS")), reason="requires env setup" +) + @pytest.fixture(scope="session") def project_root() -> Path: @@ -11,4 +16,5 @@ def project_root() -> Path: @pytest.fixture(scope="session") def project_version() -> str: + # noinspection PyUnresolvedReferences return orbiter.__version__ diff --git a/tests/integration_test.py b/tests/integration_test.py index 506484d..bd0634f 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -1,13 +1,19 @@ +import pytest + from orbiter.__main__ import run +from tests.conftest import manual_tests +# noinspection PyUnreachableCode +@pytest.mark.skip("Relies on orbiter-community-translations, not yet published") +@manual_tests def test_integration(): run("just docker-build-binary", shell=True, capture_output=True, text=True) output = run("just docker-run-binary", shell=True, capture_output=True, text=True) assert "Available Origins" in output.stdout assert ( - "Adding local .pyz files ['/data/astronomer_orbiter_translations.pyz'] to sys.path" + "Adding local .pyz files ['/data/orbiter_community_translations.pyz'] to sys.path" in output.stdout ) - assert "Finding files with extension=xml in /data/workflow" in output.stdout + assert "Finding files with extension=['.xml'] in /data/workflow" in output.stdout diff --git a/tests/orbiter/rules/rulesets_test.py b/tests/orbiter/rules/rulesets_test.py new file mode 100644 index 0000000..622741e --- /dev/null +++ b/tests/orbiter/rules/rulesets_test.py @@ -0,0 +1,34 @@ +from orbiter import FileType +from orbiter.rules.rulesets import TranslationRuleset, EMPTY_RULESET + + +def test__get_files_with_extension(project_root): + translation_ruleset = TranslationRuleset( + file_type=FileType.YAML, + dag_ruleset=EMPTY_RULESET, + dag_filter_ruleset=EMPTY_RULESET, + task_filter_ruleset=EMPTY_RULESET, + task_ruleset=EMPTY_RULESET, + task_dependency_ruleset=EMPTY_RULESET, + post_processing_ruleset=EMPTY_RULESET, + ) + actual = translation_ruleset.get_files_with_extension( + project_root / "tests/resources/test_get_files_with_extension" + ) + expected = [ + ( + project_root + / "tests/resources/test_get_files_with_extension/foo/bar/three.yaml", + {"three": "baz"}, + ), + ( + project_root + / "tests/resources/test_get_files_with_extension/foo/bar/two.yml", + {"two": "bar"}, + ), + ( + project_root / "tests/resources/test_get_files_with_extension/one.YAML", + {"one": "foo"}, + ), + ] + assert sorted(list(actual)) == sorted(expected) diff --git a/tests/resources/test_get_files_with_extension/foo/bar/three.yaml b/tests/resources/test_get_files_with_extension/foo/bar/three.yaml new file mode 100644 index 0000000..05edea8 --- /dev/null +++ b/tests/resources/test_get_files_with_extension/foo/bar/three.yaml @@ -0,0 +1 @@ +three: baz diff --git a/tests/resources/test_get_files_with_extension/foo/bar/two.yml b/tests/resources/test_get_files_with_extension/foo/bar/two.yml new file mode 100644 index 0000000..a7f64e7 --- /dev/null +++ b/tests/resources/test_get_files_with_extension/foo/bar/two.yml @@ -0,0 +1 @@ +two: bar diff --git a/tests/resources/test_get_files_with_extension/one.YAML b/tests/resources/test_get_files_with_extension/one.YAML new file mode 100644 index 0000000..432abea --- /dev/null +++ b/tests/resources/test_get_files_with_extension/one.YAML @@ -0,0 +1 @@ +one: foo diff --git a/tests/resources/translation_template.py b/tests/resources/translation_template.py index ea23eb6..65769b6 100644 --- a/tests/resources/translation_template.py +++ b/tests/resources/translation_template.py @@ -12,6 +12,7 @@ task_rule, task_dependency_rule, post_processing_rule, + cannot_map_rule, ) from orbiter.rules.rulesets import ( DAGFilterRuleset, @@ -58,19 +59,6 @@ def basic_task_rule(val: dict) -> OrbiterOperator | OrbiterTaskGroup | None: return None -@task_rule(priority=1) -def cannot_map_rule(val: dict) -> OrbiterOperator | OrbiterTaskGroup | None: - """This rule returns an `OrbiterEmptyOperator` with a doc string that says it cannot map the task, - so we can still see the task in the output. With a priority=1 it will be applied last - """ - import json - - # noinspection PyArgumentList - return OrbiterEmptyOperator( - task_id=val["task_id"], doc_md=f"Cannot map task! input: {json.dumps(val)}" - ) - - @task_dependency_rule def basic_task_dependency_rule(val: OrbiterDAG) -> list | None: """Translate input into a list of task dependencies"""