diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 92130bd3471c8..2d06dc6216f84 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -255,6 +255,43 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): assert indices == [(-1, TaskInstanceState.SKIPPED)] +class _RenderTemplateFieldsValidationOperator(BaseOperator): + template_fields = ( + "partial_template", + "map_template_xcom", + "map_template_literal", + "map_template_file", + ) + template_ext = (".ext",) + + fields_to_test = [ + "partial_template", + "partial_static", + "map_template_xcom", + "map_template_literal", + "map_static", + "map_template_file", + ] + + def __init__( + self, + partial_template, + partial_static, + map_template_xcom, + map_template_literal, + map_static, + map_template_file, + **kwargs, + ): + for field in self.fields_to_test: + setattr(self, field, value := locals()[field]) + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + + def execute(self, context): + pass + + def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): file_template_dir = tmp_path / "path" / "to" file_template_dir.mkdir(parents=True, exist_ok=True) @@ -262,38 +299,21 @@ def test_mapped_render_template_fields_validating_operator(dag_maker, session, t file_template.write_text("loaded data") with set_current_task_instance_session(session=session): - - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template - - def execute(self, context): - pass - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): task1 = BaseOperator(task_id="op1") output1 = task1.output - mapped = MyOperator.partial( + mapped = _RenderTemplateFieldsValidationOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + ).expand( + map_static=output1, + map_template_literal=["{{ ds }}"], + map_template_xcom=output1, + map_template_file=["/path/to/file.ext"], + ) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - session.add( TaskMap( dag_id=dr.dag_id, @@ -308,16 +328,16 @@ def execute(self, context): mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) + assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + assert mapped_ti.task.partial_template == "a", "Should be rendered!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" + assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" + assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!" + assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): @@ -327,46 +347,33 @@ def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_mak file_template.write_text("loaded data") with set_current_task_instance_session(session=session): - - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template - - def execute(self, context): - pass - with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): - mapped = MyOperator.partial( + mapped = _RenderTemplateFieldsValidationOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" ).expand_kwargs( - [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + [ + { + "map_template_literal": "{{ ds }}", + "map_static": "{{ ds }}", + "map_template_file": "/path/to/file.ext", + # This field is not tested since XCom inside a literal list + # is not rendered (matching BaseOperator rendering behavior). + "map_template_xcom": "", + } + ] ) dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) + assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + assert mapped_ti.task.partial_template == "a", "Should be rendered!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" + assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" + assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" def test_mapped_render_nested_template_fields(dag_maker, session):