|
23 | 23 | import zlib |
24 | 24 | from collections.abc import Iterable, Iterator, Sequence |
25 | 25 | from datetime import timedelta |
26 | | -from typing import TYPE_CHECKING, Any, Literal |
| 26 | +from typing import TYPE_CHECKING, Any, Callable, Literal |
27 | 27 |
|
28 | 28 | import sqlalchemy_jsonfield |
29 | 29 | import uuid6 |
@@ -635,39 +635,43 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[ |
635 | 635 |
|
636 | 636 | :param session: ORM Session |
637 | 637 | """ |
| 638 | + load_json: Callable | None |
| 639 | + if COMPRESS_SERIALIZED_DAGS is False: |
| 640 | + if session.bind.dialect.name in ["sqlite", "mysql"]: |
| 641 | + data_col_to_select = func.json_extract(cls._data, "$.dag.dag_dependencies") |
| 642 | + |
| 643 | + def load_json(deps_data): |
| 644 | + return json.loads(deps_data) if deps_data else [] |
| 645 | + else: |
| 646 | + data_col_to_select = func.json_extract_path(cls._data, "dag", "dag_dependencies") |
| 647 | + load_json = None |
| 648 | + else: |
| 649 | + data_col_to_select = cls._data_compressed |
| 650 | + |
| 651 | + def load_json(deps_data): |
| 652 | + return json.loads(zlib.decompress(deps_data))["dag"]["dag_dependencies"] if deps_data else [] |
| 653 | + |
638 | 654 | latest_sdag_subquery = ( |
639 | 655 | select(cls.dag_id, func.max(cls.created_at).label("max_created")).group_by(cls.dag_id).subquery() |
640 | 656 | ) |
641 | | - if session.bind.dialect.name in ["sqlite", "mysql"]: |
642 | | - query = session.execute( |
643 | | - select(cls.dag_id, func.json_extract(cls._data, "$.dag.dag_dependencies")) |
644 | | - .join( |
645 | | - latest_sdag_subquery, |
646 | | - (cls.dag_id == latest_sdag_subquery.c.dag_id) |
647 | | - & (cls.created_at == latest_sdag_subquery.c.max_created), |
648 | | - ) |
649 | | - .join(cls.dag_model) |
650 | | - .where(~DagModel.is_stale) |
| 657 | + query = session.execute( |
| 658 | + select(cls.dag_id, data_col_to_select) |
| 659 | + .join( |
| 660 | + latest_sdag_subquery, |
| 661 | + (cls.dag_id == latest_sdag_subquery.c.dag_id) |
| 662 | + & (cls.created_at == latest_sdag_subquery.c.max_created), |
651 | 663 | ) |
652 | | - iterator = [(dag_id, json.loads(deps_data) if deps_data else []) for dag_id, deps_data in query] |
653 | | - else: |
654 | | - iterator = session.execute( |
655 | | - select( |
656 | | - cls.dag_id, |
657 | | - func.json_extract_path(cls._data, "dag", "dag_dependencies"), |
658 | | - ) |
659 | | - .join( |
660 | | - latest_sdag_subquery, |
661 | | - (cls.dag_id == latest_sdag_subquery.c.dag_id) |
662 | | - & (cls.created_at == latest_sdag_subquery.c.max_created), |
663 | | - ) |
664 | | - .join(cls.dag_model) |
665 | | - .where(~DagModel.is_stale) |
666 | | - ).all() |
| 664 | + .join(cls.dag_model) |
| 665 | + .where(~DagModel.is_stale) |
| 666 | + ) |
| 667 | + iterator = ( |
| 668 | + [(dag_id, load_json(deps_data)) for dag_id, deps_data in query] |
| 669 | + if load_json is not None |
| 670 | + else query.all() |
| 671 | + ) |
667 | 672 |
|
668 | 673 | resolver = _DagDependenciesResolver(dag_id_dependencies=iterator, session=session) |
669 | 674 | dag_depdendencies_by_dag = resolver.resolve() |
670 | | - |
671 | 675 | return dag_depdendencies_by_dag |
672 | 676 |
|
673 | 677 | @staticmethod |
|
0 commit comments