Skip to content

Commit

Permalink
Allied some reformating.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Jun 6, 2024
1 parent 4499cfb commit 971a025
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 167 deletions.
19 changes: 14 additions & 5 deletions src/jace/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,15 @@ def compile(
optimization.jace_optimize(tsdfg=tsdfg, **self._make_compiler_options(compiler_options))

return JaCeCompiled(
csdfg=util.compile_jax_sdfg(tsdfg),
csdfg=dace_helper.compile_jax_sdfg(tsdfg),
inp_names=tsdfg.inp_names,
out_names=tsdfg.out_names,
)

def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprSDFG:
def compiler_ir(
self,
dialect: str | None = None,
) -> translator.TranslatedJaxprSDFG:
"""Returns the internal SDFG.
The function returns a `TranslatedJaxprSDFG` object. Direct modification
Expand All @@ -247,8 +250,14 @@ def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprS
return self._translated_sdfg
raise ValueError(f"Unknown dialect '{dialect}'.")

def as_html(self, filename: str | None = None) -> None:
"""Runs the `view()` method of the underlying SDFG."""
def view(
self,
filename: str | None = None,
) -> None:
"""Runs the `view()` method of the underlying SDFG.
This will open a browser and display the SDFG.
"""
self.compiler_ir().sdfg.view(filename=filename, verbose=False)

def as_sdfg(self) -> dace.SDFG:
Expand Down Expand Up @@ -322,7 +331,7 @@ def __call__(
The arguments must be the same as for the wrapped function, but with
all static arguments removed.
"""
return util.run_jax_sdfg(
return dace_helper.run_jax_sdfg(
self._csdfg,
self._inp_names,
self._out_names,
Expand Down
14 changes: 8 additions & 6 deletions src/jace/translator/jaxpr_translator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def map_jax_var_to_sdfg(

@overload
def map_jax_var_to_sdfg(
self, jax_var: jax_core.Atom | util.JaCeVar, allow_fail: Literal[True]
self,
jax_var: jax_core.Atom | util.JaCeVar,
allow_fail: Literal[True],
) -> str | None: ...

def map_jax_var_to_sdfg(
Expand Down Expand Up @@ -568,14 +570,14 @@ def _translate_single_eqn(
update_var_mapping=True,
)

pname: str = eqn.primitive.name
if pname not in self._primitive_translators:
raise NotImplementedError(f"No translator known to handle '{pname}'.")
ptranslator = self._primitive_translators[pname]
primitive_name: str = eqn.primitive.name
if primitive_name not in self._primitive_translators:
raise NotImplementedError(f"No translator known to handle '{primitive_name}'.")
ptranslator = self._primitive_translators[primitive_name]

# Create the state into which the equation should be translated
eqn_state = self.append_new_state(
label=f"{pname}_{'_'.join(out_var_names)}",
label=f"{primitive_name}_{'_'.join(out_var_names)}",
prev_state=None, # forces the creation of a new terminal state
)

Expand Down
6 changes: 0 additions & 6 deletions src/jace/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

from __future__ import annotations

from .compiling import (
compile_jax_sdfg,
run_jax_sdfg,
)
from .jax_helper import (
JaCeVar,
get_jax_var_dtype,
Expand Down Expand Up @@ -42,7 +38,6 @@
"VALID_SDFG_OBJ_NAME",
"VALID_SDFG_VAR_NAME",
"JaCeVar",
"compile_jax_sdfg",
"dataclass_with_default_init",
"get_jax_var_dtype",
"get_jax_var_name",
Expand All @@ -55,6 +50,5 @@
"is_scalar",
"is_tracing_ongoing",
"propose_jax_name",
"run_jax_sdfg",
"translate_dtype",
]
138 changes: 0 additions & 138 deletions src/jace/util/compiling.py

This file was deleted.

142 changes: 136 additions & 6 deletions src/jace/util/dace_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,144 @@
#
# SPDX-License-Identifier: BSD-3-Clause

"""Implements all utility functions that are related to DaCe.
Most of the functions defined here allow an unified access to DaCe's internals
in a consistent and stable way.
"""
"""Implements all utility functions that are related to DaCe."""

from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any

import dace
import numpy as np
from dace import data as dace_data

# The compiled SDFG is not available in the dace namespace or anywhere else
# Thus we import it here directly
from dace.codegen.compiled_sdfg import CompiledSDFG as CompiledSDFG
from dace.codegen.compiled_sdfg import CompiledSDFG

from jace import util


if TYPE_CHECKING:
from collections.abc import Mapping, Sequence

from jace import translator
from jace.util import dace_helper

__all__ = [
"CompiledSDFG",
"compile_jax_sdfg",
"run_jax_sdfg",
]


def compile_jax_sdfg(
tsdfg: translator.TranslatedJaxprSDFG,
) -> dace_helper.CompiledSDFG:
"""Compiles the SDFG embedded in `tsdfg` and return the resulting `CompiledSDFG` object."""
if any( # We do not support the DaCe return mechanism
array_name.startswith("__return")
for array_name in tsdfg.sdfg.arrays.keys() # noqa: SIM118 # we can not use `in` because we are also interested in `__return_`!
):
raise ValueError("Only support SDFGs without '__return' members.")

# To ensure that the SDFG is compiled and to get rid of a warning we must modify
# some settings of the SDFG. To fake an immutable SDFG, we will restore them later.
sdfg = tsdfg.sdfg
org_sdfg_name = sdfg.name
org_recompile = sdfg._recompile
org_regenerate_code = sdfg._regenerate_code

try:
# We need to give the SDFG another name, this is needed to prevent a DaCe error/warning.
# This happens if we compile the same lowered SDFG multiple times with different options.
sdfg.name = f"{sdfg.name}__comp_{int(time.time() * 1000)}"

with dace.config.temporary_config():
sdfg._recompile = True
sdfg._regenerate_code = True
dace.Config.set("compiler", "use_cache", value=False)
csdfg: dace_helper.CompiledSDFG = sdfg.compile()

finally:
sdfg.name = org_sdfg_name
sdfg._recompile = org_recompile
sdfg._regenerate_code = org_regenerate_code

return csdfg


def run_jax_sdfg(
csdfg: dace_helper.CompiledSDFG,
inp_names: Sequence[str],
out_names: Sequence[str],
call_args: Sequence[Any],
call_kwargs: Mapping[str, Any],
) -> tuple[Any, ...] | Any:
"""Run the compiled SDFG.
The function assumes that the SDFG was finalized and then compiled by
`compile_jax_sdfg()`. For running the SDFG you also have to pass the input
names (`inp_names`) and output names (`out_names`) that were inside the
`TranslatedJaxprSDFG` from which `csdfg` was compiled from.
Args:
csdfg: The `CompiledSDFG` object.
inp_names: List of names of the input arguments.
out_names: List of names of the output arguments.
call_args: All positional arguments of the call.
call_kwargs: All keyword arguments of the call.
Note:
There is no pytree mechanism jet, thus the return values are returned
inside a `tuple` or in case of one value, directly, in the order
determined by Jax. Furthermore, DaCe does not support scalar return
values, thus they are silently converted into arrays of length 1, the
same holds for inputs.
Todo:
- Implement non C strides.
"""
sdfg: dace.SDFG = csdfg.sdfg

if len(call_kwargs) != 0:
raise NotImplementedError("No kwargs are supported yet.")
if len(inp_names) != len(call_args):
raise RuntimeError("Wrong number of arguments.")
if sdfg.free_symbols: # This is a simplification that makes our life simple.
raise NotImplementedError(
f"No externally defined symbols are allowed, found: {sdfg.free_symbols}"
)

# Build the argument list that we will pass to the compiled object.
sdfg_call_args: dict[str, Any] = {}
for in_name, in_val in zip(inp_names, call_args, strict=True):
if util.is_scalar(in_val):
# Currently the translator makes scalar into arrays, this has to be reflected here
in_val = np.array([in_val])
sdfg_call_args[in_name] = in_val

for out_name, sdfg_array in ((out_name, sdfg.arrays[out_name]) for out_name in out_names):
if out_name in sdfg_call_args:
if util.is_jax_array(sdfg_call_args[out_name]):
# Jax arrays are immutable, so they can not be return values too.
raise ValueError("Passed a Jax array as output.")
else:
sdfg_call_args[out_name] = dace_data.make_array_from_descriptor(sdfg_array)

assert len(sdfg_call_args) == len(csdfg.argnames), (
"Failed to construct the call arguments,"
f" expected {len(csdfg.argnames)} but got {len(call_args)}."
f"\nExpected: {csdfg.argnames}\nGot: {list(sdfg_call_args.keys())}"
)

# Calling the SDFG
with dace.config.temporary_config():
dace.Config.set("compiler", "allow_view_arguments", value=True)
csdfg(**sdfg_call_args)

# Handling the output (pytrees are missing)
if not out_names:
return None
ret_val: tuple[Any] = tuple(sdfg_call_args[out_name] for out_name in out_names)
return ret_val[0] if len(out_names) == 1 else ret_val
Loading

0 comments on commit 971a025

Please sign in to comment.