diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index cdf5b402b5..26dde21be9 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -67,7 +67,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): dataclasses.field(default=past_process_args.past_process_args) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True)) ) foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( @@ -123,7 +123,7 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): ) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True)) ) @@ -167,3 +167,4 @@ def __gt_allocator__( self, ) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: return self.allocator + return self.allocator diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index a9021a27be..af5b7e19ec 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -31,31 +31,42 @@ from gt4py.next.type_system import type_info, type_specifications as ts +@workflow.make_step +def past_to_itir(inp: ffront_stages.PastProgramDefinition) -> stages.ProgramCall: + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( + all_closure_vars, fbuiltins.FieldOffset, common.Dimension + ) + grid_type = transform_utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + + gt_callables = transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + + itir_program = ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + return stages.ProgramCall( + program=itir_program, args=tuple(), kwargs={"column_axis": _column_axis(all_closure_vars)} + ) + + @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: - all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) - offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( - all_closure_vars, fbuiltins.FieldOffset, common.Dimension - ) - grid_type = transform_utils._deduce_grid_type( - inp.grid_type, offsets_and_dimensions.values() - ) - - gt_callables = transform_utils._filter_closure_vars_by_type( - all_closure_vars, gtcallable.GTCallable - ).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + inner: workflow.Workflow[ffront_stages.PastProgramDefinition, stages.ProgramCall] = past_to_itir - itir_program = ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: + program_call = self.inner( + ffront_stages.PastProgramDefinition(inp.past_node, inp.closure_vars, inp.grid_type) ) if config.DEBUG or "debug" in inp.kwargs: - devtools.debug(itir_program) + devtools.debug(program_call.program) - return stages.ProgramCall( - itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)} + return dataclasses.replace( + program_call, args=inp.args, kwargs=inp.kwargs | program_call.kwargs ) @@ -63,6 +74,13 @@ class PastToItirFactory(factory.Factory): class Meta: model = PastToItir + class Params: + cached = factory.Trait( + inner=workflow.CachedStep(past_to_itir, hash_function=ffront_stages.fingerprint_stage) + ) + + inner = past_to_itir + def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: # construct mapping from column axis to scan operators defined on diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 7402922ae9..7b7560751f 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -85,6 +85,21 @@ class PastClosure: kwargs: dict[str, Any] +def fingerprint_past_closure_noargs( + past_closure: PastClosure, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None +) -> str: + return fingerprint_stage( + obj={ + "closure_vars": past_closure.closure_vars, + "past_node": past_closure.past_node, + "grid_type": past_closure.grid_type, + "args": past_closure.args, + "kwargs": past_closure.kwargs, + }, + algorithm=algorithm, + ) + + def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: hasher: xtyping.HashlibAlgorithm if not algorithm: