From 6a3c93cdf421bcd8a75c68e4e793dff5cdb5a19a Mon Sep 17 00:00:00 2001 From: DropD Date: Thu, 13 Jun 2024 10:24:29 +0200 Subject: [PATCH 1/3] make past_to_itir cached in default transforms --- src/gt4py/next/backend.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 3c0d19853e..5f5fe65ba7 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -73,7 +73,11 @@ class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): dataclasses.field(default=past_process_args.past_process_args) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + dataclasses.field( + default_factory=lambda: workflow.CachedStep( + past_to_itir.PastToItirFactory(), hash_function=ffront_stages.fingerprint_stage + ) + ) ) foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( @@ -129,7 +133,11 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): ) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field(default_factory=past_to_itir.PastToItirFactory) + dataclasses.field( + default_factory=lambda: workflow.CachedStep( + past_to_itir.PastToItirFactory(), hash_function=ffront_stages.fingerprint_stage + ) + ) ) @@ -173,3 +181,4 @@ def __gt_allocator__( self, ) -> next_allocators.FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: return self.allocator + return self.allocator From cc086fdc32517fb37d9ce1e2d4d35063eb498ffa Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 17 Jun 2024 14:18:57 +0200 Subject: [PATCH 2/3] do not fingerprint PastClosure args for cached PastToItir --- src/gt4py/next/backend.py | 3 ++- src/gt4py/next/ffront/stages.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index 5f5fe65ba7..ef05a620de 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -135,7 +135,8 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( dataclasses.field( default_factory=lambda: workflow.CachedStep( - past_to_itir.PastToItirFactory(), hash_function=ffront_stages.fingerprint_stage + past_to_itir.PastToItirFactory(), + hash_function=ffront_stages.fingerprint_past_closure_noargs, ) ) ) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 1da6c85981..6ce8497ede 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -91,6 +91,19 @@ class PastClosure: kwargs: dict[str, Any] +def fingerprint_past_closure_noargs( + past_closure: PastClosure, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None +) -> str: + return fingerprint_stage( + obj={ + "closure_vars": past_closure.closure_vars, + "past_node": past_closure.past_node, + "grid_type": past_closure.grid_type, + }, + algorithm=algorithm, + ) + + def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str: hasher: xtyping.HashlibAlgorithm if not algorithm: From 8d1c63cda1ebeb87a197ff3b1f679fa2bc756e5a Mon Sep 17 00:00:00 2001 From: DropD Date: Mon, 17 Jun 2024 15:38:10 +0200 Subject: [PATCH 3/3] decouple past_to_itir from args / kwargs and cache separately --- src/gt4py/next/backend.py | 13 +------ src/gt4py/next/ffront/past_to_itir.py | 54 ++++++++++++++++++--------- src/gt4py/next/ffront/stages.py | 2 + 3 files changed, 40 insertions(+), 29 deletions(-) diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index ef05a620de..55eeb389a2 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -73,11 +73,7 @@ class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs): dataclasses.field(default=past_process_args.past_process_args) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field( - default_factory=lambda: workflow.CachedStep( - past_to_itir.PastToItirFactory(), hash_function=ffront_stages.fingerprint_stage - ) - ) + dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True)) ) foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = ( @@ -133,12 +129,7 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs): ) ) past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = ( - dataclasses.field( - default_factory=lambda: workflow.CachedStep( - past_to_itir.PastToItirFactory(), - hash_function=ffront_stages.fingerprint_past_closure_noargs, - ) - ) + dataclasses.field(default_factory=lambda: past_to_itir.PastToItirFactory(cached=True)) ) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index fb5c1a6882..084a12d821 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -37,31 +37,42 @@ from gt4py.next.type_system import type_info, type_specifications as ts +@workflow.make_step +def past_to_itir(inp: ffront_stages.PastProgramDefinition) -> stages.ProgramCall: + all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) + offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( + all_closure_vars, fbuiltins.FieldOffset, common.Dimension + ) + grid_type = transform_utils._deduce_grid_type(inp.grid_type, offsets_and_dimensions.values()) + + gt_callables = transform_utils._filter_closure_vars_by_type( + all_closure_vars, gtcallable.GTCallable + ).values() + lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + + itir_program = ProgramLowering.apply( + inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + ) + + return stages.ProgramCall( + program=itir_program, args=tuple(), kwargs={"column_axis": _column_axis(all_closure_vars)} + ) + + @dataclasses.dataclass(frozen=True) class PastToItir(workflow.ChainableWorkflowMixin): - def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: - all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars) - offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( - all_closure_vars, fbuiltins.FieldOffset, common.Dimension - ) - grid_type = transform_utils._deduce_grid_type( - inp.grid_type, offsets_and_dimensions.values() - ) - - gt_callables = transform_utils._filter_closure_vars_by_type( - all_closure_vars, gtcallable.GTCallable - ).values() - lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables] + inner: workflow.Workflow[ffront_stages.PastProgramDefinition, stages.ProgramCall] = past_to_itir - itir_program = ProgramLowering.apply( - inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type + def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall: + program_call = self.inner( + ffront_stages.PastProgramDefinition(inp.past_node, inp.closure_vars, inp.grid_type) ) if config.DEBUG or "debug" in inp.kwargs: - devtools.debug(itir_program) + devtools.debug(program_call.program) - return stages.ProgramCall( - itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)} + return dataclasses.replace( + program_call, args=inp.args, kwargs=inp.kwargs | program_call.kwargs ) @@ -69,6 +80,13 @@ class PastToItirFactory(factory.Factory): class Meta: model = PastToItir + class Params: + cached = factory.Trait( + inner=workflow.CachedStep(past_to_itir, hash_function=ffront_stages.fingerprint_stage) + ) + + inner = past_to_itir + def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]: # construct mapping from column axis to scan operators defined on diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 6ce8497ede..427c702591 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -99,6 +99,8 @@ def fingerprint_past_closure_noargs( "closure_vars": past_closure.closure_vars, "past_node": past_closure.past_node, "grid_type": past_closure.grid_type, + "args": past_closure.args, + "kwargs": past_closure.kwargs, }, algorithm=algorithm, )