diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 7f7c7eb4fb..187126a8b4 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -122,48 +122,30 @@ def dispatch( ) -> t.Callable: """Returns a dialect-specific version of a macro with the given name.""" target_type = self.jinja_globals["target"]["type"] - macro_suffix = f"__{macro_name}" - - def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, int]: - """Lower scores more relevant.""" - macro_package, name = package_name_pair - - package_score = 0 if macro_package == macro_namespace else 1 - name_score = 1 - - if name.startswith("default"): - name_score = 2 - elif name.startswith(target_type): - name_score = 0 - - return name_score, package_score jinja_env = self.jinja_macros.build_environment(**self.jinja_globals).globals packages_to_check: t.List[t.Optional[str]] = [None] if macro_namespace is not None: - if macro_namespace in jinja_env: + if dispatch := self.jinja_globals.get("dispatch"): + for entry in dispatch.get(self.jinja_macros.root_package_name, []): + if entry.get("macro_namespace") == macro_namespace: + packages_to_check = entry.get("search_order") + break + if packages_to_check == [None] and macro_namespace in jinja_env: packages_to_check = [self.jinja_macros.root_package_name, macro_namespace] # Add dbt packages as fallback packages_to_check.extend(k for k in jinja_env if k.startswith("dbt")) - candidates = {} for macro_package in packages_to_check: macros = jinja_env.get(macro_package, {}) if macro_package else jinja_env if not isinstance(macros, dict): continue - candidates.update( - { - (macro_package, macro_name): macro_callable - for macro_name, macro_callable in macros.items() - if macro_name.endswith(macro_suffix) - } - ) - if candidates: - sorted_candidates = sorted(candidates, key=_relevance) - return candidates[sorted_candidates[0]] + for prefix in (f"{target_type}__", "default__", ""): + if macro := macros.get(f"{prefix}{macro_name}"): + return macro raise ConfigError(f"Macro '{macro_name}', package '{macro_namespace}' was not found.") diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index 67e70d3c79..c68dd30d2e 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -54,6 +54,7 @@ class DbtContext: _seeds: t.Dict[str, SeedConfig] = field(default_factory=dict) _sources: t.Dict[str, SourceConfig] = field(default_factory=dict) _refs: t.Dict[str, t.Union[ModelConfig, SeedConfig]] = field(default_factory=dict) + _dispatch: t.Dict[str, t.List[t.Dict[str, t.Any]]] = field(default_factory=dict) _target: t.Optional[TargetConfig] = None @@ -136,6 +137,14 @@ def add_macros(self, macros: t.Dict[str, MacroInfo], package: str) -> None: self.jinja_macros.add_macros(macros, package=package) self._jinja_environment = None + @property + def dispatch(self) -> t.Dict[str, t.List[t.Dict[str, t.Any]]]: + return self._dispatch + + def add_dispatch(self, dispatch: t.List[t.Dict[str, t.Any]], package: str) -> None: + self._dispatch[package] = dispatch + self._jinja_environment = None + @property def models(self) -> t.Dict[str, ModelConfig]: return self._models @@ -249,6 +258,8 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]: # Pass flat graph structure like dbt if self._manifest is not None: output["flat_graph"] = AttributeDict(self.manifest.flat_graph) + if self._dispatch is not None: + output["dispatch"] = AttributeDict(self._dispatch) return output def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index eb117a3e40..e63935c82a 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -218,6 +218,7 @@ def _load_projects(self) -> t.List[Project]: context.add_sources(package.sources) context.add_seeds(package.seeds) context.add_models(package.models) + context.add_dispatch(package.dispatch, package_name) macros_mtimes.extend( [ self._path_mtimes[m.path] diff --git a/sqlmesh/dbt/package.py b/sqlmesh/dbt/package.py index 420cf3cb73..3d7e16156d 100644 --- a/sqlmesh/dbt/package.py +++ b/sqlmesh/dbt/package.py @@ -50,6 +50,7 @@ class Package(PydanticModel): on_run_start: t.Dict[str, HookConfig] on_run_end: t.Dict[str, HookConfig] files: t.Set[Path] + dispatch: t.List[t.Dict[str, t.Any]] @property def macro_infos(self) -> t.Dict[str, MacroInfo]: @@ -90,6 +91,8 @@ def load(self, package_root: Path) -> Package: var: value for var, value in all_variables.items() if not isinstance(value, dict) } + dispatch = project_yaml.get("dispatch") or [] + tests = _fix_paths(self._context.manifest.tests(package_name), package_root) models = _fix_paths(self._context.manifest.models(package_name), package_root) seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root) @@ -113,6 +116,7 @@ def load(self, package_root: Path) -> Package: sources=sources, seeds=seeds, variables=package_variables, + dispatch=dispatch, macros=macros, files=config_paths, on_run_start=on_run_start, diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 381401ce73..dfdf576e99 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -242,7 +242,11 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla assert renderer("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb" assert renderer("{{ adapter.dispatch('current_timestamp')() }}") == "now()" assert renderer("{{ adapter.dispatch('current_timestamp', 'dbt')() }}") == "now()" - assert renderer("{{ adapter.dispatch('select_distinct', 'customers')() }}") == "distinct" + + # Macros in root project overrides macros in dependent packages + assert ( + renderer("{{ adapter.dispatch('hello_world', 'my_helpers')() }}") == "hello from sushi_test" + ) # test with keyword arguments assert ( @@ -276,6 +280,14 @@ def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Calla renderer("{{ adapter.dispatch('current_engine')() }}") +@pytest.mark.slow +def test_adapter_dispatch_search_order(sushi_test_project: Project, runtime_renderer: t.Callable): + context = sushi_test_project.context + renderer = runtime_renderer(context) + assert renderer("{{ adapter.dispatch('current_package', 'my_helpers')() }}") == "my_helpers" + assert renderer("{{ adapter.dispatch('current_package', 'customers')() }}") == "my_helpers" + + @pytest.mark.parametrize("project_dialect", ["duckdb", "bigquery"]) @pytest.mark.slow def test_adapter_map_snapshot_tables( diff --git a/tests/fixtures/dbt/sushi_test/dbt_packages/my_helpers b/tests/fixtures/dbt/sushi_test/dbt_packages/my_helpers new file mode 120000 index 0000000000..aecd64e71e --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/dbt_packages/my_helpers @@ -0,0 +1 @@ +../packages/my_helpers \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 0b5f6b0f83..a7b8a8026f 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -79,3 +79,8 @@ on-run-end: - '{{ create_tables(schemas) }}' - 'DROP TABLE to_be_executed_last;' - '{{ graph_usage() }}' + + +dispatch: + - macro_namespace: customers + search_order: ["my_helpers", "customers"] diff --git a/tests/fixtures/dbt/sushi_test/macros/distinct.sql b/tests/fixtures/dbt/sushi_test/macros/distinct.sql deleted file mode 100644 index 1b339a9349..0000000000 --- a/tests/fixtures/dbt/sushi_test/macros/distinct.sql +++ /dev/null @@ -1 +0,0 @@ -{% macro default__select_distinct() %}distinct{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/macros/hello_world.sql b/tests/fixtures/dbt/sushi_test/macros/hello_world.sql new file mode 100644 index 0000000000..8e7319a862 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/hello_world.sql @@ -0,0 +1,2 @@ +{% macro hello_world() %}hello from sushi_test{% endmacro %} + diff --git a/tests/fixtures/dbt/sushi_test/packages.yml b/tests/fixtures/dbt/sushi_test/packages.yml index 34cb31e0a6..08a3f172bc 100644 --- a/tests/fixtures/dbt/sushi_test/packages.yml +++ b/tests/fixtures/dbt/sushi_test/packages.yml @@ -1,3 +1,3 @@ packages: - local: packages/customers - + - local: dbt_packages/my_helpers diff --git a/tests/fixtures/dbt/sushi_test/packages/customers/macros/current_package.sql b/tests/fixtures/dbt/sushi_test/packages/customers/macros/current_package.sql new file mode 100644 index 0000000000..9a8f527ac3 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/packages/customers/macros/current_package.sql @@ -0,0 +1,3 @@ +{% macro current_package() %}{{ return(adapter.dispatch('current_package', 'customers')) }}{% endmacro %} + +{% macro default__current_package() %}customers{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/packages/my_helpers/dbt_project.yml b/tests/fixtures/dbt/sushi_test/packages/my_helpers/dbt_project.yml new file mode 100644 index 0000000000..2e04cde28c --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/packages/my_helpers/dbt_project.yml @@ -0,0 +1,17 @@ + +name: 'my_helpers' +version: '1.0.0' +config-version: 2 +profile: 'my_helpers' + +model-paths: ["models"] +analysis-paths: ["analyses"] +test-paths: ["tests"] +seed-paths: ["seeds"] +macro-paths: ["macros"] +snapshot-paths: ["snapshots"] + +target-path: "target" # directory which will store compiled SQL files +clean-targets: # directories to be removed by `dbt clean` + - "target" + - "dbt_packages" diff --git a/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/current_package.sql b/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/current_package.sql new file mode 100644 index 0000000000..396c111cb5 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/current_package.sql @@ -0,0 +1,3 @@ +{% macro current_package() %}{{ return(adapter.dispatch('current_package', 'my_helpers')) }}{% endmacro %} + +{% macro default__current_package() %}my_helpers{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/hello_world.sql b/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/hello_world.sql new file mode 100644 index 0000000000..b719354da9 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/packages/my_helpers/macros/hello_world.sql @@ -0,0 +1 @@ +{% macro hello_world() %}hello from my_helpers{% endmacro %}