Skip to content

Commit df67f84

Browse files
authored
fix(serialized_dag): handle compressing serialized_dag for get_dag_dependencies (#48924)
1 parent 00aec39 commit df67f84

File tree

2 files changed

+54
-27
lines changed

2 files changed

+54
-27
lines changed

airflow-core/src/airflow/models/serialized_dag.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import zlib
2424
from collections.abc import Iterable, Iterator, Sequence
2525
from datetime import timedelta
26-
from typing import TYPE_CHECKING, Any, Literal
26+
from typing import TYPE_CHECKING, Any, Callable, Literal
2727

2828
import sqlalchemy_jsonfield
2929
import uuid6
@@ -635,39 +635,43 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
635635
636636
:param session: ORM Session
637637
"""
638+
load_json: Callable | None
639+
if COMPRESS_SERIALIZED_DAGS is False:
640+
if session.bind.dialect.name in ["sqlite", "mysql"]:
641+
data_col_to_select = func.json_extract(cls._data, "$.dag.dag_dependencies")
642+
643+
def load_json(deps_data):
644+
return json.loads(deps_data) if deps_data else []
645+
else:
646+
data_col_to_select = func.json_extract_path(cls._data, "dag", "dag_dependencies")
647+
load_json = None
648+
else:
649+
data_col_to_select = cls._data_compressed
650+
651+
def load_json(deps_data):
652+
return json.loads(zlib.decompress(deps_data))["dag"]["dag_dependencies"] if deps_data else []
653+
638654
latest_sdag_subquery = (
639655
select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery()
640656
)
641-
if session.bind.dialect.name in ["sqlite", "mysql"]:
642-
query = session.execute(
643-
select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies"))
644-
.join(
645-
latest_sdag_subquery,
646-
(cls.dag_id == latest_sdag_subquery.c.dag_id)
647-
& (cls.created_at == latest_sdag_subquery.c.max_created),
648-
)
649-
.join(cls.dag_model)
650-
.where(~DagModel.is_stale)
657+
query = session.execute(
658+
select(cls.dag_id, data_col_to_select)
659+
.join(
660+
latest_sdag_subquery,
661+
(cls.dag_id == latest_sdag_subquery.c.dag_id)
662+
& (cls.created_at == latest_sdag_subquery.c.max_created),
651663
)
652-
iterator = [(dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query]
653-
else:
654-
iterator = session.execute(
655-
select(
656-
cls.dag_id,
657-
func.json_extract_path(cls._data, "dag", "dag_dependencies"),
658-
)
659-
.join(
660-
latest_sdag_subquery,
661-
(cls.dag_id == latest_sdag_subquery.c.dag_id)
662-
& (cls.created_at == latest_sdag_subquery.c.max_created),
663-
)
664-
.join(cls.dag_model)
665-
.where(~DagModel.is_stale)
666-
).all()
664+
.join(cls.dag_model)
665+
.where(~DagModel.is_stale)
666+
)
667+
iterator = (
668+
[(dag_id, load_json(deps_data)) for dag_id, deps_data in query]
669+
if load_json is not None
670+
else query.all()
671+
)
667672

668673
resolver = _DagDependenciesResolver(dag_id_dependencies=iterator, session=session)
669674
dag_depdendencies_by_dag = resolver.resolve()
670-
671675
return dag_depdendencies_by_dag
672676

673677
@staticmethod

airflow-core/tests/unit/models/test_serialized_dag.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from airflow.providers.standard.operators.empty import EmptyOperator
3636
from airflow.providers.standard.operators.python import PythonOperator
3737
from airflow.sdk import DAG, Asset, task as task_decorator
38+
from airflow.serialization.dag_dependency import DagDependency
3839
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
3940
from airflow.settings import json
4041
from airflow.utils.hashlib_wrapper import md5
@@ -386,6 +387,28 @@ def test_get_dependencies(self, session):
386387
dependencies = SDM.get_dag_dependencies(session=session)
387388
assert dag_id not in dependencies
388389

390+
def test_get_dependencies_with_asset_ref(self, dag_maker, session):
391+
with dag_maker(
392+
dag_id="test_get_dependencies_with_asset_ref_example",
393+
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
394+
schedule=[Asset.ref(uri="test://asset1")],
395+
) as dag:
396+
BashOperator(task_id="any", bash_command="sleep 5")
397+
dag.sync_to_db()
398+
SDM.write_dag(dag, bundle_name="testing")
399+
dependencies = SDM.get_dag_dependencies(session=session)
400+
assert dependencies == {
401+
"test_get_dependencies_with_asset_ref_example": [
402+
DagDependency(
403+
source="asset-uri-ref",
404+
target="test_get_dependencies_with_asset_ref_example",
405+
label="test://asset1",
406+
dependency_type="asset-uri-ref",
407+
dependency_id="test://asset1",
408+
)
409+
]
410+
}
411+
389412
@pytest.mark.parametrize("min_update_interval", [0, 10])
390413
@mock.patch.object(DagVersion, "get_latest_version")
391414
def test_min_update_interval_is_respected(

0 commit comments

Comments
 (0)