diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 5240a8ed6..b6caef585 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -281,7 +281,23 @@ def save_dbt_ls_cache(self, dbt_ls_output: str) -> None: "last_modified": datetime.datetime.now(datetime.timezone.utc).isoformat(), **self.airflow_metadata, } - Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True) + if settings.remote_cache_path: + remote_cache_key_path = settings.remote_cache_path / self.dbt_ls_cache_key / "dbt_ls_cache.json" + with remote_cache_key_path.open("w") as fp: + json.dump(cache_dict, fp) + else: + Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True) + + def _get_dbt_ls_remote_cache(self) -> dict[str, str]: + """Loads the remote cache for dbt ls.""" + cache_dict: dict[str, str] = {} + if settings.remote_cache_path is None: + return cache_dict + remote_cache_key_path = settings.remote_cache_path / self.dbt_ls_cache_key / "dbt_ls_cache.json" + if remote_cache_key_path.exists(): + with remote_cache_key_path.open("r") as fp: + cache_dict = json.load(fp) + return cache_dict def get_dbt_ls_cache(self) -> dict[str, str]: """ @@ -296,7 +312,11 @@ def get_dbt_ls_cache(self) -> dict[str, str]: """ cache_dict: dict[str, str] = {} try: - cache_dict = Variable.get(self.dbt_ls_cache_key, deserialize_json=True) + cache_dict = ( + self._get_dbt_ls_remote_cache() + if settings.remote_cache_path + else Variable.get(self.dbt_ls_cache_key, deserialize_json=True) + ) except (json.decoder.JSONDecodeError, KeyError): return cache_dict else: diff --git a/cosmos/settings.py b/cosmos/settings.py index 71387de6e..1259a15f2 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import tempfile from pathlib import Path @@ -7,7 +9,12 @@ from airflow.version import version as airflow_version from packaging.version import Version -from cosmos.constants import DEFAULT_COSMOS_CACHE_DIR_NAME, DEFAULT_OPENLINEAGE_NAMESPACE +from cosmos.constants import ( + DEFAULT_COSMOS_CACHE_DIR_NAME, + DEFAULT_OPENLINEAGE_NAMESPACE, + FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, +) +from cosmos.exceptions import CosmosValueError # In MacOS users may want to set the envvar `TMPDIR` if they do not want the value of the temp directory to change DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME) @@ -29,3 +36,34 @@ LINEAGE_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", DEFAULT_OPENLINEAGE_NAMESPACE) AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0") + + +def _configure_remote_cache_path() -> Path | None: + remote_cache_path_str = str(conf.get("cosmos", "remote_cache_path", fallback="")) + remote_cache_conn_id = str(conf.get("cosmos", "remote_cache_conn_id", fallback="")) + cache_path = None + + if remote_cache_path_str and not AIRFLOW_IO_AVAILABLE: + raise CosmosValueError( + f"You're trying to specify dbt_ls_cache_remote_path {remote_cache_path_str}, but the required Object " + f"Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to " + f"Airflow 2.8 or later." + ) + elif remote_cache_path_str: + from airflow.io.path import ObjectStoragePath + + if not remote_cache_conn_id: + remote_cache_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get( + remote_cache_path_str.split("://")[0], None + ) + + cache_path = ObjectStoragePath(remote_cache_path_str, conn_id=remote_cache_conn_id) + if not cache_path.exists(): # type: ignore[no-untyped-call] + raise CosmosValueError( + f"`remote_cache_path` {remote_cache_path_str} does not exist or is not accessible using " + f"`remote_cache_conn_id` {remote_cache_conn_id}" + ) + return cache_path + + +remote_cache_path = _configure_remote_cache_path()