Skip to content

Commit

Permalink
Use consistent field names between render tests (apache#45843)
Browse files Browse the repository at this point in the history
Template field rendering tests for expand() and expand_kwargs() use the
same field names to test different things. Although both tests on their
own are correct, the field naming is confusing when you read both tests
side by side.

This changes the field names to be more specific, so different things
are always tested with different names.
  • Loading branch information
uranusjr authored Jan 22, 2025
1 parent 3edd78a commit 537ca7b
Showing 1 changed file with 68 additions and 61 deletions.
129 changes: 68 additions & 61 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,45 +255,65 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
assert indices == [(-1, TaskInstanceState.SKIPPED)]


class _RenderTemplateFieldsValidationOperator(BaseOperator):
template_fields = (
"partial_template",
"map_template_xcom",
"map_template_literal",
"map_template_file",
)
template_ext = (".ext",)

fields_to_test = [
"partial_template",
"partial_static",
"map_template_xcom",
"map_template_literal",
"map_static",
"map_template_file",
]

def __init__(
self,
partial_template,
partial_static,
map_template_xcom,
map_template_literal,
map_static,
map_template_file,
**kwargs,
):
for field in self.fields_to_test:
setattr(self, field, value := locals()[field])
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)

def execute(self, context):
pass


def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path):
file_template_dir = tmp_path / "path" / "to"
file_template_dir.mkdir(parents=True, exist_ok=True)
file_template = file_template_dir / "file.ext"
file_template.write_text("loaded data")

with set_current_task_instance_session(session=session):

class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()):
task1 = BaseOperator(task_id="op1")
output1 = task1.output
mapped = MyOperator.partial(
mapped = _RenderTemplateFieldsValidationOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"])
).expand(
map_static=output1,
map_template_literal=["{{ ds }}"],
map_template_xcom=output1,
map_template_file=["/path/to/file.ext"],
)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)

ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)

session.add(
TaskMap(
dag_id=dr.dag_id,
Expand All @@ -308,16 +328,16 @@ def execute(self, context):

mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
mapped_ti.map_index = 0

assert isinstance(mapped_ti.task, MappedOperator)
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)
assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator)

assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"
assert mapped_ti.task.partial_template == "a", "Should be rendered!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!"
assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!"
assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!"
assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!"


def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path):
Expand All @@ -327,46 +347,33 @@ def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_mak
file_template.write_text("loaded data")

with set_current_task_instance_session(session=session):

class MyOperator(BaseOperator):
template_fields = ("partial_template", "map_template", "file_template")
template_ext = (".ext",)

def __init__(
self, partial_template, partial_static, map_template, map_static, file_template, **kwargs
):
for value in [partial_template, partial_static, map_template, map_static, file_template]:
assert isinstance(value, str), "value should have been resolved before unmapping"
super().__init__(**kwargs)
self.partial_template = partial_template
self.partial_static = partial_static
self.map_template = map_template
self.map_static = map_static
self.file_template = file_template

def execute(self, context):
pass

with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()):
mapped = MyOperator.partial(
mapped = _RenderTemplateFieldsValidationOperator.partial(
task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}"
).expand_kwargs(
[{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}]
[
{
"map_template_literal": "{{ ds }}",
"map_static": "{{ ds }}",
"map_template_file": "/path/to/file.ext",
# This field is not tested since XCom inside a literal list
# is not rendered (matching BaseOperator rendering behavior).
"map_template_xcom": "",
}
]
)

dr = dag_maker.create_dagrun()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0)

assert isinstance(mapped_ti.task, MappedOperator)
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)
assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator)

assert mapped_ti.task.partial_template == "a", "Should be templated!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!"
assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.file_template == "loaded data", "Should be templated!"
assert mapped_ti.task.partial_template == "a", "Should be rendered!"
assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!"
assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!"
assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!"
assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!"


def test_mapped_render_nested_template_fields(dag_maker, session):
Expand Down

0 comments on commit 537ca7b

Please sign in to comment.