Skip to content

Commit

Permalink
Enrique's suggestions/comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Jun 5, 2024
1 parent 35dc1f1 commit 4499cfb
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 433 deletions.
4 changes: 1 addition & 3 deletions CODING_GUIDELINES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@ We deviate from the [Google Python Style Guide][google-style-guide] only in the
- We use [`ruff-linter`][ruff-linter] instead of [`pylint`][pylint].
- We use [`ruff-formatter`][ruff-formatter] for source code and imports formatting, which may work differently than indicated by the guidelines in section [_3. Python Style Rules_](https://google.github.io/styleguide/pyguide.html#3-python-style-rules). For example, maximum line length is set to 100 instead of 79 (although docstring lines should still be limited to 79).
- According to subsection [_2.19 Power Features_](https://google.github.io/styleguide/pyguide.html#219-power-features), direct use of _power features_ (e.g. custom metaclasses, import hacks, reflection) should be avoided, but standard library classes that internally use these power features are accepted. Following the same spirit, we allow the use of power features in infrastructure code with similar functionality and scope as the Python standard library.
- For readability purposes, when a docstring contains more than the required summary line, we prefer indenting the first line at the same cursor position as the first opening quote, although this is not explicitly considered in the doctring conventions described in subsection [_3.8.1 Docstrings_](https://google.github.io/styleguide/pyguide.html#381-docstrings). Example:

```python
# single line docstring
"""A one-line summary of the module or program, terminated by a period."""

# multi-line docstring
"""
A one-line summary of the module or program, terminated by a period.
""" A one-line summary of the module or program, terminated by a period.
Leave one blank line. The rest of this docstring should contain an
overall description of the module or program.
Expand Down
11 changes: 6 additions & 5 deletions src/jace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ def jit(
) -> stages.JaCeWrapped | Callable[[Callable], stages.JaCeWrapped]:
"""JaCe's replacement for `jax.jit` (just-in-time) wrapper.
It works the same way as `jax.jit` does, but instead of using XLA the computation is lowered
to DaCe. In addition it accepts some JaCe specific arguments, it accepts the same arguments
as `jax.jit` does.
It works the same way as `jax.jit` does, but instead of using XLA the
computation is lowered to DaCe. In addition it accepts some JaCe specific
arguments.
Args:
primitive_translators: Use these primitive translators for the lowering to SDFG.
If not specified the translators in the global registry are used.
primitive_translators: Use these primitive translators for the lowering
to SDFG. If not specified the translators in the global registry are
used.
Notes:
After constructions any change to `primitive_translators` has no effect.
Expand Down
11 changes: 6 additions & 5 deletions src/jace/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,14 @@ def jace_optimize(
) -> None:
"""Performs optimization of the translated SDFG _in place_.
It is recommended to use the `CompilerOptions` `TypedDict` to pass options to the function.
However, any option that is not specified will be interpreted as to be disabled.
It is recommended to use the `CompilerOptions` `TypedDict` to pass options
to the function. However, any option that is not specified will be
interpreted as to be disabled.
Args:
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
tsdfg: The translated SDFG that should be optimized.
simplify: Run the simplification pipeline.
auto_optimize: Run the auto optimization pipeline (currently does nothing)
"""
# Currently this function exists primarily for the same of existing.

Expand Down
104 changes: 52 additions & 52 deletions src/jace/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
As in Jax JaCe has different stages, the terminology is taken from
[Jax' AOT-Tutorial](https://jax.readthedocs.io/en/latest/aot.html).
- Stage out:
In this phase we translate an executable python function into Jaxpr.
In this phase an executable Python function is translated to Jaxpr.
- Lower:
This will transform the Jaxpr into an SDFG equivalent. As a implementation note,
currently this and the previous step are handled as a single step.
This will transform the Jaxpr into an SDFG equivalent. As a implementation
note, currently this and the previous step are handled as a single step.
- Compile:
This will turn the SDFG into an executable object, see `dace.codegen.CompiledSDFG`.
- Execution:
This is the actual running of the computation.
As in Jax the `stages` module give access to the last three stages, but not the first one.
As in Jax the `stages` module give access to the last three stages, but not
the first one.
"""

from __future__ import annotations
Expand Down Expand Up @@ -54,26 +55,28 @@
class JaCeWrapped(tcache.CachingStage["JaCeLowered"]):
"""A function ready to be specialized, lowered, and compiled.
This class represents the output of functions such as `jace.jit()` and is the first stage in
the translation/compilation chain of JaCe. A user should never create a `JaCeWrapped` object
directly, instead `jace.jit` should be used for that.
While it supports just-in-time lowering and compilation, by just calling it, these steps can
also be performed explicitly. The lowering performed by this stage is cached, thus if a
`JaCeWrapped` object is lowered later, with the same argument the result is taken from the
cache. Furthermore, a `JaCeWrapped` object is composable with all Jax transformations.
This class represents the output of functions such as `jace.jit()` and is
the first stage in the translation/compilation chain of JaCe. A user should
never create a `JaCeWrapped` object directly, instead `jace.jit` should be
used for that. While it supports just-in-time lowering and compilation, by
just calling it, these steps can also be performed explicitly. The lowering
performed by this stage is cached, thus if a `JaCeWrapped` object is lowered
later, with the same argument the result is taken from the cache.
Furthermore, a `JaCeWrapped` object is composable with all Jax transformations.
Args:
fun: The function that is wrapped.
primitive_translators: The list of primitive translators that that should be used.
jit_options: Options to influence the jit process.
fun: The function that is wrapped.
primitive_translators: The list of primitive translators that that should be used.
jit_options: Options to influence the jit process.
Todo:
- Handle pytrees.
- Handle all options to `jax.jit`.
- Support pytrees.
- Support keyword arguments and default values of the wrapped function.
- Support static arguments.
Note:
The tracing of function will always happen with enabled `x64` mode, which is implicitly
and temporary activated while tracing.
The tracing of function will always happen with enabled `x64` mode,
which is implicitly and temporary activated while tracing.
"""

_fun: Callable
Expand Down Expand Up @@ -122,13 +125,14 @@ def lower(
) -> JaCeLowered:
"""Lower this function explicitly for the given arguments.
Performs the first two steps of the AOT steps described above, i.e. trace the wrapped
function with the given arguments and stage it out to a Jaxpr. Then translate it to SDFG.
The result is encapsulated inside a `JaCeLowered` object which can later be compiled.
Performs the first two steps of the AOT steps described above, i.e.
trace the wrapped function with the given arguments and stage it out
to a Jaxpr. Then translate it to SDFG. The result is encapsulated
inside a `JaCeLowered` object which can later be compiled.
Note:
The call to the function is cached. As key an abstract description of the call,
similar to the tracers used by Jax, is used.
The call to the function is cached. As key an abstract description
of the call, similar to the tracers used by Jax, is used.
The tracing is always done with activated `x64` mode.
"""
if len(kwargs) != 0:
Expand Down Expand Up @@ -175,10 +179,6 @@ def _make_call_description(
"""This function computes the key for the `JaCeWrapped.lower()` call inside the cache.
The function will compute a full abstract description on its argument.
Todo:
- Support keyword arguments and default values of the wrapped function.
- Support static arguments.
"""
call_args = tuple(tcache._AbstractCallArgument.from_value(x) for x in args)
return tcache.StageTransformationSpec(stage_id=id(self), call_args=call_args)
Expand All @@ -187,19 +187,18 @@ def _make_call_description(
class JaCeLowered(tcache.CachingStage["JaCeCompiled"]):
"""Represents the original computation as an SDFG.
This class represents the output of `JaCeWrapped.lower()` and represents the originally wrapped
computation as an SDFG. This stage is followed by the `JaCeCompiled` stage.
This class is the output type of `JaCeWrapped.lower()` and represents the
originally wrapped computation as an SDFG. This stage is followed by the
`JaCeCompiled` stage.
Args:
tsdfg: The translated SDFG object representing the computation.
tsdfg: The translated SDFG object representing the computation.
Note:
`self` will manage the passed `tsdfg` object. Modifying it results in undefined behavior.
Although `JaCeWrapped` is composable with Jax transformations `JaCeLowered` is not.
A user should never create such an object, instead `JaCeWrapped.lower()` should be used.
Todo:
- Handle pytrees.
`self` will manage the passed `tsdfg` object. Modifying it results in
undefined behavior. Although `JaCeWrapped` is composable with Jax
transformations `JaCeLowered` is not. A user should never create such
an object, instead `JaCeWrapped.lower()` should be used.
"""

_translated_sdfg: translator.TranslatedJaxprSDFG
Expand All @@ -218,13 +217,14 @@ def compile(
) -> JaCeCompiled:
"""Optimize and compile the lowered SDFG using `compiler_options`.
Returns an object that encapsulates a compiled SDFG object. To influence the various
optimizations and compile options of JaCe you can use the `compiler_options` argument.
If nothing is specified `jace.optimization.DEFAULT_OPTIMIZATIONS` will be used.
Returns an object that encapsulates a compiled SDFG object. To influence
the various optimizations and compile options of JaCe you can use the
`compiler_options` argument. If nothing is specified
`jace.optimization.DEFAULT_OPTIMIZATIONS` will be used.
Note:
Before `compiler_options` is forwarded to `jace_optimize()` it will be merged with
the default arguments.
Before `compiler_options` is forwarded to `jace_optimize()` it
will be merged with the default arguments.
"""
# We **must** deepcopy before we do any optimization, because all optimizations are in
# place, however, to properly cache stages, stages needs to be immutable.
Expand All @@ -240,8 +240,8 @@ def compile(
def compiler_ir(self, dialect: str | None = None) -> translator.TranslatedJaxprSDFG:
"""Returns the internal SDFG.
The function returns a `TranslatedJaxprSDFG` object. Direct modification of the returned
object is forbidden and will cause undefined behaviour.
The function returns a `TranslatedJaxprSDFG` object. Direct modification
of the returned object is forbidden and will cause undefined behaviour.
"""
if (dialect is None) or (dialect.upper() == "SDFG"):
return self._translated_sdfg
Expand All @@ -264,8 +264,8 @@ def _make_call_description(
) -> tcache.StageTransformationSpec:
"""This function computes the key for the `self.compile()` call inside the cache.
The key that is computed by this function is based on the concrete values of the passed
compiler options.
The key that is computed by this function is based on the concrete
values of the passed compiler options.
"""
options = self._make_compiler_options(compiler_options)
call_args = tuple(sorted(options.items(), key=lambda x: x[0]))
Expand All @@ -281,13 +281,13 @@ def _make_compiler_options(
class JaCeCompiled:
"""Compiled version of the SDFG.
This is the last stage of the jit chain. A user should never create a `JaCeCompiled` instance,
instead `JaCeLowered.compile()` should be used.
This is the last stage of the jit chain. A user should never create a
`JaCeCompiled` instance, instead `JaCeLowered.compile()` should be used.
Args:
csdfg: The compiled SDFG object.
inp_names: Names of the SDFG variables used as inputs.
out_names: Names of the SDFG variables used as outputs.
csdfg: The compiled SDFG object.
inp_names: Names of the SDFG variables used as inputs.
out_names: Names of the SDFG variables used as outputs.
Note:
The class assumes ownership of its input arguments.
Expand Down Expand Up @@ -319,8 +319,8 @@ def __call__(
) -> Any:
"""Calls the embedded computation.
The arguments must be the same as for the wrapped function, but with all static arguments
removed.
The arguments must be the same as for the wrapped function, but with
all static arguments removed.
"""
return util.run_jax_sdfg(
self._csdfg,
Expand Down
4 changes: 2 additions & 2 deletions src/jace/translator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

"""Subpackage containing all the code related to the Jaxpr to SDFG translation.
The concrete primitive translators that ships with JaCe are inside the `primitive_translators`
subpackage.
The concrete primitive translators that ships with JaCe are inside the
`primitive_translators` subpackage.
"""

from __future__ import annotations
Expand Down
Loading

0 comments on commit 4499cfb

Please sign in to comment.