diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 4f1319395..c2d6dff7f 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -409,6 +409,27 @@ def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) - return dag_task_group_identifier +def identify_detached_nodes( + nodes: dict[str, DbtNode], + test_behavior: TestBehavior, + detached_nodes: dict[str, DbtNode], + detached_from_parent: dict[str, list[DbtNode]], +) -> None: + """ + Given the nodes that represent a dbt project and the test_behavior, identify the detached test nodes + (test nodes that have multiple dependencies and should run independently). + + Change in-place the dictionaries detached_nodes (detached node ID : node) and detached_from_parent (parent node ID that + is upstream to this test and the test node). + """ + if test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH): + for node_id, node in nodes.items(): + if is_detached_test(node): + detached_nodes[node_id] = node + for parent_id in node.depends_on: + detached_from_parent[parent_id].append(node) + + def build_airflow_graph( nodes: dict[str, DbtNode], dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups @@ -453,13 +474,9 @@ def build_airflow_graph( # Identify test nodes that should be run detached from the associated dbt resource nodes because they # have multiple parents - detached_from_parent = defaultdict(list) - detached_nodes = {} - for node_id, node in nodes.items(): - if is_detached_test(node): - detached_nodes[node_id] = node - for parent_id in node.depends_on: - detached_from_parent[parent_id].append(node) + detached_nodes: dict[str, DbtNode] = {} + detached_from_parent: dict[str, list[DbtNode]] = defaultdict(list) + identify_detached_nodes(nodes, test_behavior, detached_nodes, detached_from_parent) for node_id, node in nodes.items(): conversion_function = node_converters.get(node.resource_type, generate_task_or_group) @@ -487,20 +504,6 @@ def build_airflow_graph( logger.debug(f"Conversion of <{node.unique_id}> was successful!") tasks_map[node_id] = task_or_group - # Handle detached test nodes - for node_id, node in detached_nodes.items(): - test_meta = create_test_task_metadata( - f"{node.resource_name.split('.')[0]}_test", - execution_mode, - test_indirect_selection, - task_args=task_args, - on_warning_callback=on_warning_callback, - render_config=render_config, - node=node, - ) - test_task = create_airflow_task(test_meta, dag, task_group=task_group) - tasks_map[node_id] = test_task - # If test_behaviour=="after_all", there will be one test task, run by the end of the DAG # The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks) if test_behavior == TestBehavior.AFTER_ALL: @@ -516,6 +519,20 @@ def build_airflow_graph( leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes) for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task + elif test_behavior in (TestBehavior.BUILD, TestBehavior.AFTER_EACH): + # Handle detached test nodes + for node_id, node in detached_nodes.items(): + test_meta = create_test_task_metadata( + f"{node.resource_name.split('.')[0]}_test", + execution_mode, + test_indirect_selection, + task_args=task_args, + on_warning_callback=on_warning_callback, + render_config=render_config, + node=node, + ) + test_task = create_airflow_task(test_meta, dag, task_group=task_group) + tasks_map[node_id] = test_task create_airflow_task_dependencies(nodes, tasks_map) _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map, task_group) diff --git a/tests/test_converter.py b/tests/test_converter.py index bc25917c3..34ef266c0 100644 --- a/tests/test_converter.py +++ b/tests/test_converter.py @@ -205,6 +205,82 @@ def test_converter_creates_dag_with_test_with_multiple_parents(): assert args[1:] == ["test", "--select", "custom_test_combined_model_combined_model_.c6e4587380"] +@pytest.mark.integration +def test_converter_creates_dag_with_test_with_multiple_parents_test_afterall(): + """ + Validate topology of a project that uses the MULTIPLE_PARENTS_TEST_DBT_PROJECT project + """ + project_config = ProjectConfig(dbt_project_path=MULTIPLE_PARENTS_TEST_DBT_PROJECT) + execution_config = ExecutionConfig(execution_mode=ExecutionMode.LOCAL) + render_config = RenderConfig(test_behavior=TestBehavior.AFTER_ALL) + profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), + ) + with DAG("sample_dag", start_date=datetime(2024, 4, 16)) as dag: + converter = DbtToAirflowConverter( + dag=dag, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + ) + tasks = converter.tasks_map + + assert len(converter.tasks_map) == 3 + + assert tasks["model.my_dbt_project.combined_model"].task_id == "combined_model_run" + assert tasks["model.my_dbt_project.model_a"].task_id == "model_a_run" + assert tasks["model.my_dbt_project.model_b"].task_id == "model_b_run" + assert tasks["model.my_dbt_project.combined_model"].downstream_task_ids == {"multiple_parents_test_test"} + assert tasks["model.my_dbt_project.model_a"].downstream_task_ids == {"combined_model_run"} + assert tasks["model.my_dbt_project.model_b"].downstream_task_ids == {"combined_model_run"} + multiple_parents_test_test_args = tasks["model.my_dbt_project.combined_model"].downstream_list[0].build_cmd({})[0] + assert multiple_parents_test_test_args[1:] == ["test"] + + +@pytest.mark.integration +def test_converter_creates_dag_with_test_with_multiple_parents_test_none(): + """ + Validate topology of a project that uses the MULTIPLE_PARENTS_TEST_DBT_PROJECT project + """ + project_config = ProjectConfig(dbt_project_path=MULTIPLE_PARENTS_TEST_DBT_PROJECT) + execution_config = ExecutionConfig(execution_mode=ExecutionMode.LOCAL) + render_config = RenderConfig(test_behavior=TestBehavior.NONE) + profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), + ) + with DAG("sample_dag", start_date=datetime(2024, 4, 16)) as dag: + converter = DbtToAirflowConverter( + dag=dag, + project_config=project_config, + profile_config=profile_config, + execution_config=execution_config, + render_config=render_config, + ) + tasks = converter.tasks_map + + assert len(converter.tasks_map) == 3 + + assert tasks["model.my_dbt_project.combined_model"].task_id == "combined_model_run" + assert tasks["model.my_dbt_project.model_a"].task_id == "model_a_run" + assert tasks["model.my_dbt_project.model_b"].task_id == "model_b_run" + assert tasks["model.my_dbt_project.combined_model"].downstream_task_ids == set() + assert tasks["model.my_dbt_project.model_b"].downstream_task_ids == {"combined_model_run"} + assert tasks["model.my_dbt_project.model_b"].downstream_task_ids == {"combined_model_run"} + + @pytest.mark.integration def test_converter_creates_dag_with_test_with_multiple_parents_and_build(): """