diff --git a/.buildkite/gpu_pipeline.yml b/.buildkite/gpu_pipeline.yml index f122e394..424baf4d 100644 --- a/.buildkite/gpu_pipeline.yml +++ b/.buildkite/gpu_pipeline.yml @@ -16,19 +16,15 @@ steps: mv bazel* .local/bin/bazel chmod +x .local/bin/bazel - wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-x86_64.sh - chmod +x Miniconda*.sh - ./Miniconda*.sh -b -p `pwd`/conda - rm Miniconda*.sh - python -m ensurepip --upgrade - python -m pip install --user numpy wheel mkdir -p .baztmp - rm -f bazel-bin/*.whl - HERMETIC_PYTHON_VERSION=`python -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel --output_user_root=`pwd`/.baztmp build --python_path=`which python` --define=no_nccl_support=true :enzyme_ad - cp bazel-bin/*.whl . - python -m pip install --user *.whl "jax[cuda12]" echo "--- :python: Test" - HERMETIC_PYTHON_VERSION=`python -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel --output_user_root=`pwd`/.baztmp test --python_path=`which python` --test_output=errors //test/... + HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --test_output=errors //test/... + HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:bench_vs_xla + HERMETIC_PYTHON_VERSION="3.12" bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:llama + cat bazel-out/*/testlogs/test/llama/test.log + artifact_paths: + - "bazel-out/*/testlogs/test/llama/test.log" + - "bazel-out/*/testlogs/test/llama/bench_vs_xla.log" diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5e1d6f90..77a435b7 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,5 +1,5 @@ steps: - - name: "CI {{matrix.arch}} -- {{matrix.os}}" + - name: "CI {{matrix.arch}} -- {{matrix.os}} python {{matrix.python}}" matrix: setup: arch: @@ -7,6 +7,8 @@ steps: - x86_64 os: - macos + python: + - "3.12" agents: queue: "juliaecosystem" os: "{{matrix.os}}" @@ -30,52 +32,31 @@ steps: chmod +x .local/bin/md5 if [ "{{matrix.os}}" == "macos" ]; then - if [ "{{matrix.arch}}" == "aarch64" ]; then - sed -i.bak 's~targets = \[.*\]~targets = \[\"AArch64\", \"AMDGPU\"]~g' WORKSPACE - curl -fLO https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-MacOSX-arm64.sh - else - sed -i.bak 's~targets = \[.*\]~targets = \[\"X86\", \"AMDGPU\"]~g' WORKSPACE - curl -fLO https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-MacOSX-{{matrix.arch}}.sh - fi curl -fLO "https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-darwin" mv bazelisk-darwin .local/bin/bazel chmod +x .local/bin/bazel - chmod +x Miniconda*.sh - ./Miniconda*.sh -b -p `pwd`/conda - rm Miniconda*.sh elif [ "{{matrix.os}}" == "linux" ]; then - if [ "{{matrix.arch}}" == "aarch64" ]; then - curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-arm64 - else - curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-amd64 - fi mv bazel* .local/bin/bazel chmod +x .local/bin/bazel - wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Linux-{{matrix.arch}}.sh - chmod +x Miniconda*.sh - ./Miniconda*.sh -b -p `pwd`/conda - rm Miniconda*.sh else - wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.3.1-0-Windows-{{matrix.arch}}.exe if [ "{{matrix.arch}}" == "aarch64" ]; then wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-windows-arm64.exe else wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-windows-x86_64.exe fi mv bazel* .local/bin/bazel.exe - start /wait "" Miniconda3*.exe /InstallationType=JustMe /RegisterPython=0 /S /D=`pwd`/conda - rm Miniconda*.exe fi - # conda install -c conda-forge cxx-compiler -y - python -m ensurepip --upgrade - python -m pip install --user numpy wheel mkdir -p .baztmp + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --test_output=errors //test/... + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:bench_vs_xla + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp test --cache_test_results=no //test:llama + cat bazel-out/*/testlogs/test/llama/test.log rm -f bazel-bin/*.whl - HERMETIC_PYTHON_VERSION=`python -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel --output_user_root=`pwd`/.baztmp build --python_path=`which python` --define=no_nccl_support=true :enzyme_ad + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/.baztmp build :wheel cp bazel-bin/*.whl . - python -m pip install --user *.whl "jax[cpu]" - HERMETIC_PYTHON_VERSION=`python -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel --output_user_root=`pwd`/.baztmp test --python_path=`which python` --test_output=errors //test/... artifact_paths: - "*.whl" + - "bazel-out/*/testlogs/test/llama/test.log" + - "bazel-out/*/testlogs/test/llama/bench_vs_xla.log" timeout_in_minutes: 180 diff --git a/.buildkite/secure_pipeline.yml b/.buildkite/secure_pipeline.yml index 22f389a2..97bcff7a 100644 --- a/.buildkite/secure_pipeline.yml +++ b/.buildkite/secure_pipeline.yml @@ -1,5 +1,5 @@ steps: - - name: "Tag {{matrix.arch}} -- {{matrix.os}}" + - name: "Tag {{matrix.arch}} -- {{matrix.os}} python {{matrix.python}}" matrix: setup: arch: @@ -30,19 +30,9 @@ steps: chmod +x .local/bin/md5 if [ "{{matrix.os}}" == "macos" ]; then - if [ "{{matrix.arch}}" == "aarch64" ]; then - sed -i.bak 's~targets = \[.*\]~targets = \[\"AArch64\", \"AMDGPU\"]~g' WORKSPACE - curl -fLO https://repo.anaconda.com/miniconda/Miniconda3-py3`echo {{matrix.python}} | cut -c 3-`_24.7.1-0-MacOSX-arm64.sh - else - sed -i.bak 's~targets = \[.*\]~targets = \[\"X86\", \"AMDGPU\"]~g' WORKSPACE - curl -fLO https://repo.anaconda.com/miniconda/Miniconda3-py3`echo {{matrix.python}} | cut -c 3-`_24.7.1-0-MacOSX-{{matrix.arch}}.sh - fi curl -fLO "https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-darwin" mv bazelisk-darwin .local/bin/bazel chmod +x .local/bin/bazel - chmod +x Miniconda*.sh - ./Miniconda*.sh -b -p `pwd`/conda - rm Miniconda*.sh elif [ "{{matrix.os}}" == "linux" ]; then if [ "{{matrix.arch}}" == "aarch64" ]; then curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-linux-arm64 @@ -51,12 +41,7 @@ steps: fi mv bazel* .local/bin/bazel chmod +x .local/bin/bazel - wget https://repo.anaconda.com/miniconda/Miniconda3-py3`echo {{matrix.python}} | cut -c 3-`_24.7.1-0-Linux-{{matrix.arch}}.sh - chmod +x Miniconda*.sh - ./Miniconda*.sh -b -p `pwd`/conda - rm Miniconda*.sh else - wget https://repo.anaconda.com/miniconda/Miniconda3-py3`echo {{matrix.python}} | cut -c 3-`_24.7.1-0-Windows-{{matrix.arch}}.exe if [ "{{matrix.arch}}" == "aarch64" ]; then wget https://github.com/bazelbuild/bazel/releases/download/6.2.1/bazel-6.2.1-windows-arm64.exe else @@ -64,12 +49,12 @@ steps: fi mv bazel* .local/bin/bazel.exe fi - python -m ensurepip --upgrade - python -m pip install --user numpy wheel mkdir baztmp export TAG=`echo $BUILDKITE_TAG | cut -c2-` sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD - HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/baztmp build --define=no_nccl_support=true :enzyme_ad + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/baztmp build @llvm-project//llvm:FileCheck + rm bazel-bin/*.whl + HERMETIC_PYTHON_VERSION={{matrix.python}} bazel --output_user_root=`pwd`/baztmp build :wheel cp bazel-bin/*.whl . python -m pip install *.whl python -m pip install --user twine diff --git a/.buildkite/secure_pipeline.yml.signature b/.buildkite/secure_pipeline.yml.signature index f480d94d..9ca6983b 100644 --- a/.buildkite/secure_pipeline.yml.signature +++ b/.buildkite/secure_pipeline.yml.signature @@ -1 +1 @@ -Salted__QÄ "‚¨ì=xÒÕ:ŽW…Ô‡.Þ<+ƒ°¹9× E¸gÏh¢d;h?}éïŸeõ^VrÕ?nÓTë˜WdˤûÀNÑ.Tî±Aû¼Y=q#V:îæM7ô \ No newline at end of file +Salted__¸F¢ ¾ÜŒɈLY·o¶jžvQÖÀZ÷¬4ý*wÐî5µ5Wúy#‰¡ÒõWY΀ۇÕÈÐ86‡è9×ûÌõŘ;î–ÔIŒQ©]PvkU'ébæIJ \ No newline at end of file diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 086bc8ea..28fd0023 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,22 +9,23 @@ on: jobs: build: - name: Build ${{ matrix.os }} + name: Build ${{ matrix.os }} python ${{ matrix.python }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [openstack22] + python: ["3.12"] timeout-minutes: 500 steps: - name: add llvm run: | if [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then sudo apt-get update - sudo apt-get install -y git gcc g++ python3 python3-dev python3-pip - sudo python3 -m pip install --upgrade lit + sudo apt-get install -y git gcc g++ #python3 python3-dev python3-pip + #sudo python3 -m pip install --upgrade lit + #- run: python3 -m pip install --user numpy fi - - run: python3 -m pip install --user numpy - uses: actions/checkout@v3 with: submodules: recursive @@ -42,20 +43,15 @@ jobs: repository-cache: true bazelisk-version: 1.x - - run: | - HERMETIC_PYTHON_VERSION=`python3 -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel build @llvm-project//llvm:FileCheck - sudo rm bazel-bin/*.whl || echo - HERMETIC_PYTHON_VERSION=`python3 -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel build :enzyme_ad - - - run: cp bazel-bin/*.whl . - - name: test run: | - ls -all . - ls -all bazel-bin - python3 -m pip uninstall enzyme-ad -y || echo - python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl - HERMETIC_PYTHON_VERSION=`python3 -c "import sys; print('{0[0]}.{0[1]}'.format(sys.version_info))"` bazel test --test_output=errors ... + HERMETIC_PYTHON_VERSION=${{ matrix.python }} bazel test --test_output=errors ... + + - name: Build Wheel + run: | + sudo rm bazel-bin/*.whl || echo + HERMETIC_PYTHON_VERSION=${{ matrix.python }} bazel build :wheel + cp bazel-bin/*.whl . - name: Upload Build uses: actions/upload-artifact@v3 diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index c920d8ab..51f8bcc5 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -6,7 +6,7 @@ on: jobs: build: - name: Build ${{ matrix.os }} python ${{ matrix.python }} + name: Tag ${{ matrix.os }} python ${{ matrix.python }} runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -45,7 +45,7 @@ jobs: - run: | HERMETIC_PYTHON_VERSION=${{ matrix.python }} bazel build @llvm-project//llvm:FileCheck sudo rm bazel-bin/*.whl || echo - HERMETIC_PYTHON_VERSION=${{ matrix.python }} bazel build :enzyme_ad + HERMETIC_PYTHON_VERSION=${{ matrix.python }} bazel build :wheel - env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} diff --git a/BUILD b/BUILD index 62422914..aaac16ef 100644 --- a/BUILD +++ b/BUILD @@ -66,8 +66,22 @@ cc_binary( ], ) -py_wheel( +py_library( name = "enzyme_ad", + visibility = ["//visibility:public"], + deps = [ + "@pypi_jax//:pkg", + "@pypi_absl_py//:pkg", + ], + imports=["src"], + data = [ + "//:enzyme_jax_data", + "//src/enzyme_ad/jax:enzyme_jax_internal", + ] +) + +py_wheel( + name = "wheel", author = "Enzyme Authors", author_email = "wmoses@mit.edu, zinenko@google.com", distribution = "enzyme_ad", diff --git a/README.md b/README.md index f93e1449..90dbec18 100644 --- a/README.md +++ b/README.md @@ -44,13 +44,13 @@ pip install enzyme-ad ## Building from source -Requirements: `bazel-6.2.1`, `clang++`, `python`, `python-virtualenv`, +Requirements: `bazel-6.5`, `clang++`, `python`, `python-virtualenv`, `python3-dev`. Build our extension with: ```sh # Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl -bazel build :enzyme_ad +bazel build :wheel ``` Finally, install the built library with: @@ -61,6 +61,12 @@ Note that you cannot run code from the root of the git directory. For instance, ## Running the test +To run tests, you can simply execute the following bazel commands (this does not require building or installing the wheel). +```sh +bazel test //test/... +``` + +Alternatively, if you have installed the wheel, you can manually invoke the tests as follows ```sh cd test && python test.py ``` diff --git a/WORKSPACE b/WORKSPACE index 2d13b104..7cba04fb 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -41,12 +41,13 @@ python_init_rules() load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { - "3.9": "//build:requirements_lock_3_9.txt", - "3.10": "//build:requirements_lock_3_10.txt", - "3.11": "//build:requirements_lock_3_11.txt", - "3.12": "//build:requirements_lock_3_12.txt", - "3.13": "//build:requirements_lock_3_13.txt", + "3.10": "//builddeps:requirements_lock_3_10.txt", + "3.11": "//builddeps:requirements_lock_3_11.txt", + "3.12": "//builddeps:requirements_lock_3_12.txt", }, + local_wheel_inclusion_list = [ + "enzyme_ad*", + ] ) load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") @@ -55,16 +56,8 @@ python_init_toolchains() load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") python_init_pip() -load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") -python_init_rules() - -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") - -pip_install_dependencies() +load("@pypi//:requirements.bzl", "install_deps") +install_deps() http_archive( name = "enzyme", diff --git a/builddeps/BUILD b/builddeps/BUILD new file mode 100644 index 00000000..be0ae35f --- /dev/null +++ b/builddeps/BUILD @@ -0,0 +1,17 @@ +licenses(["notice"]) + +load("@python//:defs.bzl", "compile_pip_requirements") +load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") + +compile_pip_requirements( + name = "requirements", + extra_args = [ + "--allow-unsafe", + "--build-isolation", + "--rebuild", + ], + requirements_in = "requirements.in", + requirements_txt = REQUIREMENTS, + generate_hashes = True, + data = ["test-requirements.txt"] +) diff --git a/builddeps/requirements.in b/builddeps/requirements.in new file mode 100644 index 00000000..58fa9a1e --- /dev/null +++ b/builddeps/requirements.in @@ -0,0 +1,8 @@ +# +# test deps +# +-r test-requirements.txt + +jax >= 0.4.21 +jaxlib >= 0.4.21 +absl_py >= 2.0.0 diff --git a/builddeps/requirements_lock_3_10.txt b/builddeps/requirements_lock_3_10.txt new file mode 100644 index 00000000..fb0f8bb2 --- /dev/null +++ b/builddeps/requirements_lock_3_10.txt @@ -0,0 +1,172 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# bazel run //builddeps:requirements.update +# +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +jax==0.4.31 \ + --hash=sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7 \ + --hash=sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287 + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +jax-cuda12-pjrt==0.4.31 \ + --hash=sha256:3e77d1cfebeca06517254eb568f082037e1a2aa3ed8f63c543492ad8ab5a1585 \ + --hash=sha256:8961abb381d893a3c2392ad76ab2067a81f8f2514f3b47d2da3ac24283293fe0 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.4.31 ; sys_platform == "linux" \ + --hash=sha256:146f26928ca719a0daa14bb9a9f5a5cbfa20211e76ea05b7bb534277a658a995 \ + --hash=sha256:5048acdf29755303b2887d948137fd82891b48cdbc086e67640ebb976457cfcc \ + --hash=sha256:5cfa46d4106f70c31944c9fafcfa7b04c6f9886a1b35df5477cde7d9c43eb8bd \ + --hash=sha256:7a179e5e80dd9890972d777d597f3c902c6876e42bcf2edcfe4f3ec5a610472e \ + --hash=sha256:a3727a332fbeac625ab6d5ae63d0ed9e62d1e0b5011f72130c9bc96e797395d5 \ + --hash=sha256:cdb6d0c4009438a6a6bd7997ab8a1194beda9ae7322b8d265eea9e551b6c2b4e + # via -r builddeps/test-requirements.txt +jaxlib==0.4.31 \ + --hash=sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d \ + --hash=sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d \ + --hash=sha256:1db6f8ea35b884f9e7761b006ee9c60ed05be6c75d2e527551f74579cbe11677 \ + --hash=sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd \ + --hash=sha256:1fd838ff91ea58ec2bdc7b4ecbb921ad501a318fafdeae120e6e7f88f5c20b17 \ + --hash=sha256:2d2639d210b0b1918dfaabbcc504fc668326e1a6fd1f0eb427c40b039188bbce \ + --hash=sha256:48ea73cb78341bd4aabbb15e1a076ed61505ec80ab8eb4810e2d34758c400f80 \ + --hash=sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803 \ + --hash=sha256:86340df8b37729f6fc5742f17761857bb9e59c418c9453e9b090f49f6194cdf9 \ + --hash=sha256:bacb86012f9104dd71706266420fd1e5d179d826d0635c95fe31506d605b4537 \ + --hash=sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072 \ + --hash=sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832 \ + --hash=sha256:ceec494df08aaf65b8bbcbd40dd21a6579fa76ca5b851cce46fd7ce0388c0449 \ + --hash=sha256:d019023f71dba65127a3016ddc755de4b30f5bc9bd5b632a716a5fb3b00c5e53 \ + --hash=sha256:d3540a557c188d23ef93760da482b158ca910124a0445263c3b17c09c114538a + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt + # jax +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via + # jax + # jaxlib +numpy==2.1.0 \ + --hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \ + --hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \ + --hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \ + --hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \ + --hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \ + --hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \ + --hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \ + --hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \ + --hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \ + --hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \ + --hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \ + --hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \ + --hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \ + --hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \ + --hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \ + --hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \ + --hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \ + --hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \ + --hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \ + --hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \ + --hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \ + --hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \ + --hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \ + --hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \ + --hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \ + --hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \ + --hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \ + --hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \ + --hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \ + --hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \ + --hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \ + --hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \ + --hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \ + --hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \ + --hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \ + --hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \ + --hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \ + --hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \ + --hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \ + --hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \ + --hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \ + --hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \ + --hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \ + --hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \ + --hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \ + --hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \ + --hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \ + --hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \ + --hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \ + --hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \ + --hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \ + --hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e + # via + # -r builddeps/test-requirements.txt + # jax + # jaxlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via jax +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 + # via + # jax + # jaxlib diff --git a/builddeps/requirements_lock_3_11.txt b/builddeps/requirements_lock_3_11.txt new file mode 100644 index 00000000..b64d3b71 --- /dev/null +++ b/builddeps/requirements_lock_3_11.txt @@ -0,0 +1,285 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# bazel run //builddeps:requirements.update +# +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 + # via requests +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 + # via requests +idna==3.8 \ + --hash=sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac \ + --hash=sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603 + # via requests +jax==0.4.31 \ + --hash=sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7 \ + --hash=sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287 + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +jax-cuda12-pjrt==0.4.31 \ + --hash=sha256:3e77d1cfebeca06517254eb568f082037e1a2aa3ed8f63c543492ad8ab5a1585 \ + --hash=sha256:8961abb381d893a3c2392ad76ab2067a81f8f2514f3b47d2da3ac24283293fe0 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.4.31 ; sys_platform == "linux" \ + --hash=sha256:146f26928ca719a0daa14bb9a9f5a5cbfa20211e76ea05b7bb534277a658a995 \ + --hash=sha256:5048acdf29755303b2887d948137fd82891b48cdbc086e67640ebb976457cfcc \ + --hash=sha256:5cfa46d4106f70c31944c9fafcfa7b04c6f9886a1b35df5477cde7d9c43eb8bd \ + --hash=sha256:7a179e5e80dd9890972d777d597f3c902c6876e42bcf2edcfe4f3ec5a610472e \ + --hash=sha256:a3727a332fbeac625ab6d5ae63d0ed9e62d1e0b5011f72130c9bc96e797395d5 \ + --hash=sha256:cdb6d0c4009438a6a6bd7997ab8a1194beda9ae7322b8d265eea9e551b6c2b4e + # via -r builddeps/test-requirements.txt +jaxlib==0.4.31 \ + --hash=sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d \ + --hash=sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d \ + --hash=sha256:1db6f8ea35b884f9e7761b006ee9c60ed05be6c75d2e527551f74579cbe11677 \ + --hash=sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd \ + --hash=sha256:1fd838ff91ea58ec2bdc7b4ecbb921ad501a318fafdeae120e6e7f88f5c20b17 \ + --hash=sha256:2d2639d210b0b1918dfaabbcc504fc668326e1a6fd1f0eb427c40b039188bbce \ + --hash=sha256:48ea73cb78341bd4aabbb15e1a076ed61505ec80ab8eb4810e2d34758c400f80 \ + --hash=sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803 \ + --hash=sha256:86340df8b37729f6fc5742f17761857bb9e59c418c9453e9b090f49f6194cdf9 \ + --hash=sha256:bacb86012f9104dd71706266420fd1e5d179d826d0635c95fe31506d605b4537 \ + --hash=sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072 \ + --hash=sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832 \ + --hash=sha256:ceec494df08aaf65b8bbcbd40dd21a6579fa76ca5b851cce46fd7ce0388c0449 \ + --hash=sha256:d019023f71dba65127a3016ddc755de4b30f5bc9bd5b632a716a5fb3b00c5e53 \ + --hash=sha256:d3540a557c188d23ef93760da482b158ca910124a0445263c3b17c09c114538a + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt + # jax +libtpu-nightly==0.1.dev20240729 ; sys_platform == "linux" \ + --hash=sha256:e3e4a4305e673d95cdd0076785e506234fc63b51124eeb3e301727f68b6b0126 + # via -r builddeps/test-requirements.txt +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via + # jax + # jaxlib +numpy==2.1.0 \ + --hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \ + --hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \ + --hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \ + --hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \ + --hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \ + --hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \ + --hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \ + --hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \ + --hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \ + --hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \ + --hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \ + --hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \ + --hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \ + --hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \ + --hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \ + --hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \ + --hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \ + --hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \ + --hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \ + --hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \ + --hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \ + --hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \ + --hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \ + --hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \ + --hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \ + --hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \ + --hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \ + --hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \ + --hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \ + --hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \ + --hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \ + --hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \ + --hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \ + --hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \ + --hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \ + --hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \ + --hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \ + --hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \ + --hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \ + --hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \ + --hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \ + --hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \ + --hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \ + --hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \ + --hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \ + --hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \ + --hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \ + --hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \ + --hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \ + --hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \ + --hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \ + --hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e + # via + # -r builddeps/test-requirements.txt + # jax + # jaxlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via jax +requests==2.32.3 ; sys_platform == "linux" \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r builddeps/test-requirements.txt +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 + # via + # jax + # jaxlib +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 + # via requests diff --git a/builddeps/requirements_lock_3_12.txt b/builddeps/requirements_lock_3_12.txt new file mode 100644 index 00000000..a223e5db --- /dev/null +++ b/builddeps/requirements_lock_3_12.txt @@ -0,0 +1,285 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# bazel run //builddeps:requirements.update +# +--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html + +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +certifi==2024.7.4 \ + --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ + --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 + # via requests +charset-normalizer==3.3.2 \ + --hash=sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027 \ + --hash=sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087 \ + --hash=sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786 \ + --hash=sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8 \ + --hash=sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09 \ + --hash=sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185 \ + --hash=sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574 \ + --hash=sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e \ + --hash=sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519 \ + --hash=sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898 \ + --hash=sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269 \ + --hash=sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3 \ + --hash=sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f \ + --hash=sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6 \ + --hash=sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8 \ + --hash=sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a \ + --hash=sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73 \ + --hash=sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc \ + --hash=sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714 \ + --hash=sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2 \ + --hash=sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc \ + --hash=sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce \ + --hash=sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d \ + --hash=sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e \ + --hash=sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6 \ + --hash=sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269 \ + --hash=sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96 \ + --hash=sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d \ + --hash=sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a \ + --hash=sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4 \ + --hash=sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77 \ + --hash=sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d \ + --hash=sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0 \ + --hash=sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed \ + --hash=sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068 \ + --hash=sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac \ + --hash=sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25 \ + --hash=sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8 \ + --hash=sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab \ + --hash=sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26 \ + --hash=sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2 \ + --hash=sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db \ + --hash=sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f \ + --hash=sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5 \ + --hash=sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99 \ + --hash=sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c \ + --hash=sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d \ + --hash=sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811 \ + --hash=sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa \ + --hash=sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a \ + --hash=sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03 \ + --hash=sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b \ + --hash=sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04 \ + --hash=sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c \ + --hash=sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001 \ + --hash=sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458 \ + --hash=sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389 \ + --hash=sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99 \ + --hash=sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985 \ + --hash=sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537 \ + --hash=sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238 \ + --hash=sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f \ + --hash=sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d \ + --hash=sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796 \ + --hash=sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a \ + --hash=sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143 \ + --hash=sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8 \ + --hash=sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c \ + --hash=sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5 \ + --hash=sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5 \ + --hash=sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711 \ + --hash=sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4 \ + --hash=sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6 \ + --hash=sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c \ + --hash=sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7 \ + --hash=sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4 \ + --hash=sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b \ + --hash=sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae \ + --hash=sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12 \ + --hash=sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c \ + --hash=sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae \ + --hash=sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8 \ + --hash=sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887 \ + --hash=sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b \ + --hash=sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4 \ + --hash=sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f \ + --hash=sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5 \ + --hash=sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33 \ + --hash=sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519 \ + --hash=sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561 + # via requests +idna==3.8 \ + --hash=sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac \ + --hash=sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603 + # via requests +jax==0.4.31 \ + --hash=sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7 \ + --hash=sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287 + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt +jax-cuda12-pjrt==0.4.31 \ + --hash=sha256:3e77d1cfebeca06517254eb568f082037e1a2aa3ed8f63c543492ad8ab5a1585 \ + --hash=sha256:8961abb381d893a3c2392ad76ab2067a81f8f2514f3b47d2da3ac24283293fe0 + # via jax-cuda12-plugin +jax-cuda12-plugin==0.4.31 ; sys_platform == "linux" \ + --hash=sha256:146f26928ca719a0daa14bb9a9f5a5cbfa20211e76ea05b7bb534277a658a995 \ + --hash=sha256:5048acdf29755303b2887d948137fd82891b48cdbc086e67640ebb976457cfcc \ + --hash=sha256:5cfa46d4106f70c31944c9fafcfa7b04c6f9886a1b35df5477cde7d9c43eb8bd \ + --hash=sha256:7a179e5e80dd9890972d777d597f3c902c6876e42bcf2edcfe4f3ec5a610472e \ + --hash=sha256:a3727a332fbeac625ab6d5ae63d0ed9e62d1e0b5011f72130c9bc96e797395d5 \ + --hash=sha256:cdb6d0c4009438a6a6bd7997ab8a1194beda9ae7322b8d265eea9e551b6c2b4e + # via -r builddeps/test-requirements.txt +jaxlib==0.4.31 \ + --hash=sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d \ + --hash=sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d \ + --hash=sha256:1db6f8ea35b884f9e7761b006ee9c60ed05be6c75d2e527551f74579cbe11677 \ + --hash=sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd \ + --hash=sha256:1fd838ff91ea58ec2bdc7b4ecbb921ad501a318fafdeae120e6e7f88f5c20b17 \ + --hash=sha256:2d2639d210b0b1918dfaabbcc504fc668326e1a6fd1f0eb427c40b039188bbce \ + --hash=sha256:48ea73cb78341bd4aabbb15e1a076ed61505ec80ab8eb4810e2d34758c400f80 \ + --hash=sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803 \ + --hash=sha256:86340df8b37729f6fc5742f17761857bb9e59c418c9453e9b090f49f6194cdf9 \ + --hash=sha256:bacb86012f9104dd71706266420fd1e5d179d826d0635c95fe31506d605b4537 \ + --hash=sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072 \ + --hash=sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832 \ + --hash=sha256:ceec494df08aaf65b8bbcbd40dd21a6579fa76ca5b851cce46fd7ce0388c0449 \ + --hash=sha256:d019023f71dba65127a3016ddc755de4b30f5bc9bd5b632a716a5fb3b00c5e53 \ + --hash=sha256:d3540a557c188d23ef93760da482b158ca910124a0445263c3b17c09c114538a + # via + # -r builddeps/requirements.in + # -r builddeps/test-requirements.txt + # jax +libtpu-nightly==0.1.dev20240729 ; sys_platform == "linux" \ + --hash=sha256:e3e4a4305e673d95cdd0076785e506234fc63b51124eeb3e301727f68b6b0126 + # via -r builddeps/test-requirements.txt +ml-dtypes==0.4.0 \ + --hash=sha256:03e7cda6ef164eed0abb31df69d2c00c3a5ab3e2610b6d4c42183a43329c72a5 \ + --hash=sha256:2bb83fd064db43e67e67d021e547698af4c8d5c6190f2e9b1c53c09f6ff5531d \ + --hash=sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675 \ + --hash=sha256:41affb38fdfe146e3db226cf2953021184d6f0c4ffab52136613e9601706e368 \ + --hash=sha256:43cf4356a0fe2eeac6d289018d0734e17a403bdf1fd911953c125dd0358edcc0 \ + --hash=sha256:723af6346447268a3cf0b7356e963d80ecb5732b5279b2aa3fa4b9fc8297c85e \ + --hash=sha256:75b4faf99d0711b81f393db36d210b4255fd419f6f790bc6c1b461f95ffb7a9e \ + --hash=sha256:93afe37f3a879d652ec9ef1fc47612388890660a2657fbb5747256c3b818fd81 \ + --hash=sha256:a15d96d090aebb55ee85173d1775ae325a001aab607a76c8ea0b964ccd6b5364 \ + --hash=sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49 \ + --hash=sha256:bdf689be7351cc3c95110c910c1b864002f113e682e44508910c849e144f3df1 \ + --hash=sha256:c83e4d443962d891d51669ff241d5aaad10a8d3d37a81c5532a45419885d591c \ + --hash=sha256:e1e2f4237b459a63c97c2c9f449baa637d7e4c20addff6a9bac486f22432f3b6 \ + --hash=sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259 \ + --hash=sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb \ + --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ + --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 + # via + # jax + # jaxlib +numpy==2.1.0 \ + --hash=sha256:08801848a40aea24ce16c2ecde3b756f9ad756586fb2d13210939eb69b023f5b \ + --hash=sha256:0937e54c09f7a9a68da6889362ddd2ff584c02d015ec92672c099b61555f8911 \ + --hash=sha256:0ab32eb9170bf8ffcbb14f11613f4a0b108d3ffee0832457c5d4808233ba8977 \ + --hash=sha256:0abb3916a35d9090088a748636b2c06dc9a6542f99cd476979fb156a18192b84 \ + --hash=sha256:0af3a5987f59d9c529c022c8c2a64805b339b7ef506509fba7d0556649b9714b \ + --hash=sha256:10e2350aea18d04832319aac0f887d5fcec1b36abd485d14f173e3e900b83e33 \ + --hash=sha256:15ef8b2177eeb7e37dd5ef4016f30b7659c57c2c0b57a779f1d537ff33a72c7b \ + --hash=sha256:1f817c71683fd1bb5cff1529a1d085a57f02ccd2ebc5cd2c566f9a01118e3b7d \ + --hash=sha256:24003ba8ff22ea29a8c306e61d316ac74111cebf942afbf692df65509a05f111 \ + --hash=sha256:30014b234f07b5fec20f4146f69e13cfb1e33ee9a18a1879a0142fbb00d47673 \ + --hash=sha256:343e3e152bf5a087511cd325e3b7ecfd5b92d369e80e74c12cd87826e263ec06 \ + --hash=sha256:378cb4f24c7d93066ee4103204f73ed046eb88f9ad5bb2275bb9fa0f6a02bd36 \ + --hash=sha256:398049e237d1aae53d82a416dade04defed1a47f87d18d5bd615b6e7d7e41d1f \ + --hash=sha256:3a3336fbfa0d38d3deacd3fe7f3d07e13597f29c13abf4d15c3b6dc2291cbbdd \ + --hash=sha256:442596f01913656d579309edcd179a2a2f9977d9a14ff41d042475280fc7f34e \ + --hash=sha256:44e44973262dc3ae79e9063a1284a73e09d01b894b534a769732ccd46c28cc62 \ + --hash=sha256:54139e0eb219f52f60656d163cbe67c31ede51d13236c950145473504fa208cb \ + --hash=sha256:5474dad8c86ee9ba9bb776f4b99ef2d41b3b8f4e0d199d4f7304728ed34d0300 \ + --hash=sha256:54c6a63e9d81efe64bfb7bcb0ec64332a87d0b87575f6009c8ba67ea6374770b \ + --hash=sha256:624884b572dff8ca8f60fab591413f077471de64e376b17d291b19f56504b2bb \ + --hash=sha256:6326ab99b52fafdcdeccf602d6286191a79fe2fda0ae90573c5814cd2b0bc1b8 \ + --hash=sha256:652e92fc409e278abdd61e9505649e3938f6d04ce7ef1953f2ec598a50e7c195 \ + --hash=sha256:6c1de77ded79fef664d5098a66810d4d27ca0224e9051906e634b3f7ead134c2 \ + --hash=sha256:76368c788ccb4f4782cf9c842b316140142b4cbf22ff8db82724e82fe1205dce \ + --hash=sha256:7a894c51fd8c4e834f00ac742abad73fc485df1062f1b875661a3c1e1fb1c2f6 \ + --hash=sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2 \ + --hash=sha256:848c6b5cad9898e4b9ef251b6f934fa34630371f2e916261070a4eb9092ffd33 \ + --hash=sha256:899da829b362ade41e1e7eccad2cf274035e1cb36ba73034946fccd4afd8606b \ + --hash=sha256:8ab81ccd753859ab89e67199b9da62c543850f819993761c1e94a75a814ed667 \ + --hash=sha256:8fb49a0ba4d8f41198ae2d52118b050fd34dace4b8f3fb0ee34e23eb4ae775b1 \ + --hash=sha256:9156ca1f79fc4acc226696e95bfcc2b486f165a6a59ebe22b2c1f82ab190384a \ + --hash=sha256:9523f8b46485db6939bd069b28b642fec86c30909cea90ef550373787f79530e \ + --hash=sha256:a0756a179afa766ad7cb6f036de622e8a8f16ffdd55aa31f296c870b5679d745 \ + --hash=sha256:a0cdef204199278f5c461a0bed6ed2e052998276e6d8ab2963d5b5c39a0500bc \ + --hash=sha256:ab83adc099ec62e044b1fbb3a05499fa1e99f6d53a1dde102b2d85eff66ed324 \ + --hash=sha256:b34fa5e3b5d6dc7e0a4243fa0f81367027cb6f4a7215a17852979634b5544ee0 \ + --hash=sha256:b47c551c6724960479cefd7353656498b86e7232429e3a41ab83be4da1b109e8 \ + --hash=sha256:c4cd94dfefbefec3f8b544f61286584292d740e6e9d4677769bc76b8f41deb02 \ + --hash=sha256:c4f982715e65036c34897eb598d64aef15150c447be2cfc6643ec7a11af06574 \ + --hash=sha256:d8f699a709120b220dfe173f79c73cb2a2cab2c0b88dd59d7b49407d032b8ebd \ + --hash=sha256:dd94ce596bda40a9618324547cfaaf6650b1a24f5390350142499aa4e34e53d1 \ + --hash=sha256:de844aaa4815b78f6023832590d77da0e3b6805c644c33ce94a1e449f16d6ab5 \ + --hash=sha256:e5f0642cdf4636198a4990de7a71b693d824c56a757862230454629cf62e323d \ + --hash=sha256:f07fa2f15dabe91259828ce7d71b5ca9e2eb7c8c26baa822c825ce43552f4883 \ + --hash=sha256:f15976718c004466406342789f31b6673776360f3b1e3c575f25302d7e789575 \ + --hash=sha256:f358ea9e47eb3c2d6eba121ab512dfff38a88db719c38d1e67349af210bc7529 \ + --hash=sha256:f505264735ee074250a9c78247ee8618292091d9d1fcc023290e9ac67e8f1afa \ + --hash=sha256:f5ebbf9fbdabed208d4ecd2e1dfd2c0741af2f876e7ae522c2537d404ca895c3 \ + --hash=sha256:f6b26e6c3b98adb648243670fddc8cab6ae17473f9dc58c51574af3e64d61211 \ + --hash=sha256:f8e93a01a35be08d31ae33021e5268f157a2d60ebd643cfc15de6ab8e4722eb1 \ + --hash=sha256:fe76d75b345dc045acdbc006adcb197cc680754afd6c259de60d358d60c93736 \ + --hash=sha256:ffbd6faeb190aaf2b5e9024bac9622d2ee549b7ec89ef3a9373fa35313d44e0e + # via + # -r builddeps/test-requirements.txt + # jax + # jaxlib + # ml-dtypes + # opt-einsum + # scipy +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via jax +requests==2.32.3 ; sys_platform == "linux" \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r builddeps/test-requirements.txt +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 + # via + # jax + # jaxlib +urllib3==2.2.2 \ + --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ + --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 + # via requests diff --git a/builddeps/test-requirements.txt b/builddeps/test-requirements.txt new file mode 100644 index 00000000..b27f584d --- /dev/null +++ b/builddeps/test-requirements.txt @@ -0,0 +1,8 @@ +absl-py +jax +numpy +jaxlib +jax-cuda12-plugin; sys_platform == 'linux' +requests; sys_platform == 'linux' +-f https://storage.googleapis.com/jax-releases/libtpu_releases.html +libtpu-nightly == 0.1.dev20240729; sys_platform == 'linux' diff --git a/src/enzyme_ad/jax/__init__.py b/src/enzyme_ad/jax/__init__.py index 6a5b459e..b0da6758 100644 --- a/src/enzyme_ad/jax/__init__.py +++ b/src/enzyme_ad/jax/__init__.py @@ -6,4 +6,5 @@ JaXPipeline, optimize_module, export, + hlo_opts, ) diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 08b484eb..88452c8f 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -78,6 +78,8 @@ class CpuKernel { uint64_t addr; public: + static constexpr size_t UNKNOWN_PLATFORM = 0x1000000000; + CpuKernel(int64_t identifier, size_t num_out, uint64_t addr) : identifier(identifier), num_out(num_out), addr(addr) {} @@ -893,7 +895,10 @@ class CpuKernel { llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, ABI mode, - Language lang, bool xla_runtime, const std::string &pass_pipeline) { + Language lang, bool xla_runtime, const std::string &pass_pipeline, + const std::string &platform) { + if (platform != "cpu") + return std::make_tuple(UNKNOWN_PLATFORM, 0); llvm::sys::SmartScopedWriter lock(kernel_mutex); size_t identifier = last_identifier++; @@ -993,10 +998,14 @@ std::unique_ptr CpuKernel::JIT = nullptr; // CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); } // namespace -void CpuCallback(void *out, void **ins) { +void Callback(void *out, void **ins) { int64_t identifier = *reinterpret_cast(ins[0]); CpuKernel *kernel = CpuKernel::get(identifier); if (!kernel) { + if (identifier == CpuKernel::UNKNOWN_PLATFORM) { + throw pybind11::value_error( + "Unknown platform callback could not be executed"); + } // TODO: find a way to fail more gracefully. llvm::report_fatal_error("couldn't find enzyme kernel"); } @@ -1047,12 +1056,13 @@ PYBIND11_MODULE(enzyme_call, m) { .value("Reverse", ABI::Reverse) .value("Tape", ABI::Tape); - m.def("create_enzyme_cpu_kernel", + m.def("create_enzyme_kernel", [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, const pybind11::list &py_in_shapes, pybind11::object pyargv, ABI mode, Language lang, bool xla_runtime, - const std::string &pass_pipeline) -> std::tuple { + const std::string &pass_pipeline, + const std::string &platform) -> std::tuple { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -1088,7 +1098,7 @@ PYBIND11_MODULE(enzyme_call, m) { } return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang, - xla_runtime, pass_pipeline); + xla_runtime, pass_pipeline, platform); }); m.def("tmp_size", @@ -1193,8 +1203,8 @@ PYBIND11_MODULE(enzyme_call, m) { pyargv.ptr(), (Language)lang, xla_runtime, pass_pipeline); }); - m.def("get_cpu_callback", []() { - return pybind11::capsule(reinterpret_cast(&CpuCallback), + m.def("get_callback", []() { + return pybind11::capsule(reinterpret_cast(&Callback), "xla._CUSTOM_CALL_TARGET"); }); diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index c9e85a29..d40b2285 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -427,7 +427,12 @@ def resource_dir(): import os dn = os.path.dirname(enzyme_call.__file__) - res = os.path.join(dn, "..", "..", "clang", "staging") + if os.getenv("ENZYME_BAZEL_NOWHEEL", None) is None: + res = os.path.join( + dn, "..", "..", "..", "external", "llvm-project", "clang", "staging" + ) + else: + res = os.path.join(dn, "..", "..", "clang", "staging") return res @@ -573,10 +578,7 @@ def absmaketup(ty): def lower(fn, vals, parameters=None): if hasattr(fn, "trace"): - if parameters is not None: - return fn.trace(*vals).lower(_private_parameters=parameters) - else: - return fn.trace(*vals).lower() + return fn.trace(*vals).lower(_private_parameters=parameters) else: if parameters is not None: return fn.lower(*vals, _experimental_lowering_parameters=parameters) @@ -829,7 +831,8 @@ def _enzyme_primal_lowering( print(out_shapes, "\n", results, "\n", nmod) assert len(results) == len(out_shapes) else: - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -839,6 +842,7 @@ def _enzyme_primal_lowering( lang, pipeline_options.xla_runtime(), pass_pipeline, + ctx.module_context.platforms[0], ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -874,7 +878,8 @@ def _enzyme_primal_lowering( results = tuple(results2) else: - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -884,6 +889,7 @@ def _enzyme_primal_lowering( lang, pipeline_options.xla_runtime(), pass_pipeline, + ctx.module_context.platforms[0], ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -940,7 +946,8 @@ def _enzyme_fwd_lowering( in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -950,6 +957,7 @@ def _enzyme_fwd_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0], ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1004,7 +1012,8 @@ def _enzyme_aug_lowering( in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -1014,6 +1023,7 @@ def _enzyme_aug_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0], ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1075,7 +1085,8 @@ def _enzyme_rev_lowering( ) argv = tuple(argv) + ("-resource-dir", resource_dir()) + cflags() - identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel( + assert len(ctx.module_context.platforms) == 1 + identifier, tmpBuf = enzyme_call.create_enzyme_kernel( source, fn, out_shapes, @@ -1085,6 +1096,7 @@ def _enzyme_rev_lowering( lang, pipeline_options.xla_runtime(), pipeline_options.pass_pipeline(), + ctx.module_context.platforms[0], ) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) @@ -1159,19 +1171,15 @@ def cpp_call( _enzyme_primal_p.def_abstract_eval(_enzyme_primal_abstract_eval) jax_mlir.register_lowering(_enzyme_primal_p, _enzyme_primal_lowering) -xla_client.register_custom_call_target( - "jaxzyme.primal", enzyme_call.get_cpu_callback(), platform="cpu" -) +xla_client.register_custom_call_target("jaxzyme.primal", enzyme_call.get_callback()) _enzyme_fwd_p = jax.core.Primitive("enzyme_fwd") _enzyme_fwd_p.multiple_results = True _enzyme_fwd_p.def_impl(_enzyme_fwd_impl) _enzyme_fwd_p.def_abstract_eval(_enzyme_fwd_abstract_eval) -jax_mlir.register_lowering(_enzyme_fwd_p, _enzyme_fwd_lowering, platform="cpu") +jax_mlir.register_lowering(_enzyme_fwd_p, _enzyme_fwd_lowering) -xla_client.register_custom_call_target( - "jaxzyme.fwd", enzyme_call.get_cpu_callback(), platform="cpu" -) +xla_client.register_custom_call_target("jaxzyme.fwd", enzyme_call.get_callback()) def enzyme_jvp(arg_primals, arg_tangents, **kwargs): @@ -1279,10 +1287,19 @@ def dejaxify(x): _enzyme_aug_p.multiple_results = True _enzyme_aug_p.def_impl(_enzyme_aug_impl) _enzyme_aug_p.def_abstract_eval(_enzyme_aug_abstract_eval) -jax_mlir.register_lowering(_enzyme_aug_p, _enzyme_aug_lowering, platform="cpu") +jax_mlir.register_lowering(_enzyme_aug_p, _enzyme_aug_lowering) xla_client.register_custom_call_target( - "jaxzyme.aug", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.aug", enzyme_call.get_callback(), platform="cpu" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="CUDA" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="ROCM" +) +xla_client.register_custom_call_target( + "jaxzyme.aug", enzyme_call.get_callback(), platform="tpu" ) _enzyme_shadow_aug_p = jax.core.Primitive("enzyme_shadow_aug") @@ -1294,10 +1311,19 @@ def dejaxify(x): _enzyme_rev_p.multiple_results = True _enzyme_rev_p.def_impl(_enzyme_rev_impl) _enzyme_rev_p.def_abstract_eval(_enzyme_rev_abstract_eval) -jax_mlir.register_lowering(_enzyme_rev_p, _enzyme_rev_lowering, platform="cpu") +jax_mlir.register_lowering(_enzyme_rev_p, _enzyme_rev_lowering) xla_client.register_custom_call_target( - "jaxzyme.rev", enzyme_call.get_cpu_callback(), platform="cpu" + "jaxzyme.rev", enzyme_call.get_callback(), platform="cpu" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="CUDA" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="ROCM" +) +xla_client.register_custom_call_target( + "jaxzyme.rev", enzyme_call.get_callback(), platform="tpu" ) diff --git a/test/BUILD b/test/BUILD index 5addaaef..6558078e 100644 --- a/test/BUILD +++ b/test/BUILD @@ -31,17 +31,61 @@ exports_files( ":lit.cfg.py", ":lit_site_cfg_py", "//:enzymexlamlir-opt", - "//src/enzyme_ad/jax:enzyme_jax_internal", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + ] + glob(["**/*.h"]), + ) + for src in glob( + [ + "**/*.mlir", + ], + ) +] + +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") + +string_flag( + name = "test_tpu", + build_setting_default = "False", +) + +config_setting( + name = "use_tpu", + flag_values = { + ":test_tpu": "True", + }, +) + +TEST_DEPS = [ + "//:enzyme_ad", + "@pypi_jax//:pkg", + "@pypi_absl_py//:pkg", + ] + select({ + ":use_tpu": ["@pypi_libtpu_nightly//:pkg", "@pypi_requests//:pkg"], + "@bazel_tools//src/conditions:linux_x86_64": ["@pypi_jax_cuda12_plugin//:pkg"], + "//conditions:default": [] + }) + + +[ + lit_test( + name = "%s.test" % src, + srcs = [src], + data = [ + ":lit.cfg.py", + ":lit_site_cfg_py", + "//:enzyme_ad", "@llvm-project//clang:builtin_headers_gen", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:count", "@llvm-project//llvm:not", ] + glob(["**/*.h"]), + deps = TEST_DEPS ) for src in glob( [ "**/*.pyt", - "**/*.mlir", ], ) ] @@ -51,27 +95,28 @@ py_test( srcs = [ "test.py", ], - deps = [ - "//src/enzyme_ad/jax:enzyme_jax_internal", - ], + deps = TEST_DEPS, + imports = ["site-packages"], + tags = ["pypi_name=enzyme-ad"] ) py_test( name = "bench_vs_xla", srcs = [ "bench_vs_xla.py", + "test_utils.py" ], - deps = [ - "//src/enzyme_ad/jax:enzyme_jax_internal", - ], + imports = ["."], + deps = TEST_DEPS, ) py_test( name = "llama", srcs = [ "llama.py", + "test_utils.py", ], - deps = [ - "//src/enzyme_ad/jax:enzyme_jax_internal", - ], + imports = ["."], + deps = TEST_DEPS, + timeout='long' ) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index b124ea62..748783ce 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -3,230 +3,7 @@ from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, JaXPipeline from absl.testing import absltest import timeit - -argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") -number = 1000 - -AllPipelines = [ - ("JaXPipeline", JaXPipeline()), - # ("NewXLAMLIR", NewXLAPipeline(mlirad=True)), - # ("NewXLA", NewXLAPipeline()), - ("OldXLA", OldXLAPipeline()), -] -PrimalPipelines = AllPipelines -FwdPipelines = AllPipelines -RevPipelines = AllPipelines - - -def no_newxla(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR" and name != "NewXLA"] - - -def no_newxlamlir(x): - return [(name, a) for (name, a) in x if name != "NewXLAMLIR"] - - -def justjax(x): - return [ - (name, a) - for (name, a) in x - if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" - ] - - -# @jax.jit -# def fwd_jax(in0, in1, din0, din1): -# . return jax.jvp(add_one_jax, (in0, in1), (din0, din1)) -def splatjvp(in_fn): - def fwd(*args): - assert len(args) % 2 == 0 - return jax.jvp( - in_fn, tuple(args[: len(args) // 2]), tuple(args[len(args) // 2 :]) - ) - - return fwd - - -# @jax.jit -# def rev_jax(dout, in0, in1): -# primals, f_vjp = jax.vjp(add_one_jax, in0, in1) -# grads = f_vjp(dout) -# return primals, grads -def splatvjp(in_fn): - def rev(dout, *args): - primals, f_vjp = jax.vjp(in_fn, *args) - grads = f_vjp(dout) - return primals, grads - - return rev - - -class EnzymeJaxTest(absltest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.primfilter = lambda x: x - self.fwdfilter = lambda x: x - self.revfilter = lambda x: x - - def setUp(self): - self.name = None - - def test(self): - if self.name is None: - return - self.harness(self.name, self.fn, self.ins, self.dins, self.douts) - - def harness(self, name, in_fn, ins, dins, douts): - assert len(ins) == len(dins) - rfn_jax = jax.jit(in_fn) - - aop = rfn_jax(*ins) - assert 1 == len(douts) - - primalstr = "fn(" + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" - primalins = {("in" + str(i)): ins[0] for i in range(len(ins))} - - print( - name + " JaX Primal: ", - timeit.Timer( - primalstr, - globals={ - "fn": rfn_jax, - } - | primalins, - ).timeit(number) - / number, - ) - - fwd_jax = jax.jit(splatjvp(rfn_jax)) - - primals_p, tangents_p = fwd_jax(*(ins + dins)) - print(primals_p) - print((jnp.abs(aop - primals_p) < 1e-6).all()) - self.assertTrue((jnp.abs(aop - primals_p) < 1e-6).all()) - - fwdstr = ( - "fwd(" - + (", ".join(["in" + str(i) for i in range(len(ins))])) - + ", " - + (", ".join(["din" + str(i) for i in range(len(dins))])) - + ")" - ) - fwdins = primalins | {("din" + str(i)): dins[0] for i in range(len(dins))} - print( - name + " JaX Fwd: ", - timeit.Timer( - fwdstr, - globals={ - "fwd": fwd_jax, - } - | fwdins, - ).timeit(number) - / number, - ) - - assert len(douts) == 1 - - rev_jax = jax.jit(splatvjp(rfn_jax)) - - primals_p, grads_p = rev_jax(*douts, *ins) - - print(primals_p) - print((jnp.abs(aop - primals_p) < 1e-6).all()) - self.assertTrue((jnp.abs(aop - primals_p) < 1e-6).all()) - - revstr = ( - "rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" - ) - revins = primalins | {"dout": douts[0]} - - print( - name + " JaX Rev: ", - timeit.Timer( - revstr, - globals={ - "rev": rev_jax, - } - | revins, - ).timeit(number) - / number, - ) - - for pname, pipeline in AllPipelines: - rfn_enzyme = jax.jit( - enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(in_fn) - ) - - if (pname, pipeline) in self.primfilter(PrimalPipelines): - ao = rfn_enzyme(*ins) - print(aop) - print((jnp.abs(aop - aop) < 1e-6).all()) - self.assertTrue((jnp.abs(ao - aop) < 1e-6).all()) - - print( - name + " EnzymeMLIR(", - pname, - ") Primal: ", - timeit.Timer( - primalstr, - globals={ - "fn": rfn_enzyme, - } - | primalins, - ).timeit(number) - / number, - ) - - if (pname, pipeline) in self.fwdfilter(FwdPipelines): - fwd_enzyme = jax.jit(splatjvp(rfn_enzyme)) - - primals, tangents = fwd_jax(*(ins + dins)) - - self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) - - if len(tangents.shape) == 0: - self.assertTrue((jnp.abs(tangents - tangents_p) < 1e-6).all()) - else: - for t, t_p in zip(tangents, tangents_p): - self.assertTrue((jnp.abs(t - t_p) < 1e-6).all()) - - print( - name + " EnzymeMLIR(", - pname, - ") Fwd: ", - timeit.Timer( - fwdstr, - globals={ - "fwd": fwd_enzyme, - } - | fwdins, - ).timeit(number) - / number, - ) - - if (pname, pipeline) in self.revfilter(RevPipelines): - rev_enzyme = jax.jit(splatvjp(rfn_enzyme)) - - primals, grads = rev_enzyme(*douts, *ins) - self.assertTrue((jnp.abs(primals - primals_p) < 1e-6).all()) - - for i, (g, g_p) in enumerate(zip(grads, grads_p)): - print(i, g, g_p) - self.assertTrue((jnp.abs(g - g_p) < 1e-6).all()) - - print( - name + " EnzymeMLIR(", - pname, - ") Rev: ", - timeit.Timer( - revstr, - globals={ - "rev": rev_enzyme, - } - | revins, - ).timeit(number) - / number, - ) +from test_utils import * class AddOne(EnzymeJaxTest): @@ -284,7 +61,7 @@ def sum(x): return jnp.sum(x) self.fn = sum - self.name = "sum" + self.name = "sum " class Cache(EnzymeJaxTest): @@ -349,7 +126,7 @@ def f(x): return kcl self.fn = f - self.name = "activitymismatch" + self.name = "actmtch" class GenDot(EnzymeJaxTest): @@ -363,13 +140,6 @@ def setUp(self): ) ] - def nomlir(x): - return [ - (name, a) - for (name, a) in x - if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" - ] - self.primfilter = no_newxla self.fwdfilter = no_newxla # No new xla runs but gets wrong answer @@ -406,14 +176,8 @@ def setUp(self): ] self.douts = [jnp.array([i * i for i in range(2 * dim)], dtype=jnp.float32)] - def nomlir(x): - return [ - (name, a) - for (name, a) in x - if name != "NewXLAMLIR" and name != "NewXLA" and name != "OldXLA" - ] - - self.revfilter = nomlir + self.revfilter = justjax + # self.revfilter = nomlir def f(x, y): return jnp.concat([x, y], axis=None) @@ -432,33 +196,66 @@ def f(x, y): filt = justjax - for pname, pipeline in filt(AllPipelines): - args = ( - 3 * jnp.ones((1,), dtype=jnp.float32), - ( - 5 * jnp.ones((1,), dtype=jnp.float64), - 7 * jnp.ones((1,), dtype=jnp.int32), - ), - ) - - g = jax.value_and_grad( - jax.jit(enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(f)), - has_aux=True, - allow_int=True, - ) - g2 = jax.value_and_grad(f, has_aux=True, allow_int=True) - - res = g(*args) - res2 = g2(*args) - - name = "valueandgrad" - print(name + " JaX(", pname, "): ", res2) - print(name + " EnzymeMLIR(", pname, "): ", res) - self.assertTrue((jnp.abs(res[0][0] - res2[0][0]) < 1e-6).all()) - self.assertTrue((jnp.abs(res[0][1][0] - res2[0][1][0]) < 1e-6).all()) - self.assertTrue((jnp.abs(res[0][1][1] - res2[0][1][1]) < 1e-6).all()) - - self.assertTrue((jnp.abs(res[1] - res2[1]) < 1e-6).all()) + for pname, pipeline, backends in AllPipelines: + prevres = None + for backend in backends: + if (pname, pipeline) in filt(AllPipelines): + args = ( + to_backend(3 * jnp.ones((1,), dtype=jnp.float32), backend), + ( + to_backend(5 * jnp.ones((1,), dtype=jnp.float64), backend), + to_backend(7 * jnp.ones((1,), dtype=jnp.int32), backend), + ), + ) + + g = jax.value_and_grad( + ( + f + if pipeline is None + else jax.jit( + enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(f), + # backend=backend + ) + ), + has_aux=True, + allow_int=True, + ) + + res = g(*args) + if prevres is None: + prevres = res + else: + name = "valueandgrad" + print(name + " JaX(", pname, "): ", prevres) + print(name + " EnzymeMLIR(", pname, "): ", res) + self.assertTrue( + ( + jnp.abs(res[0][0] - to_backend(prevres[0][0], backend)) + < 1e-6 + ).all() + ) + self.assertTrue( + ( + jnp.abs( + res[0][1][0] - to_backend(prevres[0][1][0], backend) + ) + < 1e-6 + ).all() + ) + self.assertTrue( + ( + jnp.abs( + res[0][1][1] - to_backend(prevres[0][1][1], backend) + ) + < 1e-6 + ).all() + ) + + self.assertTrue( + ( + jnp.abs(res[1] - to_backend(prevres[1], backend)) < 1e-6 + ).all() + ) if __name__ == "__main__": diff --git a/test/lit.cfg.py b/test/lit.cfg.py index d46739f9..d9359edc 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -40,5 +40,6 @@ ] path = os.path.pathsep.join(base_paths) # + config.extra_paths) config.environment["PATH"] = path - +config.environment["ENZYME_TEST_NOWHEEL"] = "1" +config.environment["PYTHONPATH"] = os.environ["PYTHONPATH"] config.substitutions.append(("python", sys.executable)) diff --git a/test/lit_tests/ir.pyt b/test/lit_tests/ir.pyt index 14be0038..d45d72dd 100644 --- a/test/lit_tests/ir.pyt +++ b/test/lit_tests/ir.pyt @@ -49,15 +49,14 @@ ones = jnp.ones((2, 3), jnp.float32) twos = jnp.ones((5, 7), jnp.float32) -@jax.jit def fwdmode(a, b, c, d): return jax.jvp(do_something, (a, b), (c, d)) -print(lower(fwdmode, (ones, twos, ones, twos)).compiler_ir(dialect="stablehlo")) +print(lower(jax.jit(fwdmode, backend='cpu'), (ones, twos, ones, twos)).compiler_ir(dialect="stablehlo")) # CHECK: module @jit_fwdmode attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { -# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default"}, %arg2: tensor<2x3xf32> {mhlo.layout_mode = "default"}, %arg3: tensor<5x7xf32> {mhlo.layout_mode = "default"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<6x9xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}) +# CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg3: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, tensor<6x9xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, tensor<4x6xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) # CHECK-NEXT: %[[i0:.+]] = stablehlo.constant dense<1> : tensor<1xi64> # CHECK-NEXT: %[[i1:.+]]:4 = stablehlo.custom_call @jaxzyme.fwd(%[[i0]], %arg0, %arg2, %arg1, %arg3) : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>) # CHECK-NEXT: return %[[i1]]#0, %[[i1]]#2, %[[i1]]#1, %[[i1]]#3 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<6x9xf32>, tensor<4x6xf32> @@ -65,12 +64,11 @@ print(lower(fwdmode, (ones, twos, ones, twos)).compiler_ir(dialect="stablehlo")) # CHECK-NEXT: } -@jax.jit def f(a, b): return jax.vjp(do_something, a, b) -print(lower(f, (ones, twos)).compiler_ir(dialect="stablehlo")) +print(lower(jax.jit(f, backend='cpu'), (ones, twos)).compiler_ir(dialect="stablehlo")) # CHECK: module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { # CHECK-NEXT: func.func public @main @@ -84,12 +82,10 @@ x = jnp.ones((6, 9), jnp.float32) y = jnp.ones((4, 6), jnp.float32) -@jax.jit def g(a, b, x, y): primals, f_vjp = jax.vjp(do_something, a, b) return primals, f_vjp((x, y)) - # CHECK: module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { # CHECK-NEXT: func.func public @main # CHECK-NEXT: %[[i0:.+]] = stablehlo.constant dense<3> : tensor<1xi64> @@ -100,11 +96,11 @@ def g(a, b, x, y): # CHECK-NEXT: } # CHECK-NEXT: } -print(lower(g, (ones, twos, x, y)).compiler_ir(dialect="stablehlo")) +print(lower(jax.jit(g, backend='cpu'), (ones, twos, x, y)).compiler_ir(dialect="stablehlo")) -primals, f_vjp = jax.vjp(jax.jit(do_something), ones, twos) +primals, f_vjp = jax.vjp(jax.jit(do_something, backend='cpu'), ones, twos) -print(lower(jax.jit(f_vjp), ((x, y),)).compiler_ir(dialect="stablehlo")) +print(lower(jax.jit(f_vjp, backend='cpu'), ((x, y),)).compiler_ir(dialect="stablehlo")) # CHECK: module @jit__unnamed_wrapped_function_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { # CHECK-NEXT: func.func public @main # CHECK-NEXT: %[[i0:.+]] = stablehlo.constant dense<[0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63, 0, 0, -128, 63]> : tensor<16xi8> @@ -112,8 +108,12 @@ print(lower(jax.jit(f_vjp), ((x, y),)).compiler_ir(dialect="stablehlo")) # CHECK-NEXT: return %[[i1]]#0, %[[i1]]#1 : tensor<2x3xf32>, tensor<5x7xf32> # CHECK-NEXT: } # CHECK: func.func private @do_something +# CHECK-NEXT: %[[shard1:.+]] = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<6x9xf32>) -> tensor<6x9xf32> +# CHECK-NEXT: %[[shard2:.+]] = stablehlo.custom_call @Sharding(%arg2) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<4x6xf32>) -> tensor<4x6xf32> # CHECK-NEXT: %[[i0:.+]] = stablehlo.constant dense<6> : tensor<1xi64> -# CHECK-NEXT: %[[i1:.+]]:2 = stablehlo.custom_call @jaxzyme.rev(%[[i0]], %arg0, %arg1, %arg2) : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>) -# CHECK-NEXT: return %[[i1]]#0, %[[i1]]#1 : tensor<2x3xf32>, tensor<5x7xf32> +# CHECK-NEXT: %[[i1:.+]]:2 = stablehlo.custom_call @jaxzyme.rev(%[[i0]], %arg0, %[[shard1]], %[[shard2]]) : (tensor<1xi64>, tensor<16xi8>, tensor<6x9xf32>, tensor<4x6xf32>) -> (tensor<2x3xf32>, tensor<5x7xf32>) +# CHECK-NEXT: %[[res1:.+]] = stablehlo.custom_call @Sharding(%[[i1]]#0) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<2x3xf32>) -> tensor<2x3xf32> +# CHECK-NEXT: %[[res2:.+]] = stablehlo.custom_call @Sharding(%[[i1]]#1) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<5x7xf32>) -> tensor<5x7xf32> +# CHECK-NEXT: return %[[res1]], %[[res2]] : tensor<2x3xf32>, tensor<5x7xf32> # CHECK-NEXT: } # CHECK-NEXT: } diff --git a/test/llama.py b/test/llama.py index 204a8f16..ec1578ae 100644 --- a/test/llama.py +++ b/test/llama.py @@ -3,8 +3,16 @@ import jax.random import jax.lax import enzyme_ad.jax as enzyme_jax +from enzyme_ad.jax import ( + enzyme_jax_ir, + NewXLAPipeline, + OldXLAPipeline, + JaXPipeline, + hlo_opts, +) import numpy as np import timeit +from test_utils import * argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") @@ -240,199 +248,9 @@ def forward(x, config, weights, key_cache, value_cache): return x -class Llama(absltest.TestCase): - def test_llama_random(self): - config = { - "dim": 288, - "hidden_dim": 768, - "n_layers": 6, - "n_heads": 6, - "n_kv_heads": 6, - "vocab_size": 32000, - "seq_len": 256, - } - - n_layers = config["n_layers"] - seq_len = config["seq_len"] - n_heads = config["n_heads"] - dim = config["dim"] - n_kv_heads = config["n_kv_heads"] - vocab_size = config["vocab_size"] - hidden_dim = config["hidden_dim"] - kv_dim = dim // n_heads * n_kv_heads - head_size = dim // n_heads - - key = jax.random.PRNGKey(0) - weights = {} - dweights = {} - - for name, shape in [ - ("rms_att_weight", (n_layers, dim)), - ("wq", (n_layers, dim, n_heads * head_size)), - ("wk", (n_layers, dim, n_kv_heads * head_size)), - ("wv", (n_layers, dim, n_kv_heads * head_size)), - ("wo", (n_layers, dim, dim)), - ("rms_ffn_weight", (n_layers, dim)), - ("w1", (n_layers, hidden_dim, dim)), - ("w2", (n_layers, dim, hidden_dim)), - ("w3", (n_layers, hidden_dim, dim)), - ("rms_final_weight", (dim,)), - ("wcls", (vocab_size, dim)), - ]: - key, subkey = jax.random.split(key) - key, subkey2 = jax.random.split(key) - weights[name] = jax.random.uniform(subkey, shape=shape) - dweights[name] = jax.random.uniform(subkey2, shape=shape) - - key, subkey = jax.random.split(key) - x = jax.random.uniform(subkey, shape=(dim,)) - key, subkey = jax.random.split(key) - dx = jax.random.uniform(subkey, shape=(dim,)) - - def partial(func, config): - def sfn(x, weights, key_cache, value_cache): - return func(x, config, weights, key_cache, value_cache) - - return sfn - - pos = 1 - key_cache = jnp.zeros((n_layers, pos, kv_dim)) - value_cache = jnp.zeros((n_layers, pos, kv_dim)) - - key, subkey = jax.random.split(key) - dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) - key, subkey = jax.random.split(key) - dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) - - func = partial(forward, config) - - jfunc = jax.jit(func) - - efunc = jax.jit( - enzyme_jax.enzyme_jax_ir(argv=argv, pipeline_options=pipeline)(func) - ) - - number = 1000 - if False: - eres = efunc(x, weights, key_cache, value_cache) - print("Enzyme primal", eres) - res = jfunc(x, weights, key_cache, value_cache) - print("Jax primal", res) - print(" max error", jnp.max(jnp.abs(eres - res))) - assert (jnp.abs(eres - res) < 1e-3).all() - - print( - "Enzyme primal", - timeit.Timer( - "efunc(x, weights, key_cache, value_cache)", - globals={ - "efunc": efunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX primal", - timeit.Timer( - "jfunc(x, weights, key_cache, value_cache)", - globals={ - "jfunc": jfunc, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - # jfunc = jax.jit(partial(forward, config)) - # mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") - - if False: - - @jax.jit - def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - @jax.jit - def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): - return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) - - eres = efwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Enzyme fwd", eres) - jres = jfwd( - x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache - ) - print("Jax fwd", jres) - print( - "Enzyme fwd", - timeit.Timer( - "efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "efwd": efwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - print( - "JaX fwd", - timeit.Timer( - "jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)", - globals={ - "jfwd": jfwd, - "x": x, - "dx": dx, - "weights": weights, - "dweights": dweights, - "key_cache": key_cache, - "value_cache": value_cache, - }, - ).timeit(number), - ) - - @jax.jit - def jrev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - @jax.jit - def erev(x, weights, kc, vc, dx, dkc, dvc): - primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) - return f_vjp(dx) # , dkc, dvc) - - eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Enzyme rev", eres) - jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax rev", jres) - - jrev2 = jax.jit( - enzyme_jax.enzyme_jax_ir( - argv=argv, - pipeline_options=enzyme_jax.JaXPipeline( - "inline{default-pipeline=canonicalize max-iterations=4}," - + "canonicalize,cse,enzyme-hlo-opt,cse" - ), - )(jrev) - ) - - jres2 = jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax2 rev", jres2) - - jrev3 = jax.jit( - enzyme_jax.enzyme_jax_ir( - argv=argv, - pipeline_options=enzyme_jax.JaXPipeline( - "inline{default-pipeline=canonicalize max-iterations=4}," - + """canonicalize,cse, +partialopt = ( + "inline{default-pipeline=canonicalize max-iterations=4}," + + """canonicalize,cse, enzyme-hlo-generate-td{ patterns=compare_op_canon<16>; transpose_transpose<16>; @@ -534,138 +352,99 @@ def erev(x, weights, kc, vc, dx, dkc, dvc): }, transform-interpreter, enzyme-hlo-remove-transform,cse""" - ), - )(jrev) - ) - unused = """ - - - -reshape_iota<16>; -slice_reshape_slice<1>; -dot_general_simplify<16>; -transpose_simplify<16>; -reshape_empty_broadcast<1>; -add_pad_pad_to_concat<1>; -broadcast_reshape<1>; - -slice_reshape_concat<1>; -slice_reshape_elementwise<1>; -slice_reshape_transpose<1>; -slice_reshape_dot_general<1>; -concat_pad<1>; -reduce_pad<1>; -broadcast_pad<1>; - -zero_product_reshape_pad<1>; -mul_zero_pad<1>; -div_zero_pad<1>; - -binop_const_reshape_pad<1>; -binop_const_pad_add<1>; -binop_const_pad_subtract<1>; -binop_const_pad_mul<1>; -binop_const_pad_div<1>; - -slice_reshape_pad<1>; -binop_binop_pad_pad_add<1>; -binop_binop_pad_pad_mul<1>; -binop_pad_pad_add<1>; -binop_pad_pad_subtract<1>; -binop_pad_pad_mul<1>; -binop_pad_pad_div<1>; -binop_pad_pad_min<1>; -binop_pad_pad_max<1>; - -unary_pad_push_convert<1>; -unary_pad_push_tanh<1>; -unary_pad_push_exp<1>; -transpose_pad<1>; - -transpose_dot_reorder<1>; -dot_transpose<1>; -convert_convert_float<1>; -concat_to_pad<1>; -concat_appending_reshape<1>; -reshape_iota<1>; +) + +pipelines = [ + ("JaX ", None, CurBackends), + ("JaXPipe", JaXPipeline(), CurBackends), + ( + "HLOOpt", + JaXPipeline( + "inline{default-pipeline=canonicalize max-iterations=4}," + + "canonicalize,cse,enzyme-hlo-opt,cse" + ), + CurBackends, + ), + ("PartOpt", JaXPipeline(partialopt), CurBackends), + ("DefOpt", JaXPipeline(hlo_opts()), CurBackends), +] + + +class Llama(EnzymeJaxTest): + def setUp(self): + config = { + "dim": 288, + "hidden_dim": 768, + "n_layers": 6, + "n_heads": 6, + "n_kv_heads": 6, + "vocab_size": 32000, + "seq_len": 256, + } -broadcast_reduce<1>; -slice_dot_general<1>; + n_layers = config["n_layers"] + seq_len = config["seq_len"] + n_heads = config["n_heads"] + dim = config["dim"] + n_kv_heads = config["n_kv_heads"] + vocab_size = config["vocab_size"] + hidden_dim = config["hidden_dim"] + kv_dim = dim // n_heads * n_kv_heads + head_size = dim // n_heads -dot_reshape_pad<1>; -pad_dot_general<1>(0); + key = jax.random.PRNGKey(0) + weights = {} + dweights = {} -dot_reshape_pad<1>; -pad_dot_general<1>(1); -""" - - jres3 = jrev3(x, weights, key_cache, value_cache, dx, dkc, dvc) - # print("Jax3 rev", jres3) - - print( - "Enzyme rev", - timeit.Timer( - "erev(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "erev": erev, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX rev", - timeit.Timer( - "jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev": jrev, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX2 rev", - timeit.Timer( - "jrev2(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev2": jrev2, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) - print( - "JaX3 rev", - timeit.Timer( - "jrev3(x, weights, key_cache, value_cache, dx, dkc, dvc)", - globals={ - "jrev3": jrev3, - "x": x, - "weights": weights, - "key_cache": key_cache, - "value_cache": value_cache, - "dx": dx, - "dkc": dkc, - "dvc": dvc, - }, - ).timeit(number), - ) + for name, shape in [ + ("rms_att_weight", (n_layers, dim)), + ("wq", (n_layers, dim, n_heads * head_size)), + ("wk", (n_layers, dim, n_kv_heads * head_size)), + ("wv", (n_layers, dim, n_kv_heads * head_size)), + ("wo", (n_layers, dim, dim)), + ("rms_ffn_weight", (n_layers, dim)), + ("w1", (n_layers, hidden_dim, dim)), + ("w2", (n_layers, dim, hidden_dim)), + ("w3", (n_layers, hidden_dim, dim)), + ("rms_final_weight", (dim,)), + ("wcls", (vocab_size, dim)), + ]: + key, subkey = jax.random.split(key) + key, subkey2 = jax.random.split(key) + weights[name] = jax.random.uniform(subkey, shape=shape) + dweights[name] = jax.random.uniform(subkey2, shape=shape) + + key, subkey = jax.random.split(key) + x = jax.random.uniform(subkey, shape=(dim,)) + key, subkey = jax.random.split(key) + dx = jax.random.uniform(subkey, shape=(dim,)) + + def partial(func, config): + def sfn(x, weights, key_cache, value_cache): + return func(x, config, weights, key_cache, value_cache) + + return sfn + + pos = 1 + key_cache = jnp.zeros((n_layers, pos, kv_dim)) + value_cache = jnp.zeros((n_layers, pos, kv_dim)) + + key, subkey = jax.random.split(key) + dkc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + key, subkey = jax.random.split(key) + dvc = jax.random.uniform(subkey, shape=(n_layers, pos + 1, kv_dim)) + + self.fn = partial(forward, config) + self.name = "llama" + self.count = 100 if jax.default_backend() == "cpu" else 1000 + self.revprimal = False + self.AllPipelines = pipelines + self.AllBackends = CurBackends + + self.ins = [x, weights, key_cache, value_cache] + self.dins = [dx, weights, key_cache, value_cache] + self.douts = [dx] + self.tol = 5e-5 if __name__ == "__main__": diff --git a/test/test.py b/test/test.py index 623fe4bd..ad5765b0 100644 --- a/test/test.py +++ b/test/test.py @@ -3,6 +3,8 @@ import jax.numpy as jnp from enzyme_ad.jax import cpp_call, enzyme_jax_ir, optimize_module +jax.config.update("jax_platform_name", "cpu") + argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") @@ -63,20 +65,36 @@ def do_something(ones): ones = jnp.ones((2, 3), jnp.float32) x, y, z = do_something(ones) - print(x) - print(y) - print(z) + self.assertTrue((x == 43).all()) + self.assertTrue((y == 85).all()) + self.assertTrue((z[0] == 56).all()) # JVP primals, tangents = jax.jvp(do_something, (ones,), (ones,)) - print(primals) - print(tangents) + self.assertTrue((primals[0] == 43).all()) + self.assertTrue((primals[1] == 85).all()) + self.assertTrue((primals[2][0] == 56).all()) + self.assertTrue((tangents[0] == 1).all()) + self.assertTrue((tangents[1] == 1).all()) + self.assertTrue((tangents[2][0] == 0).all()) # VJP primals, f_vjp = jax.vjp(do_something, ones) (grads,) = f_vjp((x, y, z)) - print(primals) - print(grads) + self.assertTrue((primals[0] == 43).all()) + self.assertTrue((primals[1] == 85).all()) + self.assertTrue((primals[2][0] == 56).all()) + + self.assertTrue( + ( + grads[1] + == jnp.array( + [ + [128.0, 128.0, 128.0], + ] + ) + ).all() + ) def test_enzyme_mlir_jit(self): @jax.jit @@ -84,7 +102,6 @@ def test_enzyme_mlir_jit(self): def add_one(x: jax.Array, y) -> jax.Array: return x + 1 + y - # But it should print LLVM IR in the process. add_one(jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])) primals, tangents = jax.jvp( @@ -92,15 +109,61 @@ def add_one(x: jax.Array, y) -> jax.Array: (jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0])), (jnp.array([0.1, 0.2, 0.3]), jnp.array([50.0, 70.0, 110.0])), ) - print(primals) - print(tangents) + self.assertTrue( + ( + primals + == jnp.array( + [ + [12.0, 23.0, 34.0], + ] + ) + ).all() + ) + self.assertTrue( + ( + tangents + == jnp.array( + [ + [50.1, 70.2, 110.3], + ] + ) + ).all() + ) primals, f_vjp = jax.vjp( add_one, jnp.array([1.0, 2.0, 3.0]), jnp.array([10.0, 20.0, 30.0]) ) grads = f_vjp(jnp.array([500.0, 700.0, 110.0])) - print(primals) - print(grads) + self.assertTrue( + ( + primals + == jnp.array( + [ + [12.0, 23.0, 34.0], + ] + ) + ).all() + ) + self.assertTrue( + ( + grads[0] + == jnp.array( + [ + [500.0, 700.0, 110.0], + ] + ) + ).all() + ) + self.assertTrue( + ( + grads[1] + == jnp.array( + [ + [500.0, 700.0, 110.0], + ] + ) + ).all() + ) if __name__ == "__main__": diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..8d037105 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,414 @@ +import jax +import jax.numpy as jnp +from enzyme_ad.jax import ( + enzyme_jax_ir, + NewXLAPipeline, + OldXLAPipeline, + JaXPipeline, + hlo_opts, +) +from absl.testing import absltest +import timeit + +argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") + +devices = [] +CurBackends = [jax.default_backend()] + +if jax.default_backend() != "cpu": + devices = CurBackends + +AllBackends = ["cpu"] + devices +AllPipelines = [ + ("JaX ", None, AllBackends), + ("JaXPipe", JaXPipeline(), AllBackends), + # ("NewXLAMLIR", NewXLAPipeline(mlirad=True)), + # ("NewXLA", NewXLAPipeline()), + ("OldXLA", OldXLAPipeline(), ["cpu"]), +] + + +def no_newxla(x): + return [ + (name, a, b) for (name, a, b) in x if name != "NewXLAMLIR" and name != "NewXLA" + ] + + +def no_newxlamlir(x): + return [(name, a, b) for (name, a, b) in x if name != "NewXLAMLIR"] + + +def nomlir(x): + return [ + (name, a, b) + for (name, a, b) in x + if name != "NewXLAMLIR" and name != "NewXLA" # and name != "OldXLA" + ] + + +def justjax(x): + return [ + (name, a, b) for (name, a, b) in x if a is None or isinstance(a, JaXPipeline) + ] + + +# @jax.jit +# def fwd_jax(in0, in1, din0, din1): +# . return jax.jvp(add_one_jax, (in0, in1), (din0, din1)) +def splatjvp(in_fn): + def fwd(*args): + assert len(args) % 2 == 0 + return jax.jvp( + in_fn, tuple(args[: len(args) // 2]), tuple(args[len(args) // 2 :]) + ) + + return fwd + + +# @jax.jit +# def rev_jax(dout, in0, in1): +# primals, f_vjp = jax.vjp(add_one_jax, in0, in1) +# grads = f_vjp(dout) +# return primals, grads +def splatvjp(in_fn): + def rev(dout, *args): + primals, f_vjp = jax.vjp(in_fn, *args) + grads = f_vjp(dout) + return primals, grads + + return rev + + +def splatvjp_noprim(in_fn): + def rev(dout, *args): + primals, f_vjp = jax.vjp(in_fn, *args) + grads = f_vjp(dout) + return primals, grads + + return rev + + +def to_backend(x, backend): + dev = jax.local_devices(backend=backend)[0] + return jax.device_put(x, dev) + + +def recursive_check(tester, lhs, rhs, tol=1e-6): + tester.assertEqual(type(lhs), type(rhs)) + if isinstance(lhs, jax.Array): + legal = (jnp.abs(lhs - rhs) < tol).all() + if not legal: + print("lhs", lhs) + print("rhs", rhs) + print("abs", jnp.abs(lhs - rhs)) + print("eq", jnp.abs(lhs - rhs) < tol) + print("max", jnp.max(jnp.abs(lhs - rhs))) + tester.assertTrue(legal) + return + + if isinstance(lhs, tuple): + for i, (g, g_p) in enumerate(zip(lhs, rhs)): + recursive_check(tester, g, g_p, tol) + return + + if isinstance(lhs, dict): + tester.assertEqual(lhs.keys(), rhs.keys()) + for k in lhs.keys(): + recursive_check(tester, lhs[k], rhs[k], tol) + return + + print("Unknown recursive type", type(lhs), " ", type(rhs)) + tester.assertTrue(False) + + +class EnzymeJaxTest(absltest.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.primfilter = lambda x: x + self.fwdfilter = lambda x: x + self.revfilter = lambda x: x + self.count = 10000 + self.AllBackends = AllBackends + self.AllPipelines = AllPipelines + self.revprimal = True + self.tol = 1e-6 + + def setUp(self): + self.name = None + + def test(self): + if self.name is None: + return + self.harness(self.name, self.fn, self.ins, self.dins, self.douts) + + def harness(self, name, in_fn, ins, dins, douts): + assert len(ins) == len(dins) + + assert 1 == len(douts) + + primalstr = "fn(" + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" + + fwdstr = ( + "fwd(" + + (", ".join(["in" + str(i) for i in range(len(ins))])) + + ", " + + (", ".join(["din" + str(i) for i in range(len(dins))])) + + ")" + ) + + revstr = ( + "rev(dout, " + (", ".join(["in" + str(i) for i in range(len(ins))])) + ")" + ) + + for backend in self.AllBackends: + ins_backend = [to_backend(x, backend) for x in ins] + dins_backend = [to_backend(x, backend) for x in dins] + douts_backend = [to_backend(x, backend) for x in douts] + + primalins = {("in" + str(i)): ins_backend[i] for i in range(len(ins))} + fwdins = primalins | { + ("din" + str(i)): dins_backend[i] for i in range(len(dins)) + } + revins = primalins | {"dout": douts_backend[0]} + + primres = None + + for pname, pipeline, pbackends in self.primfilter(self.AllPipelines): + if backend in pbackends: + rfn_enzyme = jax.jit( + ( + in_fn + if pipeline is None + else enzyme_jax_ir(pipeline_options=pipeline, argv=argv)( + in_fn + ) + ), + # backend=backend + ) + ao = rfn_enzyme(*ins_backend) + if primres is None: + primres = ao + else: + recursive_check(self, ao, primres, self.tol) + + print( + name, + ",", + pname, + ",", + backend, + ",", + "Primal", + ",", + timeit.Timer( + primalstr, + globals={ + "fn": rfn_enzyme, + } + | primalins, + ).timeit(self.count) + / self.count, + sep="\t", + ) + + # assert primres is not None + fwdres = None + + for pname, pipeline, pbackends in self.fwdfilter(self.AllPipelines): + if backend in pbackends: + rfn_enzyme = ( + in_fn + if pipeline is None + else jax.jit( + enzyme_jax_ir(pipeline_options=pipeline, argv=argv)(in_fn), + # backend=backend + ) + ) + fwd_enzyme = jax.jit( + splatjvp(rfn_enzyme), + # backend=backend + ) + + primals, tangents = fwd_enzyme(*(ins_backend + dins_backend)) + + recursive_check(self, primals, primres, self.tol) + + if fwdres is None: + fwdres = tangents + else: + recursive_check(self, tangents, fwdres, self.tol) + + print( + name, + ",", + pname, + ",", + backend, + ",", + "Forward", + ",", + timeit.Timer( + fwdstr, + globals={ + "fwd": fwd_enzyme, + } + | fwdins, + ).timeit(self.count) + / self.count, + sep="\t", + ) + + # assert fwdres is not None + + revres = None + + revtransform = splatvjp if self.revprimal else splatvjp_noprim + + for pname, pipeline, pbackends in self.revfilter(self.AllPipelines): + if backend in pbackends: + if pipeline is not None: + rfn_enzyme = ( + in_fn + if pipeline is None + else enzyme_jax_ir(pipeline_options=pipeline, argv=argv)( + in_fn + ) + ) + rev_enzyme = jax.jit( + revtransform(rfn_enzyme), + # backend=backend + ) + + if self.revprimal: + primals, grads = rev_enzyme(*douts_backend, *ins_backend) + else: + grads = rev_enzyme(*douts_backend, *ins_backend) + assert grads is not None + + if self.revprimal and primres is not None: + recursive_check(self, primals, primres, self.tol) + + if revres is None: + revres = grads + else: + recursive_check(self, grads, revres, self.tol) + + print( + name, + ",", + pname, + ",", + backend, + ",", + "PreRev", + ",", + timeit.Timer( + revstr, + globals={ + "rev": rev_enzyme, + } + | revins, + ).timeit(self.count) + / self.count, + sep="\t", + ) + + rfn_enzyme = in_fn + rev_enzyme = jax.jit( + ( + revtransform(rfn_enzyme) + if pipeline is None + else enzyme_jax_ir( + pipeline_options=pipeline, argv=argv + )(revtransform(rfn_enzyme)) + ), + # backend=backend + ) + + if self.revprimal: + primals, grads = rev_enzyme(*douts_backend, *ins_backend) + else: + grads = rev_enzyme(*douts_backend, *ins_backend) + assert grads is not None + + if self.revprimal and primres is not None: + recursive_check(self, primals, primres, self.tol) + + if revres is None: + revres = grads + else: + recursive_check(self, grads, revres, self.tol) + + print( + name, + ",", + pname, + ",", + backend, + ",", + "PostRev", + ",", + timeit.Timer( + revstr, + globals={ + "rev": rev_enzyme, + } + | revins, + ).timeit(self.count) + / self.count, + sep="\t", + ) + + if pipeline is None or pipeline.mlir_ad(): + rfn_enzyme = ( + in_fn + if pipeline is None + else enzyme_jax_ir(pipeline_options=pipeline, argv=argv)( + in_fn + ) + ) + rev_enzyme = jax.jit( + ( + revtransform(rfn_enzyme) + if pipeline is None + else enzyme_jax_ir( + pipeline_options=pipeline, argv=argv + )(revtransform(rfn_enzyme)) + ), + # backend=backend + ) + + if self.revprimal: + primals, grads = rev_enzyme(*douts_backend, *ins_backend) + else: + grads = rev_enzyme(*douts_backend, *ins_backend) + assert grads is not None + + if self.revprimal and primres is not None: + recursive_check(self, primals, primres, self.tol) + + if revres is None: + revres = grads + else: + recursive_check(self, grads, revres, self.tol) + + print( + name, + ",", + pname, + ",", + backend, + ",", + "BothRev", + ",", + timeit.Timer( + revstr, + globals={ + "rev": rev_enzyme, + } + | revins, + ).timeit(self.count) + / self.count, + sep="\t", + ) + assert revres is not None diff --git a/workspace.bzl b/workspace.bzl index 0c3a2dc8..b1f8d8c9 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -1,8 +1,8 @@ JAX_COMMIT = "3713b966c2a868e948a663193282deba7ba14842" JAX_SHA256 = "6b0265a1c58b6c050f334e6b91ce1a69b47b1144670b87bcc30fc90172b5cbf4" -ENZYME_COMMIT = "cc65abdb55e5e2e142d773e0508e1083d4b3ac52" -ENZYME_SHA256 = "" +ENZYME_COMMIT = "9bdae194e007a288094e3b069fce7de4aff2e38b" +ENZYME_SHA256 = "f91797ef350ccd4a4b64f20ec9d328132d600a06b9862088e340c7761c4eb59a" XLA_PATCHES = [ """