diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 42da4c83a6..66321ad58c 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -202,7 +202,9 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait): closures: List[StencilClosure] implicit_domain: bool = False - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS] + _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ + Sym(id=name) for name in sorted(BUILTINS) + ] # sorted for serialization stability class Stmt(Node): ... diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d729a5ba2f..8c31ea7b65 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,13 +10,17 @@ import dataclasses import functools +import os +import pathlib +import pickle from typing import Any, Callable, Final, Optional import factory import numpy as np +import gt4py.next.config from gt4py._core import definitions as core_defs -from gt4py.eve import codegen +from gt4py.eve import codegen, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins @@ -37,6 +41,27 @@ def get_param_description(name: str, type_: Any) -> interface.Parameter: return interface.Parameter(name, type_) +def _generate_stencil_source_cache_file_path( + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], +) -> pathlib.Path: + program_hash = utils.content_hash( + ( + program, + sorted(offset_provider.items(), key=lambda el: el[0]), + column_axis, + ) + ) + + if not os.path.exists(gt4py.next.config.BUILD_CACHE_DIR): + os.mkdir(gt4py.next.config.BUILD_CACHE_DIR) + + cache_path = gt4py.next.config.BUILD_CACHE_DIR / ("gtfn_" + program.id + "_" + program_hash) + + return cache_path + + @dataclasses.dataclass(frozen=True) class GTFNTranslationStep( workflow.ReplaceEnabledWorkflowMixin[ @@ -199,6 +224,13 @@ def generate_stencil_source( offset_provider: dict[str, Connectivity | Dimension], column_axis: Optional[common.Dimension], ) -> str: + # TODO(tehrengruber): write a proper caching mechanism + cache_path = _generate_stencil_source_cache_file_path(program, offset_provider, column_axis) + + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return pickle.load(f) + new_program = self._preprocess_program(program, offset_provider) gtfn_ir = GTFN_lowering.apply( new_program, offset_provider=offset_provider, column_axis=column_axis @@ -209,7 +241,11 @@ def generate_stencil_source( generated_code = GTFNIMCodegen.apply(gtfn_im_ir) else: generated_code = GTFNCodegen.apply(gtfn_ir) - return codegen.format_source("cpp", generated_code, style="LLVM") + result = codegen.format_source("cpp", generated_code, style="LLVM") + + with open(cache_path, "wb") as f: + pickle.dump(result, f) + return result def __call__( self, inp: stages.CompilableProgram diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e3e0ee474f..eb76abbed7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -10,6 +10,7 @@ import pytest import gt4py.next as gtx +import gt4py.next.config from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.otf import arguments, languages, stages @@ -71,3 +72,44 @@ def test_codegen(fencil_example): assert module.entry_point.name == fencil.id assert any(d.name == "gridtools_cpu" for d in module.library_deps) assert module.language is languages.CPP + + +def test_transformation_caching(fencil_example): + program, _ = fencil_example + args = dict( + program=program, + offset_provider={}, + column_axis=gtx.Dimension("K", kind=gtx.DimensionKind.VERTICAL), + ) + + # test cache file written is what the function returns + with tempfile.TemporaryDirectory() as cache_dir: + try: + prev_cache_dir = gt4py.next.config.BUILD_CACHE_DIR + gt4py.next.config.BUILD_CACHE_DIR = pathlib.Path(cache_dir) + + cache_file_path = gtfn_module._generate_stencil_source_cache_file_path(**args) + assert not os.path.exists(cache_file_path) + stencil_source = gtfn_module.translate_program_cpu.generate_stencil_source(**args) + assert os.path.exists(cache_file_path) + with open(cache_file_path, "rb") as f: + stencil_source_from_cache = pickle.load(f) + assert stencil_source == stencil_source_from_cache + except Exception as e: + raise e + finally: + gt4py.next.config.BUILD_CACHE_DIR = prev_cache_dir + + # test cache file is deterministic + assert gtfn_module._generate_stencil_source_cache_file_path( + **args + ) == gtfn_module._generate_stencil_source_cache_file_path(**args) + + # test cache file changes for a different program + altered_program = copy.deepcopy(program) + altered_program.id = "example2" + assert gtfn_module._generate_stencil_source_cache_file_path( + **args + ) != gtfn_module._generate_stencil_source_cache_file_path( + **(args | {"program": altered_program}) + )