From 971a025ffd59406c0d7c3e75e2d232d5a88e72f4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 6 Jun 2024 07:24:13 +0200 Subject: [PATCH] Allied some reformating. --- src/jace/stages.py | 19 ++- .../translator/jaxpr_translator_builder.py | 14 +- src/jace/util/__init__.py | 6 - src/jace/util/compiling.py | 138 ----------------- src/jace/util/dace_helper.py | 142 +++++++++++++++++- src/jace/util/jax_helper.py | 21 ++- src/jace/util/traits.py | 4 +- 7 files changed, 177 insertions(+), 167 deletions(-) delete mode 100644 src/jace/util/compiling.py diff --git a/src/jace/stages.py b/src/jace/stages.py index 9dbcb7e..224bc00 100644 --- a/src/jace/stages.py +++ b/src/jace/stages.py @@ -232,12 +232,15 @@ def compile( optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options)) return JaCeCompiled( - csdfg=util.compile_jax_sdfg(tsdfg), + csdfg=dace_helper.compile_jax_sdfg(tsdfg), inp_names=tsdfg.inp_names, out_names=tsdfg.out_names, ) - def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprSDFG: + def compiler_ir( + self, + dialect: str | None = None, + ) -> translator.TranslatedJaxprSDFG: """Returns the internal SDFG. The function returns a `TranslatedJaxprSDFG` object. Direct modification @@ -247,8 +250,14 @@ def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprS return self._translated_sdfg raise ValueError(f"Unknown dialect '{dialect}'.") - def as_html(self, filename: str | None = None) -> None: - """Runs the `view()` method of the underlying SDFG.""" + def view( + self, + filename: str | None = None, + ) -> None: + """Runs the `view()` method of the underlying SDFG. + + This will open a browser and display the SDFG. + """ self.compiler_ir().sdfg.view(filename=filename, verbose=False) def as_sdfg(self) -> dace.SDFG: @@ -322,7 +331,7 @@ def __call__( The arguments must be the same as for the wrapped function, but with all static arguments removed. """ - return util.run_jax_sdfg( + return dace_helper.run_jax_sdfg( self._csdfg, self._inp_names, self._out_names, diff --git a/src/jace/translator/jaxpr_translator_builder.py b/src/jace/translator/jaxpr_translator_builder.py index 4e42262..b143474 100644 --- a/src/jace/translator/jaxpr_translator_builder.py +++ b/src/jace/translator/jaxpr_translator_builder.py @@ -219,7 +219,9 @@ def map_jax_var_to_sdfg( @overload def map_jax_var_to_sdfg( - self, jax_var: jax_core.Atom | util.JaCeVar, allow_fail: Literal[True] + self, + jax_var: jax_core.Atom | util.JaCeVar, + allow_fail: Literal[True], ) -> str | None: ... def map_jax_var_to_sdfg( @@ -568,14 +570,14 @@ def _translate_single_eqn( update_var_mapping=True, ) - pname: str = eqn.primitive.name - if pname not in self._primitive_translators: - raise NotImplementedError(f"No translator known to handle '{pname}'.") - ptranslator = self._primitive_translators[pname] + primitive_name: str = eqn.primitive.name + if primitive_name not in self._primitive_translators: + raise NotImplementedError(f"No translator known to handle '{primitive_name}'.") + ptranslator = self._primitive_translators[primitive_name] # Create the state into which the equation should be translated eqn_state = self.append_new_state( - label=f"{pname}_{'_'.join(out_var_names)}", + label=f"{primitive_name}_{'_'.join(out_var_names)}", prev_state=None, # forces the creation of a new terminal state ) diff --git a/src/jace/util/__init__.py b/src/jace/util/__init__.py index 27bd032..778c645 100644 --- a/src/jace/util/__init__.py +++ b/src/jace/util/__init__.py @@ -9,10 +9,6 @@ from __future__ import annotations -from .compiling import ( - compile_jax_sdfg, - run_jax_sdfg, -) from .jax_helper import ( JaCeVar, get_jax_var_dtype, @@ -42,7 +38,6 @@ "VALID_SDFG_OBJ_NAME", "VALID_SDFG_VAR_NAME", "JaCeVar", - "compile_jax_sdfg", "dataclass_with_default_init", "get_jax_var_dtype", "get_jax_var_name", @@ -55,6 +50,5 @@ "is_scalar", "is_tracing_ongoing", "propose_jax_name", - "run_jax_sdfg", "translate_dtype", ] diff --git a/src/jace/util/compiling.py b/src/jace/util/compiling.py deleted file mode 100644 index 966b09d..0000000 --- a/src/jace/util/compiling.py +++ /dev/null @@ -1,138 +0,0 @@ -# JaCe - JAX Just-In-Time compilation using DaCe (Data Centric Parallel Programming) -# -# Copyright (c) 2024, ETH Zurich -# All rights reserved. -# -# SPDX-License-Identifier: BSD-3-Clause - -"""Contains everything for compiling and running `TranslatedJaxprSDFG` instances.""" - -from __future__ import annotations - -import time -from typing import TYPE_CHECKING, Any - -import dace -import numpy as np -from dace import data as dace_data - -from jace import util - - -if TYPE_CHECKING: - from collections.abc import Mapping, Sequence - - from jace import translator - from jace.util import dace_helper - - -def compile_jax_sdfg( - tsdfg: translator.TranslatedJaxprSDFG, -) -> dace_helper.CompiledSDFG: - """Compiles the SDFG embedded in `tsdfg` and return the resulting `CompiledSDFG` object.""" - if any( # We do not support the DaCe return mechanism - array_name.startswith("__return") - for array_name in tsdfg.sdfg.arrays.keys() # noqa: SIM118 # we can not use `in` because we are also interested in `__return_`! - ): - raise ValueError("Only support SDFGs without '__return' members.") - - # To ensure that the SDFG is compiled and to get rid of a warning we must modify - # some settings of the SDFG. To fake an immutable SDFG, we will restore them later. - sdfg = tsdfg.sdfg - org_sdfg_name = sdfg.name - org_recompile = sdfg._recompile - org_regenerate_code = sdfg._regenerate_code - - try: - # We need to give the SDFG another name, this is needed to prevent a DaCe error/warning. - # This happens if we compile the same lowered SDFG multiple times with different options. - sdfg.name = f"{sdfg.name}__comp_{int(time.time() * 1000)}" - - with dace.config.temporary_config(): - sdfg._recompile = True - sdfg._regenerate_code = True - dace.Config.set("compiler", "use_cache", value=False) - csdfg: dace_helper.CompiledSDFG = sdfg.compile() - - finally: - sdfg.name = org_sdfg_name - sdfg._recompile = org_recompile - sdfg._regenerate_code = org_regenerate_code - - return csdfg - - -def run_jax_sdfg( - csdfg: dace_helper.CompiledSDFG, - inp_names: Sequence[str], - out_names: Sequence[str], - call_args: Sequence[Any], - call_kwargs: Mapping[str, Any], -) -> tuple[Any, ...] | Any: - """Run the compiled SDFG. - - The function assumes that the SDFG was finalized and then compiled by - `compile_jax_sdfg()`. For running the SDFG you also have to pass the input - names (`inp_names`) and output names (`out_names`) that were inside the - `TranslatedJaxprSDFG` from which `csdfg` was compiled from. - - Args: - csdfg: The `CompiledSDFG` object. - inp_names: List of names of the input arguments. - out_names: List of names of the output arguments. - call_args: All positional arguments of the call. - call_kwargs: All keyword arguments of the call. - - Note: - There is no pytree mechanism jet, thus the return values are returned - inside a `tuple` or in case of one value, directly, in the order - determined by Jax. Furthermore, DaCe does not support scalar return - values, thus they are silently converted into arrays of length 1, the - same holds for inputs. - - Todo: - - Implement non C strides. - """ - sdfg: dace.SDFG = csdfg.sdfg - - if len(call_kwargs) != 0: - raise NotImplementedError("No kwargs are supported yet.") - if len(inp_names) != len(call_args): - raise RuntimeError("Wrong number of arguments.") - if sdfg.free_symbols: # This is a simplification that makes our life simple. - raise NotImplementedError( - f"No externally defined symbols are allowed, found: {sdfg.free_symbols}" - ) - - # Build the argument list that we will pass to the compiled object. - sdfg_call_args: dict[str, Any] = {} - for in_name, in_val in zip(inp_names, call_args, strict=True): - if util.is_scalar(in_val): - # Currently the translator makes scalar into arrays, this has to be reflected here - in_val = np.array([in_val]) - sdfg_call_args[in_name] = in_val - - for out_name, sdfg_array in ((out_name, sdfg.arrays[out_name]) for out_name in out_names): - if out_name in sdfg_call_args: - if util.is_jax_array(sdfg_call_args[out_name]): - # Jax arrays are immutable, so they can not be return values too. - raise ValueError("Passed a Jax array as output.") - else: - sdfg_call_args[out_name] = dace_data.make_array_from_descriptor(sdfg_array) - - assert len(sdfg_call_args) == len(csdfg.argnames), ( - "Failed to construct the call arguments," - f" expected {len(csdfg.argnames)} but got {len(call_args)}." - f"\nExpected: {csdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" - ) - - # Calling the SDFG - with dace.config.temporary_config(): - dace.Config.set("compiler", "allow_view_arguments", value=True) - csdfg(**sdfg_call_args) - - # Handling the output (pytrees are missing) - if not out_names: - return None - ret_val: tuple[Any] = tuple(sdfg_call_args[out_name] for out_name in out_names) - return ret_val[0] if len(out_names) == 1 else ret_val diff --git a/src/jace/util/dace_helper.py b/src/jace/util/dace_helper.py index a380272..613a59c 100644 --- a/src/jace/util/dace_helper.py +++ b/src/jace/util/dace_helper.py @@ -5,14 +5,144 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Implements all utility functions that are related to DaCe. - -Most of the functions defined here allow an unified access to DaCe's internals -in a consistent and stable way. -""" +"""Implements all utility functions that are related to DaCe.""" from __future__ import annotations +import time +from typing import TYPE_CHECKING, Any + +import dace +import numpy as np +from dace import data as dace_data + # The compiled SDFG is not available in the dace namespace or anywhere else # Thus we import it here directly -from dace.codegen.compiled_sdfg import CompiledSDFG as CompiledSDFG +from dace.codegen.compiled_sdfg import CompiledSDFG + +from jace import util + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from jace import translator + from jace.util import dace_helper + +__all__ = [ + "CompiledSDFG", + "compile_jax_sdfg", + "run_jax_sdfg", +] + + +def compile_jax_sdfg( + tsdfg: translator.TranslatedJaxprSDFG, +) -> dace_helper.CompiledSDFG: + """Compiles the SDFG embedded in `tsdfg` and return the resulting `CompiledSDFG` object.""" + if any( # We do not support the DaCe return mechanism + array_name.startswith("__return") + for array_name in tsdfg.sdfg.arrays.keys() # noqa: SIM118 # we can not use `in` because we are also interested in `__return_`! + ): + raise ValueError("Only support SDFGs without '__return' members.") + + # To ensure that the SDFG is compiled and to get rid of a warning we must modify + # some settings of the SDFG. To fake an immutable SDFG, we will restore them later. + sdfg = tsdfg.sdfg + org_sdfg_name = sdfg.name + org_recompile = sdfg._recompile + org_regenerate_code = sdfg._regenerate_code + + try: + # We need to give the SDFG another name, this is needed to prevent a DaCe error/warning. + # This happens if we compile the same lowered SDFG multiple times with different options. + sdfg.name = f"{sdfg.name}__comp_{int(time.time() * 1000)}" + + with dace.config.temporary_config(): + sdfg._recompile = True + sdfg._regenerate_code = True + dace.Config.set("compiler", "use_cache", value=False) + csdfg: dace_helper.CompiledSDFG = sdfg.compile() + + finally: + sdfg.name = org_sdfg_name + sdfg._recompile = org_recompile + sdfg._regenerate_code = org_regenerate_code + + return csdfg + + +def run_jax_sdfg( + csdfg: dace_helper.CompiledSDFG, + inp_names: Sequence[str], + out_names: Sequence[str], + call_args: Sequence[Any], + call_kwargs: Mapping[str, Any], +) -> tuple[Any, ...] | Any: + """Run the compiled SDFG. + + The function assumes that the SDFG was finalized and then compiled by + `compile_jax_sdfg()`. For running the SDFG you also have to pass the input + names (`inp_names`) and output names (`out_names`) that were inside the + `TranslatedJaxprSDFG` from which `csdfg` was compiled from. + + Args: + csdfg: The `CompiledSDFG` object. + inp_names: List of names of the input arguments. + out_names: List of names of the output arguments. + call_args: All positional arguments of the call. + call_kwargs: All keyword arguments of the call. + + Note: + There is no pytree mechanism jet, thus the return values are returned + inside a `tuple` or in case of one value, directly, in the order + determined by Jax. Furthermore, DaCe does not support scalar return + values, thus they are silently converted into arrays of length 1, the + same holds for inputs. + + Todo: + - Implement non C strides. + """ + sdfg: dace.SDFG = csdfg.sdfg + + if len(call_kwargs) != 0: + raise NotImplementedError("No kwargs are supported yet.") + if len(inp_names) != len(call_args): + raise RuntimeError("Wrong number of arguments.") + if sdfg.free_symbols: # This is a simplification that makes our life simple. + raise NotImplementedError( + f"No externally defined symbols are allowed, found: {sdfg.free_symbols}" + ) + + # Build the argument list that we will pass to the compiled object. + sdfg_call_args: dict[str, Any] = {} + for in_name, in_val in zip(inp_names, call_args, strict=True): + if util.is_scalar(in_val): + # Currently the translator makes scalar into arrays, this has to be reflected here + in_val = np.array([in_val]) + sdfg_call_args[in_name] = in_val + + for out_name, sdfg_array in ((out_name, sdfg.arrays[out_name]) for out_name in out_names): + if out_name in sdfg_call_args: + if util.is_jax_array(sdfg_call_args[out_name]): + # Jax arrays are immutable, so they can not be return values too. + raise ValueError("Passed a Jax array as output.") + else: + sdfg_call_args[out_name] = dace_data.make_array_from_descriptor(sdfg_array) + + assert len(sdfg_call_args) == len(csdfg.argnames), ( + "Failed to construct the call arguments," + f" expected {len(csdfg.argnames)} but got {len(call_args)}." + f"\nExpected: {csdfg.argnames}\nGot: {list(sdfg_call_args.keys())}" + ) + + # Calling the SDFG + with dace.config.temporary_config(): + dace.Config.set("compiler", "allow_view_arguments", value=True) + csdfg(**sdfg_call_args) + + # Handling the output (pytrees are missing) + if not out_names: + return None + ret_val: tuple[Any] = tuple(sdfg_call_args[out_name] for out_name in out_names) + return ret_val[0] if len(out_names) == 1 else ret_val diff --git a/src/jace/util/jax_helper.py b/src/jace/util/jax_helper.py index 8fa982f..ca6f60c 100644 --- a/src/jace/util/jax_helper.py +++ b/src/jace/util/jax_helper.py @@ -73,7 +73,10 @@ def __post_init__(self) -> None: def __hash__(self) -> int: return id(self) - def __eq__(self, other: Any) -> bool: + def __eq__( + self, + other: Any, + ) -> bool: if not isinstance(other, JaCeVar): return NotImplemented return id(self) == id(other) @@ -99,7 +102,9 @@ def get_jax_var_name(jax_var: jax_core.Atom | JaCeVar) -> str: ) -def get_jax_var_shape(jax_var: jax_core.Atom | JaCeVar) -> tuple[int | dace.symbol | str, ...]: +def get_jax_var_shape( + jax_var: jax_core.Atom | JaCeVar, +) -> tuple[int | dace.symbol | str, ...]: """Returns the shape of `jax_var`.""" match jax_var: case jax_core.Var() | jax_core.Literal(): @@ -112,7 +117,9 @@ def get_jax_var_shape(jax_var: jax_core.Atom | JaCeVar) -> tuple[int | dace.symb raise TypeError(f"'get_jax_var_shape()` is not implemented for '{type(jax_var)}'.") -def get_jax_var_dtype(jax_var: jax_core.Atom | JaCeVar) -> dace.typeclass: +def get_jax_var_dtype( + jax_var: jax_core.Atom | JaCeVar, +) -> dace.typeclass: """Returns the DaCe equivalent of `jax_var`s datatype.""" match jax_var: case jax_core.Var() | jax_core.Literal(): @@ -140,7 +147,9 @@ def is_tracing_ongoing( return any(isinstance(x, jax_core.Tracer) for x in itertools.chain(args, kwargs.values())) -def translate_dtype(dtype: Any) -> dace.typeclass: +def translate_dtype( + dtype: Any, +) -> dace.typeclass: """Turns a Jax datatype into a DaCe datatype.""" if dtype is None: raise NotImplementedError # Handling a special case in DaCe. @@ -201,7 +210,9 @@ def propose_jax_name( return jax_name -def get_jax_literal_value(lit: jax_core.Atom) -> bool | float | int | np.generic: +def get_jax_literal_value( + lit: jax_core.Atom, +) -> bool | float | int | np.generic: """Returns the value a literal is wrapping. The function guarantees to return a scalar value. diff --git a/src/jace/util/traits.py b/src/jace/util/traits.py index c9e9059..acada34 100644 --- a/src/jace/util/traits.py +++ b/src/jace/util/traits.py @@ -19,7 +19,9 @@ import jace.util as util -def is_drop_var(jax_var: jax_core.Atom | util.JaCeVar) -> TypeGuard[jax_core.DropVar]: +def is_drop_var( + jax_var: jax_core.Atom | util.JaCeVar, +) -> TypeGuard[jax_core.DropVar]: """Tests if `jax_var` is a drop variable, i.e. a variable that is not read from in a Jaxpr.""" if isinstance(jax_var, jax_core.DropVar):