diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 9a723383f..d82083a23 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -139,6 +139,7 @@ def __init__( self.dbt_cmd_global_flags = dbt_cmd_global_flags or [] self.cache_dir = cache_dir self.extra_context = extra_context or {} + kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 49bf45293..2ba2b18ff 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -145,7 +145,6 @@ def __init__( self._dbt_runner: dbtRunner | None = None if self.invocation_mode: self._set_invocation_methods() - kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes super().__init__(**kwargs) # For local execution mode, we're consistent with the LoadMode.DBT_LS command in forwarding the environment diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index d0be2acad..8e0dda9c0 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -285,3 +285,68 @@ def test_created_pod(): ] assert container.args == expected_container_args assert container.command == [] + + +@pytest.mark.parametrize( + "operator_class,kwargs,expected_cmd", + [ + ( + DbtSeedKubernetesOperator, + {"full_refresh": True}, + ["dbt", "seed", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtBuildKubernetesOperator, + {"full_refresh": True}, + ["dbt", "build", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtRunKubernetesOperator, + {"full_refresh": True}, + ["dbt", "run", "--full-refresh", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {}, + ["dbt", "test", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"select": []}, + ["dbt", "test", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]}, + ["dbt", "test", "--select", "tag:daily", "--exclude", "tag:disabled", "--project-dir", "my/dir"], + ), + ( + DbtTestKubernetesOperator, + {"full_refresh": True, "selector": "nightly_snowplow"}, + ["dbt", "test", "--selector", "nightly_snowplow", "--project-dir", "my/dir"], + ), + ], +) +def test_operator_execute_with_flags(operator_class, kwargs, expected_cmd): + task = operator_class( + task_id="my-task", + project_dir="my/dir", + **kwargs, + ) + + with patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.hook", + is_in_cluster=False, + ), patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.cleanup"), patch( + "airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.get_or_create_pod", + side_effect=ValueError("Mock"), + ) as get_or_create_pod: + try: + task.execute(context={}) + except ValueError as e: + if e != get_or_create_pod.side_effect: + raise + + pod_args = get_or_create_pod.call_args.kwargs["pod_request_obj"].to_dict()["spec"]["containers"][0]["args"] + + assert expected_cmd == pod_args