diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 049072fb6..0a5c4340f 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -59,7 +59,7 @@ steps: python -m ensurepip --upgrade python -m pip install --user numpy wheel mkdir baztmp - bazel --output_user_root=`pwd`/baztmp build :enzyme_ad + bazel --output_user_root=`pwd`/baztmp build :enzymead cp bazel-bin/*.whl . python -m pip install *.whl cd test && python -m pip install "jax[cpu]" && python test.py diff --git a/.buildkite/secure_pipeline.yml b/.buildkite/secure_pipeline.yml index 3f313dd05..9c1a92acf 100644 --- a/.buildkite/secure_pipeline.yml +++ b/.buildkite/secure_pipeline.yml @@ -67,7 +67,7 @@ steps: mkdir baztmp export TAG=`echo $BUILDKITE_TAG | cut -c2-` sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD - bazel --output_user_root=`pwd`/baztmp build :enzyme_ad + bazel --output_user_root=`pwd`/baztmp build :enzymead cp bazel-bin/*.whl . python -m pip install *.whl cd test && python -m pip install "jax[cpu]" && python test.py && cd .. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7bb36e907..b2c76a640 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -36,8 +36,8 @@ jobs: key: bazel-${{ matrix.os }} - run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \; - run: | - bazel build :enzyme_ad @llvm-project//llvm:FileCheck - bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps + bazel build :enzymead @llvm-project//llvm:FileCheck + bazel cquery "allpaths(//src/enzymead/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps bazel --version nm -C $(find bazel-out/ -name enzyme_call.so -type f) | grep ExecutorCache:: - run: cp bazel-bin/*.whl . @@ -46,7 +46,7 @@ jobs: run: | python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl cd test - nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache:: + nm -C $(python3 -c "from enzymead.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache:: python3 test.py cd lit_tests lit . --verbose diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index a51000b14..cf30cc39e 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -36,7 +36,7 @@ jobs: path: "~/.cache/bazel" key: bazel-${{ matrix.os }} - run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \; - - run: bazel build :enzyme_ad + - run: bazel build :enzymead - env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} diff --git a/BUILD b/BUILD index 409ea10c7..cf7dc5b26 100644 --- a/BUILD +++ b/BUILD @@ -11,16 +11,16 @@ package( py_package( name = "enzyme_jax_data", deps = [ - "//src/enzyme_ad/jax:enzyme_call.so", + "//src/enzymead/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen", ], # Only include these Python packages. - packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], + packages = ["@//src/enzymead/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], ) py_wheel( - name = "enzyme_ad", - distribution = "enzyme_ad", + name = "enzymead", + distribution = "enzymead", summary = "Enzyme automatic differentiation tool.", homepage = "https://enzyme.mit.edu/", project_urls = { @@ -39,7 +39,7 @@ py_wheel( "@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64", "@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le", }), - deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"], + deps = ["//src/enzymead/jax:enzyme_jax_internal", ":enzyme_jax_data"], strip_path_prefixes = ["src/"], requires = [ "absl_py >= 2.0.0", diff --git a/README.md b/README.md index 1ae9f714a..1aa2dd6b0 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ languages (Julia, Swift, Fortran, Rust, and even Python)! You can use ```python -from enzyme_ad.jax import cpp_call +from enzymead.jax import cpp_call # Forward-mode C++ AD example @@ -49,13 +49,13 @@ Requirements: `bazel-6.2.1`, `clang++`, `python`, `python-virtualenv`, Build our extension with: ```sh -# Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl -bazel build :enzyme_ad +# Will create a whl in bazel-bin/enzymead-VERSION-SYSTEM.whl +bazel build :enzymead ``` Finally, install the built library with: ```sh -pip install bazel-bin/enzyme_ad-VERSION-SYSTEM.whl +pip install bazel-bin/enzymead-VERSION-SYSTEM.whl ``` Note that you cannot run code from the root of the git directory. For instance, in the code below, you have to first run `cd test` before running `test.py`. diff --git a/src/enzyme_ad/jax/__init__.py b/src/enzyme_ad/jax/__init__.py deleted file mode 100644 index 0519e33ab..000000000 --- a/src/enzyme_ad/jax/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from enzyme_ad.jax.primitives import cpp_call, enzyme_jax_ir diff --git a/src/enzyme_ad/jax/BUILD b/src/enzymead/jax/BUILD similarity index 100% rename from src/enzyme_ad/jax/BUILD rename to src/enzymead/jax/BUILD diff --git a/src/enzymead/jax/__init__.py b/src/enzymead/jax/__init__.py new file mode 100644 index 000000000..61df660ec --- /dev/null +++ b/src/enzymead/jax/__init__.py @@ -0,0 +1 @@ +from enzymead.jax.primitives import cpp_call, enzyme_jax_ir diff --git a/src/enzyme_ad/jax/clang_compile.cc b/src/enzymead/jax/clang_compile.cc similarity index 100% rename from src/enzyme_ad/jax/clang_compile.cc rename to src/enzymead/jax/clang_compile.cc diff --git a/src/enzyme_ad/jax/clang_compile.h b/src/enzymead/jax/clang_compile.h similarity index 100% rename from src/enzyme_ad/jax/clang_compile.h rename to src/enzymead/jax/clang_compile.h diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzymead/jax/compile_with_xla.cc similarity index 100% rename from src/enzyme_ad/jax/compile_with_xla.cc rename to src/enzymead/jax/compile_with_xla.cc diff --git a/src/enzyme_ad/jax/compile_with_xla.h b/src/enzymead/jax/compile_with_xla.h similarity index 100% rename from src/enzyme_ad/jax/compile_with_xla.h rename to src/enzymead/jax/compile_with_xla.h diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzymead/jax/enzyme_call.cc similarity index 100% rename from src/enzyme_ad/jax/enzyme_call.cc rename to src/enzymead/jax/enzyme_call.cc diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzymead/jax/primitives.py similarity index 100% rename from src/enzyme_ad/jax/primitives.py rename to src/enzymead/jax/primitives.py diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 957a37745..7039b6a78 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -1,6 +1,6 @@ import jax import jax.numpy as jnp -from enzyme_ad.jax import enzyme_jax_ir +from enzymead.jax import enzyme_jax_ir from absl.testing import absltest import timeit diff --git a/test/lit_tests/ir.pyt b/test/lit_tests/ir.pyt index 8982f1ad7..cd783631b 100644 --- a/test/lit_tests/ir.pyt +++ b/test/lit_tests/ir.pyt @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from enzyme_ad.jax import cpp_call +from enzymead.jax import cpp_call def do_something(ones, twos): diff --git a/test/llama.py b/test/llama.py index 53994d19a..c84c16153 100644 --- a/test/llama.py +++ b/test/llama.py @@ -2,9 +2,7 @@ import jax.numpy as jnp import jax.random import jax.lax -import enzyme_ad.jax as enzyme_jax -import numpy as np - +import enzymead.jax as enzyme_jax def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)