Skip to content

Commit

Permalink
Add support to store and fetch dbt ls cache in remote stores
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Aug 9, 2024
1 parent e847f19 commit ae0f455
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
24 changes: 22 additions & 2 deletions cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,23 @@ def save_dbt_ls_cache(self, dbt_ls_output: str) -> None:
"last_modified": datetime.datetime.now(datetime.timezone.utc).isoformat(),
**self.airflow_metadata,
}
Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True)
if settings.remote_cache_path:
remote_cache_key_path = settings.remote_cache_path / self.dbt_ls_cache_key / "dbt_ls_cache.json"
with remote_cache_key_path.open("w") as fp:
json.dump(cache_dict, fp)
else:
Variable.set(self.dbt_ls_cache_key, cache_dict, serialize_json=True)

def _get_dbt_ls_remote_cache(self) -> dict[str, str]:
"""Loads the remote cache for dbt ls."""
cache_dict: dict[str, str] = {}
if settings.remote_cache_path is None:
return cache_dict
remote_cache_key_path = settings.remote_cache_path / self.dbt_ls_cache_key / "dbt_ls_cache.json"
if remote_cache_key_path.exists():
with remote_cache_key_path.open("r") as fp:
cache_dict = json.load(fp)
return cache_dict

def get_dbt_ls_cache(self) -> dict[str, str]:
"""
Expand All @@ -296,7 +312,11 @@ def get_dbt_ls_cache(self) -> dict[str, str]:
"""
cache_dict: dict[str, str] = {}
try:
cache_dict = Variable.get(self.dbt_ls_cache_key, deserialize_json=True)
cache_dict = (
self._get_dbt_ls_remote_cache()
if settings.remote_cache_path
else Variable.get(self.dbt_ls_cache_key, deserialize_json=True)
)
except (json.decoder.JSONDecodeError, KeyError):
return cache_dict
else:
Expand Down
40 changes: 39 additions & 1 deletion cosmos/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import tempfile
from pathlib import Path
Expand All @@ -7,7 +9,12 @@
from airflow.version import version as airflow_version
from packaging.version import Version

from cosmos.constants import DEFAULT_COSMOS_CACHE_DIR_NAME, DEFAULT_OPENLINEAGE_NAMESPACE
from cosmos.constants import (
DEFAULT_COSMOS_CACHE_DIR_NAME,
DEFAULT_OPENLINEAGE_NAMESPACE,
FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP,
)
from cosmos.exceptions import CosmosValueError

# In MacOS users may want to set the envvar `TMPDIR` if they do not want the value of the temp directory to change
DEFAULT_CACHE_DIR = Path(tempfile.gettempdir(), DEFAULT_COSMOS_CACHE_DIR_NAME)
Expand All @@ -29,3 +36,34 @@
LINEAGE_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", DEFAULT_OPENLINEAGE_NAMESPACE)

AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0")


def _configure_remote_cache_path() -> Path | None:
remote_cache_path_str = str(conf.get("cosmos", "remote_cache_path", fallback=""))
remote_cache_conn_id = str(conf.get("cosmos", "remote_cache_conn_id", fallback=""))
cache_path = None

if remote_cache_path_str and not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify dbt_ls_cache_remote_path {remote_cache_path_str}, but the required Object "
f"Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
f"Airflow 2.8 or later."
)
elif remote_cache_path_str:
from airflow.io.path import ObjectStoragePath

if not remote_cache_conn_id:
remote_cache_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(
remote_cache_path_str.split("://")[0], None
)

cache_path = ObjectStoragePath(remote_cache_path_str, conn_id=remote_cache_conn_id)
if not cache_path.exists(): # type: ignore[no-untyped-call]
raise CosmosValueError(
f"`remote_cache_path` {remote_cache_path_str} does not exist or is not accessible using "
f"`remote_cache_conn_id` {remote_cache_conn_id}"
)
return cache_path


remote_cache_path = _configure_remote_cache_path()

0 comments on commit ae0f455

Please sign in to comment.