diff --git a/.github/workflows/python_format.yml b/.github/workflows/python_format.yml new file mode 100644 index 00000000..f7f8b035 --- /dev/null +++ b/.github/workflows/python_format.yml @@ -0,0 +1,12 @@ +name: Python-Black + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + with: + options: "--check --verbose" diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 99a9cfa8..56bfc3c3 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -22,22 +22,43 @@ LANG_LLVM = enzyme_call.Language.LLVM LANG_MHLO = enzyme_call.Language.MHLO + def resource_dir(): - import os - dn = os.path.dirname(enzyme_call.__file__) - res = os.path.join(dn, "..", "..", "clang", "staging") - return res + import os + + dn = os.path.dirname(enzyme_call.__file__) + res = os.path.join(dn, "..", "..", "clang", "staging") + return res + def cflags(): - import platform - import os - if platform.system() == 'Darwin': - res = ('-isysroot', '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk', "-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", "-internal-isystem", os.path.join(resource_dir(), "include"), "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", "-fgnuc-version=4.2.1") - else: - res = () - if os.getenv("ENABLE_GDBLISTENER") is not None: - res = res + ('-debug-info-kind=standalone', '-dwarf-version=5', '-debugger-tuning=gdb',) - return res + import platform + import os + + if platform.system() == "Darwin": + res = ( + "-isysroot", + "/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk", + "-isystem", + "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", + "-internal-isystem", + os.path.join(resource_dir(), "include"), + "-internal-externc-isystem", + "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", + "-internal-externc-isystem", + "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", + "-fgnuc-version=4.2.1", + ) + else: + res = () + if os.getenv("ENABLE_GDBLISTENER") is not None: + res = res + ( + "-debug-info-kind=standalone", + "-dwarf-version=5", + "-debugger-tuning=gdb", + ) + return res + def _enzyme_primal_impl( *args_flat: jax.Array, @@ -45,10 +66,11 @@ def _enzyme_primal_impl( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: - del args_flat, source, out_shapes - raise RuntimeError("must be JIT'ed") + del args_flat, source, out_shapes + raise RuntimeError("must be JIT'ed") + def _enzyme_fwd_impl( *args_flat: jax.Array, @@ -56,10 +78,11 @@ def _enzyme_fwd_impl( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: - del args_flat, source, out_shapes - raise RuntimeError("must be JIT'ed") + del args_flat, source, out_shapes + raise RuntimeError("must be JIT'ed") + def _enzyme_aug_impl( *args_flat: jax.Array, @@ -67,10 +90,11 @@ def _enzyme_aug_impl( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: - del args_flat, source, out_shapes - raise RuntimeError("must be JIT'ed") + del args_flat, source, out_shapes + raise RuntimeError("must be JIT'ed") + def _enzyme_shadow_aug_impl( *args_flat: jax.Array, @@ -78,10 +102,11 @@ def _enzyme_shadow_aug_impl( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: - del args_flat, source, out_shapes - raise RuntimeError("must be JIT'ed") + del args_flat, source, out_shapes + raise RuntimeError("must be JIT'ed") + def _enzyme_rev_impl( *args_flat: jax.Array, @@ -89,10 +114,11 @@ def _enzyme_rev_impl( fn: str, argv: Sequence[str], in_shapes, - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.Array]: - del args_flat, source, out_shapes - raise RuntimeError("must be JIT'ed") + del args_flat, source, out_shapes + raise RuntimeError("must be JIT'ed") + def _enzyme_primal_abstract_eval( *args_flat: jax.core.ShapedArray, @@ -100,11 +126,12 @@ def _enzyme_primal_abstract_eval( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: - # TODO: we may attempt some lightweight parsing of source to extract the - # result types instead. - return out_shapes + # TODO: we may attempt some lightweight parsing of source to extract the + # result types instead. + return out_shapes + def _enzyme_fwd_abstract_eval( *args_flat: jax.core.ShapedArray, @@ -114,13 +141,15 @@ def _enzyme_fwd_abstract_eval( out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: - del source, fn, args_flat - return tuple(o for o in out_shapes for _ in range(2)) + del source, fn, args_flat + return tuple(o for o in out_shapes for _ in range(2)) + def absmaketup(ty): - tystr = ty.dtype.__str__() - tystr = {'float32':'float','float64':'double'}[tystr] - return (tystr, ty.shape) + tystr = ty.dtype.__str__() + tystr = {"float32": "float", "float64": "double"}[tystr] + return (tystr, ty.shape) + def _enzyme_aug_abstract_eval( *args_flat: jax.core.ShapedArray, @@ -128,31 +157,34 @@ def _enzyme_aug_abstract_eval( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang : enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: + in_shapes = args_flat - in_shapes = args_flat + prev_out_shapes = out_shapes - prev_out_shapes = out_shapes + out_shapes = [absmaketup(a) for a in out_shapes] - out_shapes = [absmaketup(a) for a in out_shapes] + in_shapes = [absmaketup(a) for a in in_shapes] - in_shapes = [absmaketup(a) for a in in_shapes] + if lang == LANG_MHLO: + (in_tree, func) = source + avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) + lowered_func = jax.jit(func).lower(*avals_in) + mhlo = lowered_func.compiler_ir(dialect="mhlo") + source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) - lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') - source = str(mhlo) - kept = lowered_func.compile()._executable._kept_var_idx - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + argv = argv + ("-resource-dir", resource_dir()) + cflags() - argv = argv + ( "-resource-dir", resource_dir()) + cflags() - - tapeSize, tmpSize = enzyme_call.tape_and_tmp_size(source, fn, out_shapes, in_shapes, argv, lang) - res = tuple(prev_out_shapes) + (jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)),) - return res + tapeSize, tmpSize = enzyme_call.tape_and_tmp_size( + source, fn, out_shapes, in_shapes, argv, lang + ) + res = tuple(prev_out_shapes) + ( + jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)), + ) + return res def _enzyme_shadow_aug_abstract_eval( @@ -161,9 +193,10 @@ def _enzyme_shadow_aug_abstract_eval( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: - return out_shapes + return out_shapes + def _enzyme_rev_abstract_eval( *args_flat: jax.core.ShapedArray, @@ -171,20 +204,25 @@ def _enzyme_rev_abstract_eval( fn: str, argv: Sequence[str], in_shapes, - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: - return tuple(jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes) + return tuple( + jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes + ) + def maketup(ty): - ty = ir.RankedTensorType(ty) - tystr = ty.element_type.__str__() - tystr = {'f32':'float','f64':'double','i32':'int32_t','i64':'int64_t'}[tystr] - return (tystr, ty.shape) + ty = ir.RankedTensorType(ty) + tystr = ty.element_type.__str__() + tystr = {"f32": "float", "f64": "double", "i32": "int32_t", "i64": "int64_t"}[tystr] + return (tystr, ty.shape) + def to_jax(ty): - tystr = ty.__str__() - return {'f32':jnp.float32,'f64':jnp.float64}[tystr] - + tystr = ty.__str__() + return {"f32": jnp.float32, "f64": jnp.float64}[tystr] + + def _enzyme_primal_lowering( ctx: jax_mlir.LoweringRuleContext, *args_flat: ir.Value, @@ -192,50 +230,51 @@ def _enzyme_primal_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: - del out_shapes + del out_shapes + + out_types = tuple(itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out))) - out_types = tuple( - itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) - ) + out_shapes = list(map(maketup, out_types)) + in_shapes = list(map(lambda x: maketup(x.type), args_flat)) - out_shapes = list(map(maketup, out_types)) - in_shapes = list(map(lambda x: maketup(x.type), args_flat)) + in_args = (*args_flat,) - in_args = (*args_flat,) + if lang == LANG_MHLO: + (in_tree, func) = source + avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) + lowered_func = jax.jit(func).lower(*avals_in) + mhlo = lowered_func.compiler_ir(dialect="mhlo") + source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) - lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') - source = str(mhlo) - kept = lowered_func.compile()._executable._kept_var_idx - in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + argv = argv + ("-resource-dir", resource_dir()) + cflags() + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Primal, lang + ) + identifier_attr = jax_mlir.dense_int_elements([identifier]) + identifier_op = stablehlo.ConstantOp(identifier_attr) - argv = argv + ( "-resource-dir", resource_dir() ) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Primal, lang) - identifier_attr = jax_mlir.dense_int_elements([identifier]) - identifier_op = stablehlo.ConstantOp(identifier_attr) + mlir_args = (identifier_op,) + in_args - mlir_args = (identifier_op,) + in_args + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa,) - if tmpBuf != 0: - sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) - out_types = out_types + (sa,) + custom_call = stablehlo.CustomCallOp( + out_types, mlir_args, call_target_name="jaxzyme.primal" + ) - custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.primal" - ) + results = custom_call.results - results = custom_call.results + if tmpBuf != 0: + results = results[:-1] - if tmpBuf != 0: - results = results[:-1] + return results - return results def _enzyme_fwd_lowering( ctx: jax_mlir.LoweringRuleContext, @@ -244,50 +283,50 @@ def _enzyme_fwd_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: - del out_shapes + del out_shapes - out_types = tuple( - itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) - ) + out_types = tuple(itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out))) - out_shapes = list(map(maketup, out_types[::2])) + out_shapes = list(map(maketup, out_types[::2])) - in_shapes = list(map(lambda x: maketup(x.type), args_flat[::2])) - - in_args = (*args_flat,) + in_shapes = list(map(lambda x: maketup(x.type), args_flat[::2])) - if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) - lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') - source = str(mhlo) - kept = lowered_func.compile()._executable._kept_var_idx - in_args = tuple(arg for (i, arg) in enumerate(in_args) if i//2 in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + in_args = (*args_flat,) - argv = argv + ( "-resource-dir", resource_dir() ) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Forward, lang) - identifier_attr = jax_mlir.dense_int_elements([identifier]) - identifier_op = stablehlo.ConstantOp(identifier_attr) + if lang == LANG_MHLO: + (in_tree, func) = source + avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) + lowered_func = jax.jit(func).lower(*avals_in) + mhlo = lowered_func.compiler_ir(dialect="mhlo") + source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i // 2 in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - mlir_args = (identifier_op,) + in_args + argv = argv + ("-resource-dir", resource_dir()) + cflags() + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Forward, lang + ) + identifier_attr = jax_mlir.dense_int_elements([identifier]) + identifier_op = stablehlo.ConstantOp(identifier_attr) - if tmpBuf != 0: - sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) - out_types = out_types + (sa,sa) + mlir_args = (identifier_op,) + in_args - custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.fwd" - ) + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa, sa) - results = custom_call.results - if tmpBuf != 0: - results = results[:-2] + custom_call = stablehlo.CustomCallOp( + out_types, mlir_args, call_target_name="jaxzyme.fwd" + ) - return results + results = custom_call.results + if tmpBuf != 0: + results = results[:-2] + + return results def _enzyme_aug_lowering( @@ -297,48 +336,49 @@ def _enzyme_aug_lowering( fn: str, argv: Sequence[str], out_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: - del out_shapes - - out_types = tuple( - itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) - ) - - out_shapes = list(map(maketup, out_types[:len(out_types)-1])) - - in_shapes = list(map(lambda x: maketup(x.type), args_flat)) - - in_args = (*args_flat,) - - if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) - lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') - source = str(mhlo) - kept = lowered_func.compile()._executable._kept_var_idx - in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - - argv = argv + ( "-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Augmented, lang) - identifier_attr = jax_mlir.dense_int_elements([identifier]) - identifier_op = stablehlo.ConstantOp(identifier_attr) - - if tmpBuf != 0: - sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) - out_types = out_types + (sa,) - - mlir_args = (identifier_op,) + in_args - custom_call = stablehlo.CustomCallOp( - out_types, mlir_args, call_target_name="jaxzyme.aug" - ) - - results = custom_call.results - if tmpBuf != 0: - results = results[:-1] - return results + del out_shapes + + out_types = tuple(itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out))) + + out_shapes = list(map(maketup, out_types[: len(out_types) - 1])) + + in_shapes = list(map(lambda x: maketup(x.type), args_flat)) + + in_args = (*args_flat,) + + if lang == LANG_MHLO: + (in_tree, func) = source + avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) + lowered_func = jax.jit(func).lower(*avals_in) + mhlo = lowered_func.compiler_ir(dialect="mhlo") + source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + + argv = argv + ("-resource-dir", resource_dir()) + cflags() + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Augmented, lang + ) + identifier_attr = jax_mlir.dense_int_elements([identifier]) + identifier_op = stablehlo.ConstantOp(identifier_attr) + + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa,) + + mlir_args = (identifier_op,) + in_args + custom_call = stablehlo.CustomCallOp( + out_types, mlir_args, call_target_name="jaxzyme.aug" + ) + + results = custom_call.results + if tmpBuf != 0: + results = results[:-1] + return results + def _enzyme_rev_lowering( ctx: jax_mlir.LoweringRuleContext, @@ -347,73 +387,103 @@ def _enzyme_rev_lowering( fn: str, argv: Sequence[str], in_shapes: Sequence[jax.core.ShapedArray], - lang: enzyme_call.Language + lang: enzyme_call.Language, ) -> Sequence[ir.Value]: - del in_shapes - - pre_in_types = tuple( - itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) - ) - - in_shapes = list(map(maketup, pre_in_types)) - pre_in_shapes = in_shapes - - out_shapes = list(map(lambda x: maketup(x.type), args_flat[1:])) - - in_args = (*args_flat,) - - rev_return_types = pre_in_types - - kept = None - if lang == LANG_MHLO: - (in_tree, func) = source - avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) - lowered_func = jax.jit(func).lower(*avals_in) - mhlo = lowered_func.compiler_ir(dialect='mhlo') - source = str(mhlo) - kept = lowered_func.compile()._executable._kept_var_idx - # in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) - in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] - rev_return_types = tuple(retty for (i, retty) in enumerate(rev_return_types) if i in kept) - - argv = tuple(argv) + ( "-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Reverse, lang) - identifier_attr = jax_mlir.dense_int_elements([identifier]) - identifier_op = stablehlo.ConstantOp(identifier_attr) - - mlir_args = (identifier_op,) + in_args - - if tmpBuf != 0: - sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) - rev_return_types = rev_return_types + (sa,) - - custom_call = stablehlo.CustomCallOp( - rev_return_types, mlir_args, call_target_name="jaxzyme.rev" - ) - results = custom_call.results - if tmpBuf != 0: - results = results[:-1] - if kept != None: - results = [] - cur_idx = 0 - for i, ty in enumerate(pre_in_types): - if i in kept: - results.append(custom_call.results[cur_idx]) - cur_idx += 1 - else: - ty = ir.RankedTensorType(ty) - shape = ty.shape - element_type = ty.element_type - import numpy as np - results.append(stablehlo.ConstantOp(ir.DenseElementsAttr.get(np.zeros(shape, dtype=to_jax(element_type)))).results[0]) - return results - -def ffi_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source, fn:str="f", argv: tuple[str]=(), lang:int=LANG_CPP): - return _enzyme_primal_p.bind( - *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang) - -def cpp_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source: str, fn:str="f", argv: tuple[str]=()): - return ffi_call(*args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP) + del in_shapes + + pre_in_types = tuple( + itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) + ) + + in_shapes = list(map(maketup, pre_in_types)) + pre_in_shapes = in_shapes + + out_shapes = list(map(lambda x: maketup(x.type), args_flat[1:])) + + in_args = (*args_flat,) + + rev_return_types = pre_in_types + + kept = None + if lang == LANG_MHLO: + (in_tree, func) = source + avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) + lowered_func = jax.jit(func).lower(*avals_in) + mhlo = lowered_func.compiler_ir(dialect="mhlo") + source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + # in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + rev_return_types = tuple( + retty for (i, retty) in enumerate(rev_return_types) if i in kept + ) + + argv = tuple(argv) + ("-resource-dir", resource_dir()) + cflags() + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Reverse, lang + ) + identifier_attr = jax_mlir.dense_int_elements([identifier]) + identifier_op = stablehlo.ConstantOp(identifier_attr) + + mlir_args = (identifier_op,) + in_args + + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + rev_return_types = rev_return_types + (sa,) + + custom_call = stablehlo.CustomCallOp( + rev_return_types, mlir_args, call_target_name="jaxzyme.rev" + ) + results = custom_call.results + if tmpBuf != 0: + results = results[:-1] + if kept != None: + results = [] + cur_idx = 0 + for i, ty in enumerate(pre_in_types): + if i in kept: + results.append(custom_call.results[cur_idx]) + cur_idx += 1 + else: + ty = ir.RankedTensorType(ty) + shape = ty.shape + element_type = ty.element_type + import numpy as np + + results.append( + stablehlo.ConstantOp( + ir.DenseElementsAttr.get( + np.zeros(shape, dtype=to_jax(element_type)) + ) + ).results[0] + ) + return results + + +def ffi_call( + *args, + out_shapes: Sequence[jax.core.ShapedArray], + source, + fn: str = "f", + argv: tuple[str] = (), + lang: int = LANG_CPP, +): + return _enzyme_primal_p.bind( + *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=lang + ) + + +def cpp_call( + *args, + out_shapes: Sequence[jax.core.ShapedArray], + source: str, + fn: str = "f", + argv: tuple[str] = (), +): + return ffi_call( + *args, source=source, fn=fn, argv=argv, out_shapes=out_shapes, lang=LANG_CPP + ) + _enzyme_primal_p = jax.core.Primitive("enzyme_primal") _enzyme_primal_p.multiple_results = True @@ -435,26 +505,36 @@ def cpp_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source: str, fn: "jaxzyme.fwd", enzyme_call.get_cpu_callback(), platform="cpu" ) + def enzyme_jvp(arg_primals, arg_tangents, **kwargs): - - # TODO propagate activity info rather than make_zero - def make_zero(tan, prim): - return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan - - arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)) - args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t) - shadconv = _enzyme_fwd_p.bind( - *args, source=kwargs['source'], fn=kwargs['fn'], argv=kwargs['argv'], out_shapes=kwargs['out_shapes'], lang=kwargs['lang']) - res = (shadconv[0::2], shadconv[1::2]) - return res + # TODO propagate activity info rather than make_zero + def make_zero(tan, prim): + return lax.zeros_like_array(prim) if type(tan) is ad.Zero else tan + + arg_tangents = tuple(make_zero(t, p) for (t, p) in zip(arg_tangents, arg_primals)) + args = tuple(v for t in zip(arg_primals, arg_tangents) for v in t) + shadconv = _enzyme_fwd_p.bind( + *args, + source=kwargs["source"], + fn=kwargs["fn"], + argv=kwargs["argv"], + out_shapes=kwargs["out_shapes"], + lang=kwargs["lang"], + ) + res = (shadconv[0::2], shadconv[1::2]) + return res + ad.primitive_jvps[_enzyme_primal_p] = enzyme_jvp + def jaxify(x): - return {'float32':0, 'float64':1}[x.__str__()] + return {"float32": 0, "float64": 1}[x.__str__()] + def dejaxify(x): - return {0:jnp.float32, 1:jnp.float64}[x] + return {0: jnp.float32, 1: jnp.float64}[x] + _enzyme_aug_p = jax.core.Primitive("enzyme_aug") _enzyme_aug_p.multiple_results = True @@ -484,55 +564,72 @@ def dejaxify(x): from jax._src.interpreters import partial_eval as pe + def fwd_partial_eval(trace, *args, **kwargs): - assert len(args) % 2 == 0 - nr_primals = len(args) // 2 - primals, tangents = args[0::2], args[1::2] - all_primals_known = all(p.is_known() for p in primals) - some_tangents_unknown = any(not t.is_known() for t in tangents) + assert len(args) % 2 == 0 + nr_primals = len(args) // 2 + primals, tangents = args[0::2], args[1::2] + all_primals_known = all(p.is_known() for p in primals) + some_tangents_unknown = any(not t.is_known() for t in tangents) + + if not (all_primals_known and some_tangents_unknown): + return trace.default_process_primitive(_enzyme_fwd_p, args, kwargs) - if not (all_primals_known and some_tangents_unknown): - return trace.default_process_primitive(_enzyme_fwd_p, args, kwargs) + outs_known = trace.default_process_primitive(_enzyme_aug_p, primals, kwargs) - outs_known = trace.default_process_primitive( - _enzyme_aug_p, primals, kwargs) + shadow_aug_args = (trace.full_raise(outs_known[-1]),) + primals + tangents + shadows_known = trace.default_process_primitive( + _enzyme_shadow_aug_p, shadow_aug_args, kwargs + ) - shadow_aug_args = (trace.full_raise(outs_known[-1]),) + primals + tangents - shadows_known = trace.default_process_primitive( - _enzyme_shadow_aug_p, shadow_aug_args, - kwargs) + outs = tuple(v for tup in zip(outs_known[:-1], shadows_known) for v in tup) + return outs - outs = tuple(v for tup in zip(outs_known[:-1], shadows_known) for v in tup) - return outs pe.custom_partial_eval_rules[_enzyme_fwd_p] = fwd_partial_eval + def enzyme_vjp(shadow_rets, *prim_args, **kwargs): - out_shapes = kwargs['out_shapes'] - del kwargs['out_shapes'] - shadows = [ad.is_undefined_primal(x) for x in prim_args] - tape = prim_args[0] - prim_args = prim_args[1:1+(len(prim_args)-1)//2] - prim_args = tuple(jnp.ones(x.aval.shape, x.aval.dtype) if ad.is_undefined_primal(x) else x for x in prim_args) - in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args) - - args = (tape, ) + tuple(shadow_rets) - shadconv = _enzyme_rev_p.bind( - *args, **kwargs, in_shapes=in_shapes) - res = (None,) + tuple(None for _ in range(len(shadconv))) + tuple(shadconv) - return res + out_shapes = kwargs["out_shapes"] + del kwargs["out_shapes"] + shadows = [ad.is_undefined_primal(x) for x in prim_args] + tape = prim_args[0] + prim_args = prim_args[1 : 1 + (len(prim_args) - 1) // 2] + prim_args = tuple( + jnp.ones(x.aval.shape, x.aval.dtype) if ad.is_undefined_primal(x) else x + for x in prim_args + ) + in_shapes = tuple((a.shape, jaxify(a.dtype)) for a in prim_args) + + args = (tape,) + tuple(shadow_rets) + shadconv = _enzyme_rev_p.bind(*args, **kwargs, in_shapes=in_shapes) + res = (None,) + tuple(None for _ in range(len(shadconv))) + tuple(shadconv) + return res + ad.primitive_transposes[_enzyme_shadow_aug_p] = enzyme_vjp + def enzyme_jax_ir(argv=()): - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @jax.jit - def wrapped(*args: Any): - args_flat, in_tree = jax.tree_util.tree_flatten(args) - out_shape = jax.eval_shape(func, *args) - out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape) - out_shape_flat = [jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat] - out_flat = ffi_call(*args_flat, source=(in_tree, func), fn="", out_shapes=out_shape_flat, argv=argv, lang=LANG_MHLO) - return jax.tree_util.tree_unflatten(out_tree, out_flat) - return wrapped - return decorator + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @jax.jit + def wrapped(*args: Any): + args_flat, in_tree = jax.tree_util.tree_flatten(args) + out_shape = jax.eval_shape(func, *args) + out_shape_flat, out_tree = jax.tree_util.tree_flatten(out_shape) + out_shape_flat = [ + jax.core.ShapedArray(o.shape, o.dtype) for o in out_shape_flat + ] + out_flat = ffi_call( + *args_flat, + source=(in_tree, func), + fn="", + out_shapes=out_shape_flat, + argv=argv, + lang=LANG_MHLO, + ) + return jax.tree_util.tree_unflatten(out_tree, out_flat) + + return wrapped + + return decorator diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 63a2bd9d..bcef48ac 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -5,189 +5,257 @@ @enzyme_jax_ir() def add_one(x: jax.Array, y) -> jax.Array: - return x + 1 + y + return x + 1 + y + @jax.jit def add_one_plain(x: jax.Array, y) -> jax.Array: - return x + 1 + y + return x + 1 + y + @enzyme_jax_ir() def add_two(x: jax.Array, z, y) -> jax.Array: - return x + y + return x + y + @jax.jit def add_two_plain(x: jax.Array, z, y) -> jax.Array: - return x + y + return x + y -in0, in1, in2 = jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.]), jnp.array([100., 200., 300.]) + +in0, in1, in2 = ( + jnp.array([1.0, 2.0, 3.0]), + jnp.array([10.0, 20.0, 30.0]), + jnp.array([100.0, 200.0, 300.0]), +) # TODO: this currently throws NYI as it is not yet connected to JIT and runtime. # But it should print LLVM IR in the process. ao = add_one(in0, in1) aop = add_one_plain(in0, in1) -assert (jnp.abs(ao-aop) < 1e-6).all() +assert (jnp.abs(ao - aop) < 1e-6).all() print("Primal success") at = add_two(in0, in1, in2) atp = add_two_plain(in0, in1, in2) -assert (jnp.abs(at-atp) < 1e-6).all() +assert (jnp.abs(at - atp) < 1e-6).all() print("Primal Deadarg success") import timeit -print(timeit.Timer('add_one(in0, in1)', globals={'add_one':add_one, 'in0':in0, 'in1':in1}).timeit()) -print(timeit.Timer('add_one_plain(in0, in1)', globals={'add_one_plain':add_one_plain, 'in0':in0, 'in1':in1}).timeit()) +print( + timeit.Timer( + "add_one(in0, in1)", globals={"add_one": add_one, "in0": in0, "in1": in1} + ).timeit() +) +print( + timeit.Timer( + "add_one_plain(in0, in1)", + globals={"add_one_plain": add_one_plain, "in0": in0, "in1": in1}, + ).timeit() +) + +din0, din1, din2 = ( + jnp.array([0.1, 0.2, 0.3]), + jnp.array([50.0, 70.0, 110.0]), + jnp.array([1300.0, 1700.0, 1900.0]), +) -din0, din1, din2 = (jnp.array([.1, .2, .3]), jnp.array([50., 70., 110.]), jnp.array([1300., 1700., 1900.])) @jax.jit def fwd(in0, in1, din0, din1): - return jax.jvp(add_one, (in0, in1), (din0, din1)) + return jax.jvp(add_one, (in0, in1), (din0, din1)) + @jax.jit def fwd_plain(in0, in1, din0, din1): - return jax.jvp(add_one_plain, (in0, in1), (din0, din1)) + return jax.jvp(add_one_plain, (in0, in1), (din0, din1)) + primals, tangents = fwd(in0, in1, din0, din1) primals_p, tangents_p = fwd_plain(in0, in1, din0, din1) -assert (jnp.abs(primals-primals_p) < 1e-6).all() +assert (jnp.abs(primals - primals_p) < 1e-6).all() for t, t_p in zip(tangents, tangents_p): - assert (jnp.abs(t-t_p) < 1e-6).all() + assert (jnp.abs(t - t_p) < 1e-6).all() print("Tangent success") + @jax.jit def fwd2(in0, in1, in2, din0, din1, din2): - return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2)) + return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2)) + @jax.jit def fwd2_plain(in0, in1, in2, din0, din1, din2): - return jax.jvp(add_two_plain, (in0, in1, in2), (din0, din1, din2)) + return jax.jvp(add_two_plain, (in0, in1, in2), (din0, din1, din2)) + primals, tangents = fwd2(in0, in1, in2, din0, din1, din2) primals_p, tangents_p = fwd2_plain(in0, in1, in2, din0, din1, din2) print(primals, primals_p) -assert (jnp.abs(primals-primals_p) < 1e-6).all() +assert (jnp.abs(primals - primals_p) < 1e-6).all() for i, (t, t_p) in enumerate(zip(tangents, tangents_p)): print(i, t, t_p) - assert (jnp.abs(t-t_p) < 1e-6).all() + assert (jnp.abs(t - t_p) < 1e-6).all() print("Tangent deadarg success") -print(timeit.Timer('fwd(in0, in1, din0, din1)', globals={'fwd':fwd, 'in0':in0, 'in1':in1, 'din0':din0, 'din1':din1}).timeit()) -print(timeit.Timer('fwd_plain(in0, in1, din0, din1)', globals={'fwd_plain':fwd_plain, 'in0':in0, 'in1':in1, 'din0':din0, 'din1':din1}).timeit()) +print( + timeit.Timer( + "fwd(in0, in1, din0, din1)", + globals={"fwd": fwd, "in0": in0, "in1": in1, "din0": din0, "din1": din1}, + ).timeit() +) +print( + timeit.Timer( + "fwd_plain(in0, in1, din0, din1)", + globals={ + "fwd_plain": fwd_plain, + "in0": in0, + "in1": in1, + "din0": din0, + "din1": din1, + }, + ).timeit() +) @jax.jit def rev(in0, in1, dout): - primals, f_vjp = jax.vjp(add_one, in0, in1) - grads = f_vjp(dout) - return primals, grads + primals, f_vjp = jax.vjp(add_one, in0, in1) + grads = f_vjp(dout) + return primals, grads + @jax.jit def rev_plain(in0, in1, dout): - primals, f_vjp = jax.vjp(add_one_plain, in0, in1) - grads = f_vjp(dout) - return primals, grads + primals, f_vjp = jax.vjp(add_one_plain, in0, in1) + grads = f_vjp(dout) + return primals, grads + -dout = jnp.array([500., 700., 110.]) +dout = jnp.array([500.0, 700.0, 110.0]) primals, grads = rev(in0, in1, dout) # TODO enzyme will in place 0 the gradient inputs, which may not be expected print(dout) -dout = jnp.array([500., 700., 110.]) +dout = jnp.array([500.0, 700.0, 110.0]) primals_p, grads_p = rev_plain(in0, in1, dout) -assert (jnp.abs(primals-primals_p) < 1e-6).all() +assert (jnp.abs(primals - primals_p) < 1e-6).all() for g, g_p in zip(grads, grads_p): print(i, g, g_p) - assert (jnp.abs(g-g_p) < 1e-6).all() + assert (jnp.abs(g - g_p) < 1e-6).all() print("Gradient success") + @jax.jit def rev2(in0, in1, in2, dout): - primals, f_vjp = jax.vjp(add_two, in0, in1, in2) - grads = f_vjp(dout) - return primals, grads + primals, f_vjp = jax.vjp(add_two, in0, in1, in2) + grads = f_vjp(dout) + return primals, grads + @jax.jit def rev2_plain(in0, in1, in2, dout): - primals, f_vjp = jax.vjp(add_two_plain, in0, in1, in2) - grads = f_vjp(dout) - return primals, grads + primals, f_vjp = jax.vjp(add_two_plain, in0, in1, in2) + grads = f_vjp(dout) + return primals, grads -dout = jnp.array([500., 700., 110.]) +dout = jnp.array([500.0, 700.0, 110.0]) primals, grads = rev2(in0, in1, in2, dout) # TODO enzyme will in place 0 the gradient inputs, which may not be expected print(dout) -dout = jnp.array([500., 700., 110.]) +dout = jnp.array([500.0, 700.0, 110.0]) primals_p, grads_p = rev2_plain(in0, in1, in2, dout) -assert (jnp.abs(primals-primals_p) < 1e-6).all() +assert (jnp.abs(primals - primals_p) < 1e-6).all() for g, g_p in zip(grads, grads_p): print(i, g, g_p) - assert (jnp.abs(g-g_p) < 1e-6).all() + assert (jnp.abs(g - g_p) < 1e-6).all() print("Gradient deadarg success") -print(timeit.Timer('rev(in0, in1, dout)', globals={'rev':rev, 'in0':in0, 'in1':in1, 'dout':dout}).timeit()) -print(timeit.Timer('rev_plain(in0, in1, dout)', globals={'rev_plain':rev_plain, 'in0':in0, 'in1':in1, 'dout':dout}).timeit()) +print( + timeit.Timer( + "rev(in0, in1, dout)", + globals={"rev": rev, "in0": in0, "in1": in1, "dout": dout}, + ).timeit() +) +print( + timeit.Timer( + "rev_plain(in0, in1, dout)", + globals={"rev_plain": rev_plain, "in0": in0, "in1": in1, "dout": dout}, + ).timeit() +) + +x = jnp.array(range(50), dtype=jnp.float32) +dx = jnp.array([i * i for i in range(50)], dtype=jnp.float32) -x = jnp.array(range(50), dtype=jnp.float32) -dx = jnp.array([i*i for i in range(50)], dtype=jnp.float32) @enzyme_jax_ir() def esum(x): return jnp.sum(x) + eres = esum(x) print(eres) -assert jnp.abs(eres-50*49/2)<1e-6 +assert jnp.abs(eres - 50 * 49 / 2) < 1e-6 + @jax.jit def sumfwd(in0, din0): - return jax.jvp(esum, (in0,), (din0,)) + return jax.jvp(esum, (in0,), (din0,)) + primals, tangents = sumfwd(x, dx) print(primals, tangents) -assert jnp.abs(primals-50*49/2)<1e-6 -assert jnp.abs(tangents-50*49*99/6)<1e-6 +assert jnp.abs(primals - 50 * 49 / 2) < 1e-6 +assert jnp.abs(tangents - 50 * 49 * 99 / 6) < 1e-6 + @jax.jit def sumrev_p(in0): - primals, f_vjp = jax.vjp(jnp.sum, in0) - grads = f_vjp(1.0) - return primals, grads + primals, f_vjp = jax.vjp(jnp.sum, in0) + grads = f_vjp(1.0) + return primals, grads + primals, grads = sumrev_p(x) print(primals, grads) + @jax.jit def sumrev(in0): - primals, f_vjp = jax.vjp(esum, in0) - grads = f_vjp(1.0) - return primals, grads + primals, f_vjp = jax.vjp(esum, in0) + grads = f_vjp(1.0) + return primals, grads + primals, grads = sumrev(x) print(primals, grads) -assert jnp.abs(primals-50*49/2)<1e-6 -assert (jnp.abs(grads[0]-1) <1e-6).all() +assert jnp.abs(primals - 50 * 49 / 2) < 1e-6 +assert (jnp.abs(grads[0] - 1) < 1e-6).all() + @enzyme_jax_ir() def ecache(x): return x * x[0] + @jax.jit def cacherev(in0, din0): - primals, f_vjp = jax.vjp(ecache, in0) - grads = f_vjp(din0) - return grads + primals, f_vjp = jax.vjp(ecache, in0) + grads = f_vjp(din0) + return grads + dim = 288 @@ -195,5 +263,5 @@ def cacherev(in0, din0): dx = jnp.array(range(dim), dtype=jnp.float32) grads = cacherev(x, dx) -assert jnp.abs(grads[0][0]-287*288*(2*287+1)/6)<1e-6 -assert (jnp.abs(grads[0][1:]) <1e-6).all() +assert jnp.abs(grads[0][0] - 287 * 288 * (2 * 287 + 1) / 6) < 1e-6 +assert (jnp.abs(grads[0][1:]) < 1e-6).all() diff --git a/test/lit_tests/ir.pyt b/test/lit_tests/ir.pyt index 71e538c2..8982f1ad 100644 --- a/test/lit_tests/ir.pyt +++ b/test/lit_tests/ir.pyt @@ -4,10 +4,15 @@ import jax import jax.numpy as jnp from enzyme_ad.jax import cpp_call + def do_something(ones, twos): shape = jax.core.ShapedArray(tuple(3 * s for s in ones.shape), ones.dtype) shape2 = jax.core.ShapedArray(tuple(2 * s for s in ones.shape), ones.dtype) - a, b = cpp_call(ones, twos, out_shapes=[shape, shape2], source=""" + a, b = cpp_call( + ones, + twos, + out_shapes=[shape, shape2], + source=""" template void myfn(enzyme::tensor& out0, enzyme::tensor& out1, const enzyme::tensor& in0, const enzyme::tensor& in1) { for (int j=0; j {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<6x9xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg3: tensor<4x6xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<2x3xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<5x7xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}) { # CHECK-NEXT: %0 = mhlo.constant dense<3> : tensor<1xi64> diff --git a/test/lit_tests/lit.cfg.py b/test/lit_tests/lit.cfg.py index e29923d0..ab5c05d0 100644 --- a/test/lit_tests/lit.cfg.py +++ b/test/lit_tests/lit.cfg.py @@ -12,17 +12,17 @@ # Configuration file for the 'lit' test runner. # name: The name of this test suite. -config.name = 'Enzyme-JaX' +config.name = "Enzyme-JaX" # testFormat: The test format to use to interpret tests. # # For now we require '&&' between commands, until they get globally killed and # the test runner updated. -execute_external = platform.system() != 'Windows' +execute_external = platform.system() != "Windows" config.test_format = lit.formats.ShTest(execute_external) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.pyt'] +config.suffixes = [".pyt"] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -30,11 +30,22 @@ # test_exec_root: The root path where tests should be run. config.test_exec_root = os.path.dirname(__file__) -#ToolSubst('%lli', FindTool('lli'), post='.', extra_args=lli_args), +# ToolSubst('%lli', FindTool('lli'), post='.', extra_args=lli_args), # Tweak the PATH to include the tools dir and the scripts dir. -base_paths = [os.path.join(os.path.dirname(__file__), '..', '..', 'bazel-bin', 'external', 'llvm-project', 'llvm'), config.environment['PATH']] -path = os.path.pathsep.join(base_paths) # + config.extra_paths) -config.environment['PATH'] = path - -config.substitutions.append(('python', sys.executable)) +base_paths = [ + os.path.join( + os.path.dirname(__file__), + "..", + "..", + "bazel-bin", + "external", + "llvm-project", + "llvm", + ), + config.environment["PATH"], +] +path = os.path.pathsep.join(base_paths) # + config.extra_paths) +config.environment["PATH"] = path + +config.substitutions.append(("python", sys.executable)) diff --git a/test/llama.py b/test/llama.py index b0fe50b7..cc3649ce 100644 --- a/test/llama.py +++ b/test/llama.py @@ -3,202 +3,235 @@ import jax.lax import enzyme_ad.jax as enzyme_jax + def rmsnorm(x, weight): - ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) - return weight * x * ss + ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) + return weight * x * ss + def softmax(x): - max_val = jnp.max(x) - x = jnp.exp(x - max_val) - return x / sum(x) + max_val = jnp.max(x) + x = jnp.exp(x - max_val) + return x / sum(x) + def sigmoid(x): - return 1 / (1 + jnp.exp(-x)) + return 1 / (1 + jnp.exp(-x)) + def silu(x): - return x * sigmoid(x) + return x * sigmoid(x) # Token is token value asserts = True -def forward(x, config, weights, key_cache, value_cache): - pos = key_cache.shape[1] - assert pos == key_cache.shape[1] - assert pos == value_cache.shape[1] - - n_layers = config['n_layers'] - seq_len = config['seq_len'] - n_heads = config['n_heads'] - vocab_size = config['vocab_size'] - - # Total number of parameters of the recurrent state - dim = config['dim'] - - n_kv_heads = config['n_kv_heads'] - - # number of hidden dimensions? - hidden_dim = config['hidden_dim'] - - - # Number of parameters per head - head_size = dim // n_heads - - # Number of heads per kv - kv_mul = n_heads // n_kv_heads - - # Number of parameters in a kv - kv_dim = dim // n_heads * n_kv_heads - - - wo = weights['wo'] - if asserts: assert wo.shape == (n_layers, dim, dim) - rms_ffn_weight = weights['rms_ffn_weight'] - if asserts: assert rms_ffn_weight.shape == (n_layers, dim) - w1 = weights['w1'] - if asserts: assert w1.shape == (n_layers, hidden_dim, dim) - w3 = weights['w3'] - if asserts: assert w3.shape == (n_layers, hidden_dim, dim) - w2 = weights['w2'] - if asserts: assert w2.shape == (n_layers, dim, hidden_dim) - rms_att_weight = weights['rms_att_weight'] - if asserts: assert rms_att_weight.shape == (n_layers,dim) - rms_final_weight = weights['rms_final_weight'] - if asserts: assert rms_final_weight.shape == (dim,) - wcls = weights['wcls'] - if asserts: assert wcls.shape == (vocab_size, dim) +def forward(x, config, weights, key_cache, value_cache): + pos = key_cache.shape[1] + assert pos == key_cache.shape[1] + assert pos == value_cache.shape[1] + + n_layers = config["n_layers"] + seq_len = config["seq_len"] + n_heads = config["n_heads"] + vocab_size = config["vocab_size"] + + # Total number of parameters of the recurrent state + dim = config["dim"] + + n_kv_heads = config["n_kv_heads"] + + # number of hidden dimensions? + hidden_dim = config["hidden_dim"] + + # Number of parameters per head + head_size = dim // n_heads + + # Number of heads per kv + kv_mul = n_heads // n_kv_heads + + # Number of parameters in a kv + kv_dim = dim // n_heads * n_kv_heads + + wo = weights["wo"] + if asserts: + assert wo.shape == (n_layers, dim, dim) + rms_ffn_weight = weights["rms_ffn_weight"] + if asserts: + assert rms_ffn_weight.shape == (n_layers, dim) + w1 = weights["w1"] + if asserts: + assert w1.shape == (n_layers, hidden_dim, dim) + w3 = weights["w3"] + if asserts: + assert w3.shape == (n_layers, hidden_dim, dim) + w2 = weights["w2"] + if asserts: + assert w2.shape == (n_layers, dim, hidden_dim) + + rms_att_weight = weights["rms_att_weight"] + if asserts: + assert rms_att_weight.shape == (n_layers, dim) + + rms_final_weight = weights["rms_final_weight"] + if asserts: + assert rms_final_weight.shape == (dim,) + wcls = weights["wcls"] + if asserts: + assert wcls.shape == (vocab_size, dim) + + # token_embedding_table = weights['token_embedding_table'] + # if asserts: assert token_embedding_table.shape == (vocab_size, dim) + + # x = token_embedding_table[token, :] + # if asserts: assert x.shape == (dim, ) + + wq = weights["wq"] + if asserts: + assert wq.shape == (n_layers, dim, dim) - # token_embedding_table = weights['token_embedding_table'] - # if asserts: assert token_embedding_table.shape == (vocab_size, dim) + wk = weights["wk"] + if asserts: + assert wk.shape == (n_layers, kv_dim, dim) - # x = token_embedding_table[token, :] - # if asserts: assert x.shape == (dim, ) + wv = weights["wv"] + if asserts: + assert wv.shape == (n_layers, kv_dim, dim) - wq = weights['wq'] - if asserts: assert wq.shape == (n_layers, dim, dim) + toconv = [] - wk = weights['wk'] - if asserts: assert wk.shape == (n_layers, kv_dim, dim) + for i in range(0, dim, 2): + freq = 1 / jnp.power(10000, (i % head_size) / head_size) + val = pos * freq + fcr = jnp.cos(val) + fci = jnp.sin(val) - wv = weights['wv'] - if asserts: assert wv.shape == (n_layers, kv_dim, dim) + rotM = jnp.array([[fcr, -fci], [fci, fcr]]) + toconv.append(rotM) + toconv2 = toconv[: kv_dim // 2] + [jnp.eye(2)] * (dim // 2 - kv_dim // 2) - toconv = [] - - for i in range(0, dim, 2): - freq = 1 / jnp.power(10000, (i % head_size) / head_size) - val = pos * freq - fcr = jnp.cos(val) - fci = jnp.sin(val) + toconv = jnp.array(toconv) + toconv2 = jnp.array(toconv2) - rotM = jnp.array([[fcr, -fci], - [fci, fcr]]) - toconv.append(rotM) - toconv2 = toconv[:kv_dim//2] + [jnp.eye(2)] * (dim//2 - kv_dim//2) - - toconv = jnp.array(toconv) - toconv2 = jnp.array(toconv2) + keys2 = [] + values2 = [] + for l in range(n_layers): + xb = rmsnorm(x, rms_att_weight[l, :]) + if asserts: + assert xb.shape == (dim,) - keys2 = [] - values2 = [] - for l in range(n_layers): - xb = rmsnorm(x, rms_att_weight[l, :]) - if asserts: assert xb.shape == (dim, ) + q = wq[l, :, :] @ xb + if asserts: + assert q.shape == (dim,) - q = wq[l, :, :] @ xb - if asserts: assert q.shape == (dim, ) + k = wk[l, :, :] @ xb + if asserts: + assert q.shape == (kv_dim,) - k = wk[l, :, :] @ xb - if asserts: assert q.shape == (kv_dim, ) + v = wv[l, :, :] @ xb + if asserts: + assert q.shape == (kv_dim,) - v = wv[l, :, :] @ xb - if asserts: assert q.shape == (kv_dim, ) - - q_tmp = jnp.reshape(q, (dim // 2, 2)) - k_tmp = jnp.reshape(k, (dim // 2, 2)) + q_tmp = jnp.reshape(q, (dim // 2, 2)) + k_tmp = jnp.reshape(k, (dim // 2, 2)) - # dim == head_size * n_heads + # dim == head_size * n_heads - # Batched gemv - k = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv2, k_tmp), (dim,)) - q = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv, q_tmp), (dim,)) + # Batched gemv + k = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv2, k_tmp), (dim,)) + q = jnp.reshape(jnp.einsum("ijk,ik -> ij", toconv, q_tmp), (dim,)) - key_cache_l = key_cache[l, :, :] - key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0) - value_cache_l = value_cache[l, :, :] - value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0) - keys2.append(key_cache_l) - values2.append(value_cache_l) - - xbs2 = [] - for h in range(n_heads): + key_cache_l = key_cache[l, :, :] + key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0) + value_cache_l = value_cache[l, :, :] + value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0) + keys2.append(key_cache_l) + values2.append(value_cache_l) - q2 = q[head_size*h:head_size*(h+1)] - if asserts: assert q2.shape == (head_size,) + xbs2 = [] + for h in range(n_heads): + q2 = q[head_size * h : head_size * (h + 1)] + if asserts: + assert q2.shape == (head_size,) - # For kv_mul consecutive heads, they share the same kv cache - # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size) - # generalized einsum reducing the last dim, the rest are batch - att = [] + # For kv_mul consecutive heads, they share the same kv cache + # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size) + # generalized einsum reducing the last dim, the rest are batch + att = [] - key_index = h // kv_mul - - att = jnp.einsum('ij,j->i', key_cache_l[:, key_index * head_size : (key_index+1) * head_size], q2) + key_index = h // kv_mul - att = att / jnp.sqrt(head_size) + att = jnp.einsum( + "ij,j->i", + key_cache_l[:, key_index * head_size : (key_index + 1) * head_size], + q2, + ) - att = softmax(att) - - x_tmp = jnp.einsum('ij,i->j', value_cache_l[:, key_index * head_size : (key_index+1) * head_size], att) + att = att / jnp.sqrt(head_size) - xbs2.append(x_tmp) + att = softmax(att) - # Todo right concat - xb = jnp.concatenate(xbs2, axis=None) + x_tmp = jnp.einsum( + "ij,i->j", + value_cache_l[:, key_index * head_size : (key_index + 1) * head_size], + att, + ) - xb2 = wo[l, :, :] @ xb - if asserts: assert xb2.shape == (dim, ) + xbs2.append(x_tmp) - x += xb2 + # Todo right concat + xb = jnp.concatenate(xbs2, axis=None) - # Rmsnorm and feedforward swiglu + xb2 = wo[l, :, :] @ xb + if asserts: + assert xb2.shape == (dim,) - xb = rmsnorm(x, rms_ffn_weight[l, :]) + x += xb2 - # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) - # first calculate self.w1(x) and self.w3(x) + # Rmsnorm and feedforward swiglu + xb = rmsnorm(x, rms_ffn_weight[l, :]) - hb = w1[l, :, :] @ xb - hb2 = w3[l, :, :] @ xb + # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + # first calculate self.w1(x) and self.w3(x) - hb = silu(hb) + hb = w1[l, :, :] @ xb + hb2 = w3[l, :, :] @ xb - hb = hb * hb2 + hb = silu(hb) + hb = hb * hb2 - xb = w2[l, :, :] @ hb + xb = w2[l, :, :] @ hb - x += xb + x += xb + x = rmsnorm(x, rms_final_weight) + logits = wcls @ x - x = rmsnorm(x, rms_final_weight) - logits = wcls @ x + return x - return x import numpy as np -config = {'dim': 288, 'hidden_dim': 768, 'n_layers': 6, 'n_heads': 6, 'n_kv_heads': 6, 'vocab_size': 32000, 'seq_len': 256} - -n_layers = config['n_layers'] -seq_len = config['seq_len'] -n_heads = config['n_heads'] -dim = config['dim'] -n_kv_heads = config['n_kv_heads'] -vocab_size = config['vocab_size'] -hidden_dim = config['hidden_dim'] +config = { + "dim": 288, + "hidden_dim": 768, + "n_layers": 6, + "n_heads": 6, + "n_kv_heads": 6, + "vocab_size": 32000, + "seq_len": 256, +} + +n_layers = config["n_layers"] +seq_len = config["seq_len"] +n_heads = config["n_heads"] +dim = config["dim"] +n_kv_heads = config["n_kv_heads"] +vocab_size = config["vocab_size"] +hidden_dim = config["hidden_dim"] kv_dim = dim // n_heads * n_kv_heads head_size = dim // n_heads @@ -206,52 +239,59 @@ def forward(x, config, weights, key_cache, value_cache): weights = {} dweights = {} -for name, shape in [("rms_att_weight", (n_layers, dim)), - ("wq", (n_layers, dim, n_heads * head_size)), - ("wk", (n_layers, dim, n_kv_heads * head_size)), - ("wv", (n_layers, dim, n_kv_heads * head_size)), - ("wo", (n_layers, dim, dim)), - ("rms_ffn_weight", (n_layers, dim)), - ("w1", (n_layers, hidden_dim, dim)), - ("w2", (n_layers, dim, hidden_dim)), - ("w3", (n_layers, hidden_dim, dim)), - ("rms_final_weight", (dim,)), - ("wcls", (vocab_size, dim)) - ]: - key, subkey = jax.random.split(key) - key, subkey2 = jax.random.split(key) - weights[name] = jax.random.uniform(subkey, shape=shape) - dweights[name] = jax.random.uniform(subkey2, shape=shape) +for name, shape in [ + ("rms_att_weight", (n_layers, dim)), + ("wq", (n_layers, dim, n_heads * head_size)), + ("wk", (n_layers, dim, n_kv_heads * head_size)), + ("wv", (n_layers, dim, n_kv_heads * head_size)), + ("wo", (n_layers, dim, dim)), + ("rms_ffn_weight", (n_layers, dim)), + ("w1", (n_layers, hidden_dim, dim)), + ("w2", (n_layers, dim, hidden_dim)), + ("w3", (n_layers, hidden_dim, dim)), + ("rms_final_weight", (dim,)), + ("wcls", (vocab_size, dim)), +]: + key, subkey = jax.random.split(key) + key, subkey2 = jax.random.split(key) + weights[name] = jax.random.uniform(subkey, shape=shape) + dweights[name] = jax.random.uniform(subkey2, shape=shape) key, subkey = jax.random.split(key) x = jax.random.uniform(subkey, shape=(dim,)) key, subkey = jax.random.split(key) dx = jax.random.uniform(subkey, shape=(dim,)) + def partial(func, config): def sfn(x, weights, key_cache, value_cache): return func(x, config, weights, key_cache, value_cache) + return sfn + pos = 1 -key_cache = jnp.zeros((n_layers, pos,kv_dim)) -value_cache = jnp.zeros((n_layers, pos,kv_dim)) +key_cache = jnp.zeros((n_layers, pos, kv_dim)) +value_cache = jnp.zeros((n_layers, pos, kv_dim)) key, subkey = jax.random.split(key) -dkc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim)) +dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) key, subkey = jax.random.split(key) -dvc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim)) +dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) func = partial(forward, config) + @jax.jit def jfunc(x, weights, key_cache, value_cache): return func(x, weights, key_cache, value_cache) + @enzyme_jax.enzyme_jax_ir() def efunc(x, weights, key_cache, value_cache): return func(x, weights, key_cache, value_cache) + # eres = efunc(x, weights, key_cache, value_cache) # print("Enzyme primal", eres) # res = func(x, weights, key_cache, value_cache) @@ -259,16 +299,19 @@ def efunc(x, weights, key_cache, value_cache): # print (" max error", jnp.max(jnp.abs(eres-res))) # assert (jnp.abs(eres - res) < 1e-3).all() -#jfunc = jax.jit(partial(forward, config)) +# jfunc = jax.jit(partial(forward, config)) # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") + @jax.jit def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + @jax.jit def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + # print("pre fwd diff") # eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) @@ -279,16 +322,17 @@ def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): @jax.jit def jrev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) - return f_vjp(dx) #, dkc, dvc) + primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) + return f_vjp(dx) # , dkc, dvc) + @jax.jit def erev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) - return f_vjp(dx) #, dkc, dvc) + primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) + return f_vjp(dx) # , dkc, dvc) + eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Enzyme rev", eres) jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) print("Jax rev", jres) - diff --git a/test/test.py b/test/test.py index 2eb5866d..1553f7d6 100644 --- a/test/test.py +++ b/test/test.py @@ -2,10 +2,14 @@ import jax.numpy as jnp from enzyme_ad.jax import cpp_call + @jax.jit def do_something(ones): shape = jax.core.ShapedArray(ones.shape, ones.dtype) - a, b = cpp_call(ones, out_shapes=[shape, shape], source=""" + a, b = cpp_call( + ones, + out_shapes=[shape, shape], + source=""" template void myfn(enzyme::tensor& out0, enzyme::tensor& out1, const enzyme::tensor& in0) { for (int j=0; j void f(T1& out0, const T2& in1) { out0 = 56.0f; } - """) + """, + ) return a, b, c + ones = jnp.ones((2, 3), jnp.float32) x, y, z = do_something(ones) @@ -35,7 +46,7 @@ def do_something(ones): print(y) print(z) -primals, tangents = jax.jvp(do_something, (ones,), (ones,) ) +primals, tangents = jax.jvp(do_something, (ones,), (ones,)) print(primals) print(tangents) @@ -51,18 +62,24 @@ def do_something(ones): @enzyme_jax_ir() def add_one(x: jax.Array, y) -> jax.Array: - return x + 1 + y + return x + 1 + y # TODO: this currently throws NYI as it is not yet connected to JIT and runtime. # But it should print LLVM IR in the process. -add_one(jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.])) +add_one(jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])) -primals, tangents = jax.jvp(add_one, (jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.])), (jnp.array([.1, .2, .3]), jnp.array([50., 70., 110.])) ) +primals, tangents = jax.jvp( + add_one, + (jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])), + (jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])), +) print(primals) print(tangents) -primals, f_vjp = jax.vjp(add_one, jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.])) -grads = f_vjp(jnp.array([500., 700., 110.])) +primals, f_vjp = jax.vjp( + add_one, jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0]) +) +grads = f_vjp(jnp.array([500.0, 700.0, 110.0])) print(primals) print(grads)