Skip to content

Commit

Permalink
Rename to enzymead
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 19, 2023
1 parent 5a15457 commit 00e2ea9
Show file tree
Hide file tree
Showing 18 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ steps:
python -m ensurepip --upgrade
python -m pip install --user numpy wheel
mkdir baztmp
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
bazel --output_user_root=`pwd`/baztmp build :enzymead
cp bazel-bin/*.whl .
python -m pip install *.whl
cd test && python -m pip install "jax[cpu]" && python test.py
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/secure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ steps:
mkdir baztmp
export TAG=`echo $BUILDKITE_TAG | cut -c2-`
sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
bazel --output_user_root=`pwd`/baztmp build :enzymead
cp bazel-bin/*.whl .
python -m pip install *.whl
cd test && python -m pip install "jax[cpu]" && python test.py && cd ..
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jobs:
key: bazel-${{ matrix.os }}
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: |
bazel build :enzyme_ad @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
bazel build :enzymead @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//src/enzymead/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
bazel --version
nm -C $(find bazel-out/ -name enzyme_call.so -type f) | grep ExecutorCache::
- run: cp bazel-bin/*.whl .
Expand All @@ -46,7 +46,7 @@ jobs:
run: |
python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl
cd test
nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
nm -C $(python3 -c "from enzymead.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
python3 test.py
cd lit_tests
lit . --verbose
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
path: "~/.cache/bazel"
key: bazel-${{ matrix.os }}
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: bazel build :enzyme_ad
- run: bazel build :enzymead
- env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
Expand Down
10 changes: 5 additions & 5 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ package(
py_package(
name = "enzyme_jax_data",
deps = [
"//src/enzyme_ad/jax:enzyme_call.so",
"//src/enzymead/jax:enzyme_call.so",
"@llvm-project//clang:builtin_headers_gen",
],
# Only include these Python packages.
packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
packages = ["@//src/enzymead/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
)

py_wheel(
name = "enzyme_ad",
distribution = "enzyme_ad",
name = "enzymead",
distribution = "enzymead",
summary = "Enzyme automatic differentiation tool.",
homepage = "https://enzyme.mit.edu/",
project_urls = {
Expand All @@ -39,7 +39,7 @@ py_wheel(
"@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64",
"@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le",
}),
deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"],
deps = ["//src/enzymead/jax:enzyme_jax_internal", ":enzyme_jax_data"],
strip_path_prefixes = ["src/"],
requires = [
"absl_py >= 2.0.0",
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ languages (Julia, Swift, Fortran, Rust, and even Python)!
You can use

```python
from enzyme_ad.jax import cpp_call
from enzymead.jax import cpp_call

# Forward-mode C++ AD example

Expand Down Expand Up @@ -49,13 +49,13 @@ Requirements: `bazel-6.2.1`, `clang++`, `python`, `python-virtualenv`,

Build our extension with:
```sh
# Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl
bazel build :enzyme_ad
# Will create a whl in bazel-bin/enzymead-VERSION-SYSTEM.whl
bazel build :enzymead
```

Finally, install the built library with:
```sh
pip install bazel-bin/enzyme_ad-VERSION-SYSTEM.whl
pip install bazel-bin/enzymead-VERSION-SYSTEM.whl
```
Note that you cannot run code from the root of the git directory. For instance, in the code below, you have to first run `cd test` before running `test.py`.

Expand Down
1 change: 0 additions & 1 deletion src/enzyme_ad/jax/__init__.py

This file was deleted.

File renamed without changes.
1 change: 1 addition & 0 deletions src/enzymead/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from enzymead.jax.primitives import cpp_call, enzyme_jax_ir
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion test/bench_vs_xla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
from enzyme_ad.jax import enzyme_jax_ir
from enzymead.jax import enzyme_jax_ir
from absl.testing import absltest
import timeit

Expand Down
2 changes: 1 addition & 1 deletion test/lit_tests/ir.pyt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import jax.numpy as jnp
from enzyme_ad.jax import cpp_call
from enzymead.jax import cpp_call


def do_something(ones, twos):
Expand Down
4 changes: 1 addition & 3 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
import jax.numpy as jnp
import jax.random
import jax.lax
import enzyme_ad.jax as enzyme_jax
import numpy as np

import enzymead.jax as enzyme_jax

def rmsnorm(x, weight):
ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)
Expand Down

0 comments on commit 00e2ea9

Please sign in to comment.