diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d8980fcd1a0d..7fdc7d6c4266d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,751 +1,43 @@ -############################################################# -# WARNING: automatically generated file, DO NOT CHANGE! # -############################################################# +name: Rust CI -# This file was automatically generated by the expand-yaml-anchors tool. The -# source file that generated this one is: -# -# src/ci/github-actions/ci.yml -# -# Once you make changes to that file you need to run: -# -# ./x.py run src/tools/expand-yaml-anchors/ -# -# The CI build will fail if the tool is not run after changes to this file. - ---- -name: CI -"on": +on: push: branches: - - auto - - try - - try-perf - - master + - master pull_request: branches: - - "**" -permissions: - contents: read -defaults: - run: - shell: bash + - master + merge_group: + jobs: - pr: - permissions: - actions: write - name: "PR - ${{ matrix.name }}" - env: - CI_JOB_NAME: "${{ matrix.name }}" - CARGO_REGISTRIES_CRATES_IO_PROTOCOL: sparse - SCCACHE_BUCKET: rust-lang-ci-sccache2 - TOOLSTATE_REPO: "https://github.com/rust-lang-nursery/rust-toolstate" - CACHE_DOMAIN: ci-caches.rust-lang.org - if: "github.event_name == 'pull_request'" - continue-on-error: "${{ matrix.name == 'mingw-check-tidy' }}" + build: + name: Rust Integration CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false matrix: - include: - - name: mingw-check - os: ubuntu-20.04-16core-64gb - env: {} - - name: mingw-check-tidy - os: ubuntu-20.04-16core-64gb - env: {} - - name: x86_64-gnu-llvm-14 - os: ubuntu-20.04-16core-64gb - env: {} - - name: x86_64-gnu-tools - os: ubuntu-20.04-16core-64gb - env: {} - timeout-minutes: 600 - runs-on: "${{ matrix.os }}" - steps: - - name: disable git crlf conversion - run: git config --global core.autocrlf false - - name: checkout the source code - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - name: configure the PR in which the error message will be posted - run: "echo \"[CI_PR_NUMBER=$num]\"" - env: - num: "${{ github.event.number }}" - if: "success() && !env.SKIP_JOB && github.event_name == 'pull_request'" - - name: add extra environment variables - run: src/ci/scripts/setup-environment.sh - env: - EXTRA_VARIABLES: "${{ toJson(matrix.env) }}" - if: success() && !env.SKIP_JOB - - name: decide whether to skip this job - run: src/ci/scripts/should-skip-this.sh - if: success() && !env.SKIP_JOB - - name: ensure the channel matches the target branch - run: src/ci/scripts/verify-channel.sh - if: success() && !env.SKIP_JOB - - name: configure GitHub Actions to kill the build when outdated - uses: rust-lang/simpleinfra/github-actions/cancel-outdated-builds@master - with: - github_token: "${{ secrets.github_token }}" - if: "success() && !env.SKIP_JOB && github.ref != 'refs/heads/try' && github.ref != 'refs/heads/try-perf'" - - name: collect CPU statistics - run: src/ci/scripts/collect-cpu-stats.sh - if: success() && !env.SKIP_JOB - - name: show the current environment - run: src/ci/scripts/dump-environment.sh - if: success() && !env.SKIP_JOB - - name: install sccache - run: src/ci/scripts/install-sccache.sh - if: success() && !env.SKIP_JOB - - name: select Xcode - run: src/ci/scripts/select-xcode.sh - if: success() && !env.SKIP_JOB - - name: install clang - run: src/ci/scripts/install-clang.sh - if: success() && !env.SKIP_JOB - - name: install WIX - run: src/ci/scripts/install-wix.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: checkout submodules - run: src/ci/scripts/checkout-submodules.sh - if: success() && !env.SKIP_JOB - - name: install MSYS2 - run: src/ci/scripts/install-msys2.sh - if: success() && !env.SKIP_JOB - - name: install MinGW - run: src/ci/scripts/install-mingw.sh - if: success() && !env.SKIP_JOB - - name: install ninja - run: src/ci/scripts/install-ninja.sh - if: success() && !env.SKIP_JOB - - name: enable ipv6 on Docker - run: src/ci/scripts/enable-docker-ipv6.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: ensure line endings are correct - run: src/ci/scripts/verify-line-endings.sh - if: success() && !env.SKIP_JOB - - name: ensure backported commits are in upstream branches - run: src/ci/scripts/verify-backported-commits.sh - if: success() && !env.SKIP_JOB - - name: ensure the stable version number is correct - run: src/ci/scripts/verify-stable-version-number.sh - if: success() && !env.SKIP_JOB - - name: run the build - run: src/ci/scripts/run-build-from-ci.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.CACHES_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.CACHES_AWS_ACCESS_KEY_ID)] }}" - TOOLSTATE_REPO_ACCESS_TOKEN: "${{ secrets.TOOLSTATE_REPO_ACCESS_TOKEN }}" - if: success() && !env.SKIP_JOB - - name: upload artifacts to S3 - run: src/ci/scripts/upload-artifacts.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.ARTIFACTS_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.ARTIFACTS_AWS_ACCESS_KEY_ID)] }}" - if: "success() && !env.SKIP_JOB && (github.event_name == 'push' || env.DEPLOY == '1' || env.DEPLOY_ALT == '1')" - auto: - permissions: - actions: write - name: "auto - ${{ matrix.name }}" - env: - CI_JOB_NAME: "${{ matrix.name }}" - CARGO_REGISTRIES_CRATES_IO_PROTOCOL: sparse - SCCACHE_BUCKET: rust-lang-ci-sccache2 - DEPLOY_BUCKET: rust-lang-ci2 - TOOLSTATE_REPO: "https://github.com/rust-lang-nursery/rust-toolstate" - TOOLSTATE_ISSUES_API_URL: "https://api.github.com/repos/rust-lang/rust/issues" - TOOLSTATE_PUBLISH: 1 - CACHES_AWS_ACCESS_KEY_ID: AKIA46X5W6CZI5DHEBFL - ARTIFACTS_AWS_ACCESS_KEY_ID: AKIA46X5W6CZN24CBO55 - AWS_REGION: us-west-1 - CACHE_DOMAIN: ci-caches.rust-lang.org - if: "github.event_name == 'push' && github.ref == 'refs/heads/auto' && github.repository == 'rust-lang-ci/rust'" - strategy: - matrix: - include: - - name: aarch64-gnu - os: - - self-hosted - - ARM64 - - linux - - name: arm-android - os: ubuntu-20.04-8core-32gb - env: {} - - name: armhf-gnu - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-aarch64-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-android - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-arm-linux - os: ubuntu-20.04-16core-64gb - env: {} - - name: dist-armhf-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-armv7-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-i586-gnu-i586-i686-musl - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-i686-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-mips-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-mips64-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-mips64el-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-mipsel-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-powerpc-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-powerpc64-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-powerpc64le-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-riscv64-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-s390x-linux - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-various-1 - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-various-2 - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-x86_64-freebsd - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-x86_64-illumos - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-x86_64-linux - os: ubuntu-20.04-16core-64gb - env: {} - - name: dist-x86_64-linux-alt - env: - IMAGE: dist-x86_64-linux - os: ubuntu-20.04-16core-64gb - - name: dist-x86_64-musl - os: ubuntu-20.04-8core-32gb - env: {} - - name: dist-x86_64-netbsd - os: ubuntu-20.04-8core-32gb - env: {} - - name: i686-gnu - os: ubuntu-20.04-8core-32gb - env: {} - - name: i686-gnu-nopt - os: ubuntu-20.04-8core-32gb - env: {} - - name: mingw-check - os: ubuntu-20.04-4core-16gb - env: {} - - name: test-various - os: ubuntu-20.04-8core-32gb - env: {} - - name: wasm32 - os: ubuntu-20.04-8core-32gb - env: {} - - name: x86_64-gnu - os: ubuntu-20.04-4core-16gb - env: {} - - name: x86_64-gnu-stable - env: - IMAGE: x86_64-gnu - RUST_CI_OVERRIDE_RELEASE_CHANNEL: stable - CI_ONLY_WHEN_CHANNEL: nightly - os: ubuntu-20.04-4core-16gb - - name: x86_64-gnu-aux - os: ubuntu-20.04-4core-16gb - env: {} - - name: x86_64-gnu-debug - os: ubuntu-20.04-8core-32gb - env: {} - - name: x86_64-gnu-distcheck - os: ubuntu-20.04-8core-32gb - env: {} - - name: x86_64-gnu-llvm-16 - env: - RUST_BACKTRACE: 1 - os: ubuntu-20.04-8core-32gb - - name: x86_64-gnu-llvm-15 - env: - RUST_BACKTRACE: 1 - os: ubuntu-20.04-8core-32gb - - name: x86_64-gnu-llvm-14 - env: - RUST_BACKTRACE: 1 - os: ubuntu-20.04-8core-32gb - - name: x86_64-gnu-llvm-14-stage1 - env: - RUST_BACKTRACE: 1 - os: ubuntu-20.04-8core-32gb - - name: x86_64-gnu-nopt - os: ubuntu-20.04-4core-16gb - env: {} - - name: x86_64-gnu-tools - env: - DEPLOY_TOOLSTATES_JSON: toolstates-linux.json - os: ubuntu-20.04-8core-32gb - - name: dist-x86_64-apple - env: - SCRIPT: "./x.py dist bootstrap --include-default-paths --host=x86_64-apple-darwin --target=x86_64-apple-darwin" - RUST_CONFIGURE_ARGS: "--enable-full-tools --enable-sanitizers --enable-profiler --set rust.jemalloc --set llvm.ninja=false --set rust.lto=thin" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - MACOSX_DEPLOYMENT_TARGET: 10.7 - SELECT_XCODE: /Applications/Xcode_13.4.1.app - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - DIST_REQUIRE_ALL_TOOLS: 1 - os: macos-latest - - name: dist-apple-various - env: - SCRIPT: "./x.py dist bootstrap --include-default-paths --host='' --target=aarch64-apple-ios,x86_64-apple-ios,aarch64-apple-ios-sim" - RUST_CONFIGURE_ARGS: "--enable-sanitizers --enable-profiler --set rust.jemalloc --set llvm.ninja=false" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - MACOSX_DEPLOYMENT_TARGET: 10.7 - SELECT_XCODE: /Applications/Xcode_13.4.1.app - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - os: macos-latest - - name: dist-x86_64-apple-alt - env: - SCRIPT: "./x.py dist bootstrap --include-default-paths" - RUST_CONFIGURE_ARGS: "--enable-extended --enable-profiler --set rust.jemalloc --set llvm.ninja=false" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - MACOSX_DEPLOYMENT_TARGET: 10.7 - SELECT_XCODE: /Applications/Xcode_13.4.1.app - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - os: macos-latest - - name: x86_64-apple-1 - env: - SCRIPT: "./x.py --stage 2 test --exclude tests/ui --exclude tests/rustdoc --exclude tests/run-make-fulldeps" - RUST_CONFIGURE_ARGS: "--build=x86_64-apple-darwin --enable-sanitizers --enable-profiler --set rust.jemalloc --set llvm.ninja=false" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - MACOSX_DEPLOYMENT_TARGET: 10.8 - MACOSX_STD_DEPLOYMENT_TARGET: 10.7 - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - os: macos-latest - - name: x86_64-apple-2 - env: - SCRIPT: "./x.py --stage 2 test tests/ui tests/rustdoc tests/run-make-fulldeps" - RUST_CONFIGURE_ARGS: "--build=x86_64-apple-darwin --enable-sanitizers --enable-profiler --set rust.jemalloc --set llvm.ninja=false" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - MACOSX_DEPLOYMENT_TARGET: 10.8 - MACOSX_STD_DEPLOYMENT_TARGET: 10.7 - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - os: macos-latest - - name: dist-aarch64-apple - env: - SCRIPT: "./x.py dist bootstrap --include-default-paths --stage 2" - RUST_CONFIGURE_ARGS: "--build=x86_64-apple-darwin --host=aarch64-apple-darwin --target=aarch64-apple-darwin --enable-full-tools --enable-sanitizers --enable-profiler --disable-docs --set rust.jemalloc --set llvm.ninja=false" - RUSTC_RETRY_LINKER_ON_SEGFAULT: 1 - SELECT_XCODE: /Applications/Xcode_13.4.1.app - USE_XCODE_CLANG: 1 - MACOSX_DEPLOYMENT_TARGET: 11.0 - MACOSX_STD_DEPLOYMENT_TARGET: 11.0 - NO_LLVM_ASSERTIONS: 1 - NO_DEBUG_ASSERTIONS: 1 - NO_OVERFLOW_CHECKS: 1 - DIST_REQUIRE_ALL_TOOLS: 1 - JEMALLOC_SYS_WITH_LG_PAGE: 14 - os: macos-latest - - name: x86_64-msvc-1 - env: - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --enable-profiler" - SCRIPT: make ci-subset-1 - os: windows-2019-8core-32gb - - name: x86_64-msvc-2 - env: - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --enable-profiler" - SCRIPT: make ci-subset-2 - os: windows-2019-8core-32gb - - name: i686-msvc-1 - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-msvc" - SCRIPT: make ci-subset-1 - os: windows-2019-8core-32gb - - name: i686-msvc-2 - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-msvc" - SCRIPT: make ci-subset-2 - os: windows-2019-8core-32gb - - name: x86_64-msvc-cargo - env: - SCRIPT: python x.py --stage 2 test src/tools/cargotest src/tools/cargo - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --enable-lld" - os: windows-2019-8core-32gb - - name: x86_64-msvc-tools - env: - SCRIPT: src/ci/docker/host-x86_64/x86_64-gnu-tools/checktools.sh x.py /tmp/toolstate/toolstates.json windows - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --save-toolstates=/tmp/toolstate/toolstates.json" - DEPLOY_TOOLSTATES_JSON: toolstates-windows.json - os: windows-2019-8core-32gb - - name: i686-mingw-1 - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-gnu" - SCRIPT: make ci-mingw-subset-1 - NO_DOWNLOAD_CI_LLVM: 1 - CUSTOM_MINGW: 1 - os: windows-2019-8core-32gb - - name: i686-mingw-2 - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-gnu" - SCRIPT: make ci-mingw-subset-2 - NO_DOWNLOAD_CI_LLVM: 1 - CUSTOM_MINGW: 1 - os: windows-2019-8core-32gb - - name: x86_64-mingw-1 - env: - SCRIPT: make ci-mingw-subset-1 - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-gnu --enable-profiler" - NO_DOWNLOAD_CI_LLVM: 1 - CUSTOM_MINGW: 1 - os: windows-2019-8core-32gb - - name: x86_64-mingw-2 - env: - SCRIPT: make ci-mingw-subset-2 - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-gnu --enable-profiler" - NO_DOWNLOAD_CI_LLVM: 1 - CUSTOM_MINGW: 1 - os: windows-2019-8core-32gb - - name: dist-x86_64-msvc - env: - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --host=x86_64-pc-windows-msvc --target=x86_64-pc-windows-msvc --enable-full-tools --enable-profiler" - SCRIPT: PGO_HOST=x86_64-pc-windows-msvc python src/ci/stage-build.py python x.py dist bootstrap --include-default-paths - DIST_REQUIRE_ALL_TOOLS: 1 - os: windows-2019-8core-32gb - - name: dist-i686-msvc - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-msvc --host=i686-pc-windows-msvc --target=i686-pc-windows-msvc,i586-pc-windows-msvc --enable-full-tools --enable-profiler" - SCRIPT: python x.py dist bootstrap --include-default-paths - DIST_REQUIRE_ALL_TOOLS: 1 - os: windows-2019-8core-32gb - - name: dist-aarch64-msvc - env: - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --host=aarch64-pc-windows-msvc --enable-full-tools --enable-profiler" - SCRIPT: python x.py dist bootstrap --include-default-paths - DIST_REQUIRE_ALL_TOOLS: 1 - WINDOWS_SDK_20348_HACK: 1 - os: windows-2019-8core-32gb - - name: dist-i686-mingw - env: - RUST_CONFIGURE_ARGS: "--build=i686-pc-windows-gnu --enable-full-tools --enable-profiler" - NO_DOWNLOAD_CI_LLVM: 1 - SCRIPT: python x.py dist bootstrap --include-default-paths - CUSTOM_MINGW: 1 - DIST_REQUIRE_ALL_TOOLS: 1 - os: windows-2019-8core-32gb - - name: dist-x86_64-mingw - env: - SCRIPT: python x.py dist bootstrap --include-default-paths - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-gnu --enable-full-tools --enable-profiler" - NO_DOWNLOAD_CI_LLVM: 1 - CUSTOM_MINGW: 1 - DIST_REQUIRE_ALL_TOOLS: 1 - os: windows-2019-8core-32gb - - name: dist-x86_64-msvc-alt - env: - RUST_CONFIGURE_ARGS: "--build=x86_64-pc-windows-msvc --enable-extended --enable-profiler" - SCRIPT: python x.py dist bootstrap --include-default-paths - os: windows-2019-8core-32gb - timeout-minutes: 600 - runs-on: "${{ matrix.os }}" - steps: - - name: disable git crlf conversion - run: git config --global core.autocrlf false - - name: checkout the source code - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - name: configure the PR in which the error message will be posted - run: "echo \"[CI_PR_NUMBER=$num]\"" - env: - num: "${{ github.event.number }}" - if: "success() && !env.SKIP_JOB && github.event_name == 'pull_request'" - - name: add extra environment variables - run: src/ci/scripts/setup-environment.sh - env: - EXTRA_VARIABLES: "${{ toJson(matrix.env) }}" - if: success() && !env.SKIP_JOB - - name: decide whether to skip this job - run: src/ci/scripts/should-skip-this.sh - if: success() && !env.SKIP_JOB - - name: ensure the channel matches the target branch - run: src/ci/scripts/verify-channel.sh - if: success() && !env.SKIP_JOB - - name: configure GitHub Actions to kill the build when outdated - uses: rust-lang/simpleinfra/github-actions/cancel-outdated-builds@master - with: - github_token: "${{ secrets.github_token }}" - if: "success() && !env.SKIP_JOB && github.ref != 'refs/heads/try' && github.ref != 'refs/heads/try-perf'" - - name: collect CPU statistics - run: src/ci/scripts/collect-cpu-stats.sh - if: success() && !env.SKIP_JOB - - name: show the current environment - run: src/ci/scripts/dump-environment.sh - if: success() && !env.SKIP_JOB - - name: install sccache - run: src/ci/scripts/install-sccache.sh - if: success() && !env.SKIP_JOB - - name: select Xcode - run: src/ci/scripts/select-xcode.sh - if: success() && !env.SKIP_JOB - - name: install clang - run: src/ci/scripts/install-clang.sh - if: success() && !env.SKIP_JOB - - name: install WIX - run: src/ci/scripts/install-wix.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: checkout submodules - run: src/ci/scripts/checkout-submodules.sh - if: success() && !env.SKIP_JOB - - name: install MSYS2 - run: src/ci/scripts/install-msys2.sh - if: success() && !env.SKIP_JOB - - name: install MinGW - run: src/ci/scripts/install-mingw.sh - if: success() && !env.SKIP_JOB - - name: install ninja - run: src/ci/scripts/install-ninja.sh - if: success() && !env.SKIP_JOB - - name: enable ipv6 on Docker - run: src/ci/scripts/enable-docker-ipv6.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: ensure line endings are correct - run: src/ci/scripts/verify-line-endings.sh - if: success() && !env.SKIP_JOB - - name: ensure backported commits are in upstream branches - run: src/ci/scripts/verify-backported-commits.sh - if: success() && !env.SKIP_JOB - - name: ensure the stable version number is correct - run: src/ci/scripts/verify-stable-version-number.sh - if: success() && !env.SKIP_JOB - - name: run the build - run: src/ci/scripts/run-build-from-ci.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.CACHES_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.CACHES_AWS_ACCESS_KEY_ID)] }}" - TOOLSTATE_REPO_ACCESS_TOKEN: "${{ secrets.TOOLSTATE_REPO_ACCESS_TOKEN }}" - if: success() && !env.SKIP_JOB - - name: upload artifacts to S3 - run: src/ci/scripts/upload-artifacts.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.ARTIFACTS_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.ARTIFACTS_AWS_ACCESS_KEY_ID)] }}" - if: "success() && !env.SKIP_JOB && (github.event_name == 'push' || env.DEPLOY == '1' || env.DEPLOY_ALT == '1')" - try: - permissions: - actions: write - name: "try - ${{ matrix.name }}" - env: - CI_JOB_NAME: "${{ matrix.name }}" - CARGO_REGISTRIES_CRATES_IO_PROTOCOL: sparse - SCCACHE_BUCKET: rust-lang-ci-sccache2 - DEPLOY_BUCKET: rust-lang-ci2 - TOOLSTATE_REPO: "https://github.com/rust-lang-nursery/rust-toolstate" - TOOLSTATE_ISSUES_API_URL: "https://api.github.com/repos/rust-lang/rust/issues" - TOOLSTATE_PUBLISH: 1 - CACHES_AWS_ACCESS_KEY_ID: AKIA46X5W6CZI5DHEBFL - ARTIFACTS_AWS_ACCESS_KEY_ID: AKIA46X5W6CZN24CBO55 - AWS_REGION: us-west-1 - CACHE_DOMAIN: ci-caches.rust-lang.org - if: "github.event_name == 'push' && (github.ref == 'refs/heads/try' || github.ref == 'refs/heads/try-perf') && github.repository == 'rust-lang-ci/rust'" - strategy: - matrix: - include: - - name: dist-x86_64-linux - os: ubuntu-20.04-16core-64gb - env: {} + os: [ubuntu-20.04] + timeout-minutes: 600 - runs-on: "${{ matrix.os }}" steps: - - name: disable git crlf conversion - run: git config --global core.autocrlf false - name: checkout the source code uses: actions/checkout@v3 with: fetch-depth: 2 - - name: configure the PR in which the error message will be posted - run: "echo \"[CI_PR_NUMBER=$num]\"" - env: - num: "${{ github.event.number }}" - if: "success() && !env.SKIP_JOB && github.event_name == 'pull_request'" - name: add extra environment variables run: src/ci/scripts/setup-environment.sh env: EXTRA_VARIABLES: "${{ toJson(matrix.env) }}" if: success() && !env.SKIP_JOB - - name: decide whether to skip this job - run: src/ci/scripts/should-skip-this.sh - if: success() && !env.SKIP_JOB - - name: ensure the channel matches the target branch - run: src/ci/scripts/verify-channel.sh - if: success() && !env.SKIP_JOB - - name: configure GitHub Actions to kill the build when outdated - uses: rust-lang/simpleinfra/github-actions/cancel-outdated-builds@master - with: - github_token: "${{ secrets.github_token }}" - if: "success() && !env.SKIP_JOB && github.ref != 'refs/heads/try' && github.ref != 'refs/heads/try-perf'" - - name: collect CPU statistics - run: src/ci/scripts/collect-cpu-stats.sh - if: success() && !env.SKIP_JOB - - name: show the current environment - run: src/ci/scripts/dump-environment.sh - if: success() && !env.SKIP_JOB - - name: install sccache - run: src/ci/scripts/install-sccache.sh - if: success() && !env.SKIP_JOB - - name: select Xcode - run: src/ci/scripts/select-xcode.sh - if: success() && !env.SKIP_JOB - - name: install clang - run: src/ci/scripts/install-clang.sh - if: success() && !env.SKIP_JOB - - name: install WIX - run: src/ci/scripts/install-wix.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: checkout submodules - run: src/ci/scripts/checkout-submodules.sh - if: success() && !env.SKIP_JOB - - name: install MSYS2 - run: src/ci/scripts/install-msys2.sh - if: success() && !env.SKIP_JOB - - name: install MinGW - run: src/ci/scripts/install-mingw.sh - if: success() && !env.SKIP_JOB - - name: install ninja - run: src/ci/scripts/install-ninja.sh - if: success() && !env.SKIP_JOB - - name: enable ipv6 on Docker - run: src/ci/scripts/enable-docker-ipv6.sh - if: success() && !env.SKIP_JOB - - name: disable git crlf conversion - run: src/ci/scripts/disable-git-crlf-conversion.sh - if: success() && !env.SKIP_JOB - - name: ensure line endings are correct - run: src/ci/scripts/verify-line-endings.sh - if: success() && !env.SKIP_JOB - - name: ensure backported commits are in upstream branches - run: src/ci/scripts/verify-backported-commits.sh - if: success() && !env.SKIP_JOB - - name: ensure the stable version number is correct - run: src/ci/scripts/verify-stable-version-number.sh - if: success() && !env.SKIP_JOB - - name: run the build - run: src/ci/scripts/run-build-from-ci.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.CACHES_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.CACHES_AWS_ACCESS_KEY_ID)] }}" - TOOLSTATE_REPO_ACCESS_TOKEN: "${{ secrets.TOOLSTATE_REPO_ACCESS_TOKEN }}" - if: success() && !env.SKIP_JOB - - name: upload artifacts to S3 - run: src/ci/scripts/upload-artifacts.sh - env: - AWS_ACCESS_KEY_ID: "${{ env.ARTIFACTS_AWS_ACCESS_KEY_ID }}" - AWS_SECRET_ACCESS_KEY: "${{ secrets[format('AWS_SECRET_ACCESS_KEY_{0}', env.ARTIFACTS_AWS_ACCESS_KEY_ID)] }}" - if: "success() && !env.SKIP_JOB && (github.event_name == 'push' || env.DEPLOY == '1' || env.DEPLOY_ALT == '1')" - master: - name: master - runs-on: ubuntu-latest - env: - SCCACHE_BUCKET: rust-lang-ci-sccache2 - DEPLOY_BUCKET: rust-lang-ci2 - TOOLSTATE_REPO: "https://github.com/rust-lang-nursery/rust-toolstate" - TOOLSTATE_ISSUES_API_URL: "https://api.github.com/repos/rust-lang/rust/issues" - TOOLSTATE_PUBLISH: 1 - CACHES_AWS_ACCESS_KEY_ID: AKIA46X5W6CZI5DHEBFL - ARTIFACTS_AWS_ACCESS_KEY_ID: AKIA46X5W6CZN24CBO55 - AWS_REGION: us-west-1 - CACHE_DOMAIN: ci-caches.rust-lang.org - if: "github.event_name == 'push' && github.ref == 'refs/heads/master' && github.repository == 'rust-lang-ci/rust'" - steps: - - name: checkout the source code - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - name: publish toolstate - run: src/ci/publish_toolstate.sh - shell: bash - env: - TOOLSTATE_REPO_ACCESS_TOKEN: "${{ secrets.TOOLSTATE_REPO_ACCESS_TOKEN }}" - if: success() && !env.SKIP_JOB - try-success: - needs: - - try - if: "success() && github.event_name == 'push' && (github.ref == 'refs/heads/try' || github.ref == 'refs/heads/try-perf') && github.repository == 'rust-lang-ci/rust'" - steps: - - name: mark the job as a success - run: exit 0 - shell: bash - name: bors build finished - runs-on: ubuntu-latest - try-failure: - needs: - - try - if: "!success() && github.event_name == 'push' && (github.ref == 'refs/heads/try' || github.ref == 'refs/heads/try-perf') && github.repository == 'rust-lang-ci/rust'" - steps: - - name: mark the job as a failure - run: exit 1 - shell: bash - name: bors build finished - runs-on: ubuntu-latest - auto-success: - needs: - - auto - if: "success() && github.event_name == 'push' && github.ref == 'refs/heads/auto' && github.repository == 'rust-lang-ci/rust'" - steps: - - name: mark the job as a success - run: exit 0 - shell: bash - name: bors build finished - runs-on: ubuntu-latest - auto-failure: - needs: - - auto - if: "!success() && github.event_name == 'push' && github.ref == 'refs/heads/auto' && github.repository == 'rust-lang-ci/rust'" - steps: - - name: mark the job as a failure - run: exit 1 - shell: bash - name: bors build finished - runs-on: ubuntu-latest + - name: build + run: | + mkdir build + cd build + ../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs + ../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc + rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 + rustup toolchain install nightly # enables -Z unstable-options + - name: test + run: | + cargo +enzyme test --examples diff --git a/.gitmodules b/.gitmodules index 4596ae17d0238..acf63590fdc82 100644 --- a/.gitmodules +++ b/.gitmodules @@ -32,3 +32,6 @@ [submodule "library/backtrace"] path = library/backtrace url = https://github.com/rust-lang/backtrace-rs.git +[submodule "src/tools/enzyme"] + path = src/tools/enzyme + url = https://github.com/EnzymeAD/Enzyme.git diff --git a/Cargo.lock b/Cargo.lock index e0c72d6899e98..5ac910d0ea720 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3897,6 +3897,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 53331e2869f2e..72c613f69bf0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ exclude = [ "src/tools/x", # stdarch has its own Cargo workspace "library/stdarch", + "library/autodiff", ] [profile.release.package.compiler_builtins] diff --git a/README.md b/README.md index 41b135972af11..0a9449f2e480d 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,68 @@ -# The Rust Programming Language +# The Rust Programming Language +Enzyme [![Rust Community](https://img.shields.io/badge/Rust_Community%20-Join_us-brightgreen?style=plastic&logo=rust)](https://www.rust-lang.org/community) This is the main source code repository for [Rust]. It contains the compiler, -standard library, and documentation. +standard library, and documentation. It is modified to use Enzyme for AutoDiff. + +Please configure this fork using the following command: + +``` +mkdir build +cd build +../configure --enable-llvm-link-shared --enable-llvm-plugins --enable-llvm-enzyme --release-channel=nightly --enable-llvm-assertions --enable-clang --enable-lld --enable-option-checking --enable-ninja --disable-docs +``` + +Afterwards you can build rustc using: +``` +../x.py build --stage 1 library/std library/proc_macro library/test tools/rustdoc +``` + +Afterwards rustc toolchain link will allow you to use it through cargo: +``` +rustup toolchain link enzyme `pwd`/build/`rustup target list --installed`/stage1 +rustup toolchain install nightly # enables -Z unstable-options +``` + +You can then look at examples in the `library/autodiff/examples/*` folder and run them with + +```bash +# rosenbrock forward iteration +cargo +enzyme run --example rosenbrock_fwd_iter --release + +# or all of them +cargo +enzyme test --examples +``` + +## Enzyme Config +To help with debugging, Enzyme can be configured using environment variables. +```bash +export ENZYME_PRINT_TA=1 +export ENZYME_PRINT_AA=1 +export ENZYME_PRINT=1 +export ENZYME_PRINT_MOD=1 +export ENZYME_PRINT_MOD_AFTER=1 +``` +The first three will print TypeAnalysis, ActivityAnalysis and the llvm-ir on a function basis, respectively. +The last two variables will print the whole module directly before and after Enzyme differented the functions. + +When experimenting with flags please make sure that EnzymeStrictAliasing=0 +is not changed, since it is required for Enzyme to handle enums correctly. + +## Bug reporting +Bugs are pretty much expected at this point of the development process. +In order to help us please minimize the Rust code as far as possible. +This tool might be a nicer helper: https://github.com/Nilstrieb/cargo-minimize +If you have some knowledge of LLVM-IR we also greatly appreciate it if you could help +us by compiling your minimized Rust code to LLVM-IR and reducing it further. + +The only exception to this strategy is error based on "Can not deduce type of X", +where reducing your example will make it harder for us to understand the origin of the bug. +In this case please just try to inline all dependencies into a single crate or even file, +without deleting used code. + + + [Rust]: https://www.rust-lang.org/ diff --git a/compiler/rustc_ast/src/mut_visit.rs b/compiler/rustc_ast/src/mut_visit.rs index 66b94d12a32c6..415acac9187d7 100644 --- a/compiler/rustc_ast/src/mut_visit.rs +++ b/compiler/rustc_ast/src/mut_visit.rs @@ -380,7 +380,7 @@ pub fn visit_bounds(bounds: &mut GenericBounds, vis: &mut T) { } // No `noop_` prefix because there isn't a corresponding method in `MutVisitor`. -pub fn visit_fn_sig(FnSig { header, decl, span }: &mut FnSig, vis: &mut T) { +pub fn visit_fn_sig(FnSig { header, decl, span, .. }: &mut FnSig, vis: &mut T) { vis.visit_fn_header(header); vis.visit_fn_decl(decl); vis.visit_span(span); diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 651d644ebb63d..bfe7b73ee66fd 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -280,6 +280,7 @@ pub fn from_fn_attrs<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -297,9 +298,12 @@ pub fn from_fn_attrs<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; + to_add.extend(inline_attr(cx, inline)); // The `uwtable` attribute according to LLVM is: diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index d2e01708a37bc..34f649a78a157 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -244,6 +244,7 @@ fn fat_lto( info!("pushing cached module {:?}", wp.cgu_name); (buffer, CString::new(wp.cgu_name).unwrap()) })); + for module in modules { match module { FatLTOInput::InMemory(m) => in_memory.push(m), @@ -254,7 +255,6 @@ fn fat_lto( } } } - // Find the "costliest" module and merge everything into that codegen unit. // All the other modules will be serialized and reparsed into the new // context, so this hopefully avoids serializing and parsing the largest @@ -700,7 +700,7 @@ pub unsafe fn optimize_thin_module( let llcx = llvm::LLVMRustContextCreate(cgcx.fewer_names); let llmod_raw = parse_module(llcx, module_name, thin_module.data(), &diag_handler)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm }, + module_llvm: ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ca2eab28f872b..923db9952dc58 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -2,17 +2,33 @@ use crate::back::lto::ThinBuffer; use crate::back::profiling::{ selfprofile_after_pass_callback, selfprofile_before_pass_callback, LlvmSelfProfiler, }; -use crate::base; + use crate::common; use crate::consts; use crate::errors::{ CopyBitcode, FromLlvmDiag, FromLlvmOptimizationDiag, LlvmError, WithLlvmError, WriteBytecode, }; use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{LLVMReplaceAllUsesWith, LLVMVerifyFunction, Value}; use crate::llvm_util; use crate::type_::Type; +use crate::typetree::to_enzyme_typetree; use crate::LlvmCodegenBackend; use crate::ModuleLlvm; +use crate::{base, DiffTypeTree}; +use llvm::{ + enzyme_rust_forward_diff, enzyme_rust_reverse_diff, BasicBlock, CreateEnzymeLogic, + CreateTypeAnalysis, EnzymeLogicRef, EnzymeTypeAnalysisRef, LLVMAddFunction, + LLVMAppendBasicBlockInContext, LLVMBuildCall2, LLVMBuildExtractValue, LLVMBuildRet, + LLVMCountParams, LLVMCountStructElementTypes, LLVMCreateBuilderInContext, LLVMDeleteFunction, + LLVMDisposeBuilder, LLVMGetBasicBlockTerminator, LLVMGetElementType, LLVMGetModuleContext, + LLVMGetParams, LLVMGetReturnType, LLVMPositionBuilderAtEnd, LLVMSetValueName2, LLVMTypeOf, + LLVMVoidTypeInContext, LLVMGlobalGetValueType, LLVMGetStringAttributeAtIndex, + LLVMIsStringAttribute, LLVMRemoveStringAttributeAtIndex, LLVMRemoveEnumAttributeAtIndex, AttributeKind, + LLVMGetFirstFunction, LLVMGetNextFunction, LLVMGetEnumAttributeAtIndex, LLVMIsEnumAttribute, + LLVMCreateStringAttribute, LLVMRustAddFunctionAttributes, LLVMCreateEnumAttribute, LLVMDumpModule +}; +//use llvm::LLVMRustGetNamedValue; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -20,10 +36,12 @@ use rustc_codegen_ssa::back::write::{ }; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::profiling::SelfProfilerRef; use rustc_data_structures::small_c_str::SmallCStr; use rustc_errors::{FatalError, Handler, Level}; use rustc_fs_util::{link_or_copy, path_to_c_string}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; use rustc_middle::ty::TyCtxt; use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath}; use rustc_session::Session; @@ -33,7 +51,7 @@ use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo}; use crate::llvm::diagnostic::OptimizationDiagnosticKind; use libc::{c_char, c_int, c_uint, c_void, size_t}; -use std::ffi::CString; +use std::ffi::{CStr, CString}; use std::fs; use std::io::{self, Write}; use std::path::{Path, PathBuf}; @@ -436,8 +454,18 @@ pub(crate) unsafe fn llvm_optimize( opt_level: config::OptLevel, opt_stage: llvm::OptStage, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // We want to simplify / optimize functions before AD. + // However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore activate them first, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // RIP compile time. + // let unroll_loops = + // opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + let unroll_loops = false; + let vectorize_slp = false; + let vectorize_loop = false; + let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -489,8 +517,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -512,6 +540,255 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(diag_handler, LlvmError::RunLlvmPasses)) } +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +// TODO: cleanup +unsafe fn create_wrapper<'a>( + llmod: &'a llvm::Module, + //module: &'a ModuleCodegen, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> (&'a Value, &'a BasicBlock, Vec<&'a Value>, Vec<&'a Value>, CString) { + //let llmod = module.module_llvm.llmod(); + let context = LLVMGetModuleContext(llmod); + let inner_fnc_name = "inner_".to_string() + &fnc_name; + let c_inner_fnc_name = CString::new(inner_fnc_name.clone()).unwrap(); + LLVMSetValueName2(fnc, c_inner_fnc_name.as_ptr(), inner_fnc_name.len() as usize); + + let c_outer_fnc_name = CString::new(fnc_name).unwrap(); + let outer_fnc: &Value = + LLVMAddFunction(llmod, c_outer_fnc_name.as_ptr(), LLVMGetElementType(u_type) as &Type); + + let entry = "fnc_entry".to_string(); + let c_entry = CString::new(entry).unwrap(); + let basic_block = LLVMAppendBasicBlockInContext(context, outer_fnc, c_entry.as_ptr()); + + let outer_params: Vec<&Value> = get_params(outer_fnc); + let inner_params: Vec<&Value> = get_params(fnc); + + (outer_fnc, basic_block, outer_params, inner_params, c_inner_fnc_name) +} + +//pub(crate) fn get_type(t: LLVMTypeRef) -> CString { +// unsafe { CString::from_raw(LLVMPrintTypeToString(t)) } +//} + +// TODO: Don't write a wrapper function, just unwrap the struct inside of the same fnc. +// Might help during debugging, if you have one function less to jump trough +pub(crate) unsafe fn extract_return_type<'a>( + llmod: &'a llvm::Module, + fnc: &'a Value, + u_type: &Type, + fnc_name: String, +) -> &'a Value { + //let llmod = module.module_llvm.llmod(); + let context = llvm::LLVMGetModuleContext(llmod); + //dbg!("Unpacking", fnc_name.clone()); + //dbg!("From: ", f_type, " into ", u_type); + + let inner_param_num = LLVMCountParams(fnc); + let (outer_fnc, outer_bb, mut outer_args, _inner_args, c_inner_fnc_name) = + create_wrapper(llmod, fnc, u_type, fnc_name); + + if inner_param_num as usize != outer_args.len() { + panic!("Args len shouldn't differ. Please report this."); + } + + let builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(builder, outer_bb); + let struct_ret = LLVMBuildCall2( + builder, + u_type, + fnc, + outer_args.as_mut_ptr(), + outer_args.len(), + c_inner_fnc_name.as_ptr(), + ); + // We can use an arbitrary name here, since it will be used to store a tmp value. + let inner_grad_name = "foo".to_string(); + let c_inner_grad_name = CString::new(inner_grad_name).unwrap(); + let struct_ret = LLVMBuildExtractValue(builder, struct_ret, 0, c_inner_grad_name.as_ptr()); + let _ret = LLVMBuildRet(builder, struct_ret); + let _terminator = LLVMGetBasicBlockTerminator(outer_bb); + //assert!(LLVMIsNull(terminator)!=0, "no terminator"); + LLVMDisposeBuilder(builder); + + let _fnc_ok = + LLVMVerifyFunction(outer_fnc, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + //dbg!(outer_fnc); + //assert!(fnc_ok); + //if let Err(e) = verify_function(outer_fnc) { + // panic!("Creating a wrapper function failed! {}", e); + //} + + outer_fnc +} + +// As unsafe as it can be. +#[allow(unused_variables)] +#[allow(unused)] +pub(crate) unsafe fn enzyme_ad( + llmod: &llvm::Module, + llcx: &llvm::Context, + item: AutoDiffItem, +) -> Result<(), FatalError> { + let autodiff_mode = item.attrs.mode; + let rust_name = item.source; + let rust_name2 = &item.target; + + let args_activity = item.attrs.input_activity.clone(); + let ret_activity: DiffActivity = item.attrs.ret_activity; + + // get target and source function + let name = CString::new(rust_name.to_owned()).unwrap(); + let name2 = CString::new(rust_name2.clone()).unwrap(); + let src_fnc = llvm::LLVMGetNamedFunction(llmod, name.as_c_str().as_ptr()).unwrap(); + let target_fnc = llvm::LLVMGetNamedFunction(llmod, name2.as_ptr()).unwrap(); + + // create enzyme typetrees + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + + let input_tts = + item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); + let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + + let opt = 1; + let ret_primary_ret = false; + let diff_primary_ret = false; + let logic_ref: EnzymeLogicRef = CreateEnzymeLogic(opt as u8); + let type_analysis: EnzymeTypeAnalysisRef = + CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0); + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_TA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintType), 1); + } + if std::env::var("ENZYME_PRINT_AA").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintActivity), 1); + } + if std::env::var("ENZYME_PRINT_PERF").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrintPerf), 1); + } + if std::env::var("ENZYME_PRINT").is_ok() { + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymePrint), 1); + } + + let mut res: &Value = match item.attrs.mode { + DiffMode::Forward => enzyme_rust_forward_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + input_tts, + output_tt, + ), + DiffMode::Reverse => enzyme_rust_reverse_diff( + logic_ref, + type_analysis, + src_fnc, + args_activity, + ret_activity, + ret_primary_ret, + diff_primary_ret, + input_tts, + output_tt, + ), + _ => unreachable!(), + }; + let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res)); + + let void_type = LLVMVoidTypeInContext(llcx); + if item.attrs.mode == DiffMode::Reverse && f_return_type != void_type { + //dbg!("Reverse Mode sanitizer"); + //dbg!(f_type); + //dbg!(f_return_type); + let num_elem_in_ret_struct = LLVMCountStructElementTypes(f_return_type); + if num_elem_in_ret_struct == 1 { + let u_type = LLVMTypeOf(target_fnc); + res = extract_return_type(llmod, res, u_type, rust_name2.clone()); // TODO: check if name or name2 + } + } + //dbg!(&target_fnc); + LLVMSetValueName2(res, name2.as_ptr(), rust_name2.len()); + LLVMReplaceAllUsesWith(target_fnc, res); + LLVMDeleteFunction(target_fnc); + + Ok(()) +} + +pub(crate) unsafe fn differentiate( + module: &ModuleCodegen, + _cgcx: &CodegenContext, + diff_items: Vec, + _typetrees: FxHashMap, + _config: &ModuleConfig, +) -> Result<(), FatalError> { + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + + llvm::EnzymeSetCLBool(std::ptr::addr_of_mut!(llvm::EnzymeStrictAliasing), 0); + + if std::env::var("ENZYME_PRINT_MOD").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + if std::env::var("ENZYME_TT_DEPTH").is_ok() { + let depth = std::env::var("ENZYME_TT_DEPTH").unwrap(); + let depth = depth.parse::().unwrap(); + assert!(depth >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::EnzymeMaxTypeDepth), depth); + } + if std::env::var("ENZYME_TT_WIDTH").is_ok() { + let width = std::env::var("ENZYME_TT_WIDTH").unwrap(); + let width = width.parse::().unwrap(); + assert!(width >= 1); + llvm::EnzymeSetCLInteger(std::ptr::addr_of_mut!(llvm::MaxTypeOffset), width); + } + + for item in diff_items { + let res = enzyme_ad(llmod, llcx, item); + assert!(res.is_ok()); + } + + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex(lf, c_uint::MAX, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint); + } else { + LLVMRemoveEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + } + + + } else { + break; + } + } + if std::env::var("ENZYME_PRINT_MOD_AFTER").is_ok() { + unsafe {LLVMDumpModule(llmod);} + } + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -534,6 +811,28 @@ pub(crate) unsafe fn optimize( llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()); } + { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMGetEnumAttributeAtIndex(lf, c_uint::MAX, AttributeKind::SanitizeHWAddress); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute(llcx, myhwattr.as_ptr() as *const c_char, myhwattr.as_bytes().len() as c_uint, myhwv.as_ptr() as *const c_char, myhwv.as_bytes().len() as c_uint); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + let attr = LLVMCreateEnumAttribute(llcx, AttributeKind::SanitizeHWAddress, 0); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } + + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -543,6 +842,7 @@ pub(crate) unsafe fn optimize( }; return llvm_optimize(cgcx, diag_handler, module, config, opt_level, opt_stage); } + Ok(()) } @@ -554,6 +854,7 @@ pub(crate) fn link( use super::lto::{Linker, ModuleBuffer}; // Sort the modules by name to ensure deterministic behavior. modules.sort_by(|a, b| a.name.cmp(&b.name)); + let (first, elements) = modules.split_first().expect("Bug! modules must contain at least one module."); @@ -566,6 +867,7 @@ pub(crate) fn link( })?; } drop(linker); + Ok(modules.remove(0)) } diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index 5b2bbdb4bde1e..4c049a7342f08 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -1,5 +1,3 @@ -//! Codegen the MIR to the LLVM IR. -//! //! Hopefully useful general knowledge about codegen: //! //! * There's no way to find out the [`Ty`] type of a [`Value`]. Doing so @@ -25,6 +23,7 @@ use rustc_codegen_ssa::base::maybe_create_entry_wrapper; use rustc_codegen_ssa::mono_item::MonoItemExt; use rustc_codegen_ssa::traits::*; use rustc_codegen_ssa::{ModuleCodegen, ModuleKind}; +use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::small_c_str::SmallCStr; use rustc_middle::dep_graph; use rustc_middle::middle::codegen_fn_attrs::CodegenFnAttrs; @@ -82,9 +81,10 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen recorder.record_arg(cgu.size_estimate().to_string()); }); // Instantiate monomorphizations without filling out definitions yet... - let llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); - { + let mut llvm_module = ModuleLlvm::new(tcx, cgu_name.as_str()); + let typetrees = { let cx = CodegenCx::new(tcx, cgu, &llvm_module); + let mono_items = cx.codegen_unit.items_in_deterministic_order(cx.tcx); for &(mono_item, (linkage, visibility)) in &mono_items { mono_item.predefine::>(&cx, linkage, visibility); @@ -133,7 +133,30 @@ pub fn compile_codegen_unit(tcx: TyCtxt<'_>, cgu_name: Symbol) -> (ModuleCodegen if cx.sess().opts.debuginfo != DebugInfo::None { cx.debuginfo_finalize(); } - } + + // find autodiff items and build typetrees for them + /*mono_items.iter() + //.filter(|(mono_item, _)| mono_item.def_id().map(|x| tcx.autodiff_attrs(x).is_active()).unwrap_or(false)) + .filter(|(mono_item, _)| mono_item.def_id().map(|x| tcx.autodiff_attrs(x).is_source()).unwrap_or(false)) + .filter_map(|(mono_item, _)| { + let symbol = mono_item.symbol_name(cx.tcx).to_string(); + match mono_item { + MonoItem::Fn(instance) => { + let ty = instance.ty(tcx, ParamEnv::empty()); + + Some(( + symbol, + parse_typetree(tcx, ty, &llvm_module) + )) + }, + _ => None + } + }).collect::>()*/ + + FxHashMap::default() + }; + + llvm_module.typetrees = typetrees; ModuleCodegen { name: cgu_name.to_string(), diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 4d0bcd53d1562..f4447c5db75d4 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -30,6 +30,8 @@ use std::iter; use std::ops::Deref; use std::ptr; +// use libc::rand; + // All Builders must have an llfn associated with them #[must_use] pub struct Builder<'a, 'll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 83101a85435a0..34e5c845c122c 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -597,6 +597,10 @@ impl<'ll, 'tcx> MiscMethods<'tcx> for CodegenCx<'ll, 'tcx> { None } } + + fn create_autodiff(&self) -> Vec { + return vec![]; + } } impl<'ll> CodegenCx<'ll, '_> { @@ -630,17 +634,17 @@ impl<'ll> CodegenCx<'ll, '_> { if key == $name { return Some(self.insert_intrinsic($name, Some(&[]), $ret)); } - ); + ); ($name:expr, fn(...) -> $ret:expr) => ( if key == $name { return Some(self.insert_intrinsic($name, None, $ret)); } - ); + ); ($name:expr, fn($($arg:expr),*) -> $ret:expr) => ( if key == $name { return Some(self.insert_intrinsic($name, Some(&[$($arg),*]), $ret)); } - ); + ); } macro_rules! mk_struct { ($($field_ty:expr),*) => (self.type_struct( &[$($field_ty),*], false)) diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 8305a0a4c286d..15486bccde00e 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -23,6 +23,7 @@ extern crate tracing; use back::write::{create_informational_target_machine, create_target_machine}; use errors::ParseTargetMachineConfig; +use llvm::TypeTree; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; @@ -37,6 +38,7 @@ use rustc_errors::{DiagnosticMessage, ErrorGuaranteed, FatalError, Handler, Subd use rustc_fluent_macro::fluent_messages; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use rustc_middle::ty::query::Providers; use rustc_middle::ty::TyCtxt; use rustc_session::config::{OptLevel, OutputFilenames, PrintRequest}; @@ -68,6 +70,7 @@ mod debuginfo; mod declare; mod errors; mod intrinsic; +mod typetree; // The following is a workaround that replaces `pub mod llvm;` and that fixes issue 53912. #[path = "llvm/mod.rs"] @@ -175,6 +178,8 @@ impl WriteBackendMethods for LlvmCodegenBackend { type TargetMachineError = crate::errors::LlvmError<'static>; type ThinData = back::lto::ThinData; type ThinBuffer = back::lto::ThinBuffer; + type TypeTree = DiffTypeTree; + fn print_pass_timings(&self) { unsafe { llvm::LLVMRustPrintPassTimings(); @@ -236,6 +241,20 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + unsafe { back::write::differentiate(module, cgcx, diff_fncs, typetrees, config) } + } + + fn typetrees(module: &mut Self::Module) -> FxHashMap { + module.typetrees.drain().collect() + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -385,10 +404,18 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[derive(Clone, Debug)] +pub struct DiffTypeTree { + pub ret_tt: TypeTree, + pub input_tt: Vec, +} + +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, tm: &'static mut llvm::TargetMachine, + typetrees: FxHashMap, } unsafe impl Send for ModuleLlvm {} @@ -399,7 +426,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_target_machine(tcx, mod_name) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_target_machine(tcx, mod_name), + typetrees: Default::default(), + } } } @@ -407,7 +439,12 @@ impl ModuleLlvm { unsafe { let llcx = llvm::LLVMRustContextCreate(tcx.sess.fewer_names()); let llmod_raw = context::create_module(tcx, llcx, mod_name) as *const _; - ModuleLlvm { llmod_raw, llcx, tm: create_informational_target_machine(tcx.sess) } + ModuleLlvm { + llmod_raw, + llcx, + tm: create_informational_target_machine(tcx.sess), + typetrees: Default::default(), + } } } @@ -428,7 +465,7 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm }) + Ok(ModuleLlvm { llmod_raw, llcx, tm, typetrees: Default::default() }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index aefd5b2a13c92..1800b57e32a9a 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1,7 +1,10 @@ #![allow(non_camel_case_types)] #![allow(non_upper_case_globals)] + +use std::ptr; use rustc_codegen_ssa::coverageinfo::map as coverage_map; +use rustc_middle::middle::autodiff_attrs::DiffActivity; use super::debuginfo::{ DIArray, DIBasicType, DIBuilder, DICompositeType, DIDerivedType, DIDescriptor, DIEnumerator, @@ -13,6 +16,8 @@ use super::debuginfo::{ use libc::{c_char, c_int, c_uint, size_t}; use libc::{c_ulonglong, c_void}; +use core::fmt; +use std::ffi::{CStr, CString}; use std::marker::PhantomData; use super::RustString; @@ -184,7 +189,7 @@ pub enum AttributeKind { OptimizeNone = 24, ReturnsTwice = 25, ReadNone = 26, - SanitizeHWAddress = 28, + SanitizeHWAddress = 51, WillReturn = 29, StackProtectReq = 30, StackProtectStrong = 31, @@ -1001,10 +1006,186 @@ pub type SelfProfileBeforePassCallback = unsafe extern "C" fn(*mut c_void, *const c_char, *const c_char); pub type SelfProfileAfterPassCallback = unsafe extern "C" fn(*mut c_void); +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} + +pub(crate) unsafe fn enzyme_rust_forward_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_diffactivity: Vec, + ret_diffactivity: DiffActivity, + mut ret_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_diffactivity); + assert!(ret_activity != CDIFFE_TYPE::DFT_OUT_DIFF); + let mut input_activity: Vec = vec![]; + for input in input_diffactivity { + let act = cdiffe_from(input); + assert!(act == CDIFFE_TYPE::DFT_CONSTANT || act == CDIFFE_TYPE::DFT_DUP_ARG || act == CDIFFE_TYPE::DFT_DUP_NONEED); + input_activity.push(act); + } + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_activity.len()]; + + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + let mut known_values = vec![kv_tmp; input_activity.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreateForwardDiff( + logic_ref, // Logic + ptr::null(), + ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + CDerivativeMode::DEM_ForwardMode, // return value, dret_used, top_level which was 1 + 1, // free memory + 1, // vector mode width + Option::None, + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + ) +} + +pub(crate) unsafe fn enzyme_rust_reverse_diff( + logic_ref: EnzymeLogicRef, + type_analysis: EnzymeTypeAnalysisRef, + fnc: &Value, + input_activity: Vec, + ret_activity: DiffActivity, + mut ret_primary_ret: bool, + diff_primary_ret: bool, + input_tts: Vec, + output_tt: TypeTree, +) -> &Value { + let ret_activity = cdiffe_from(ret_activity); + assert!(ret_activity == CDIFFE_TYPE::DFT_CONSTANT || ret_activity == CDIFFE_TYPE::DFT_OUT_DIFF); + let input_activity: Vec = input_activity.iter().map(|&x| cdiffe_from(x)).collect(); + + if ret_activity == CDIFFE_TYPE::DFT_DUP_ARG { + if ret_primary_ret != true { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = true; + } else if ret_activity == CDIFFE_TYPE::DFT_DUP_NONEED { + if ret_primary_ret != false { + dbg!("overwriting ret_primary_ret!"); + } + ret_primary_ret = false; + } + + let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + + // We don't support volatile / extern / (global?) values. + // Just because I didn't had time to test them, and it seems less urgent. + let args_uncacheable = vec![0; input_tts.len()]; + let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; + + + let mut known_values = vec![kv_tmp; input_tts.len()]; + + let dummy_type = CFnTypeInfo { + Arguments: args_tree.as_mut_ptr(), + Return: output_tt.inner.clone(), + KnownValues: known_values.as_mut_ptr(), + }; + + EnzymeCreatePrimalAndGradient( + logic_ref, // Logic + ptr::null(), + ptr::null(), + fnc, + ret_activity, // LLVM function, return type + input_activity.as_ptr(), + input_activity.len(), // constant arguments + type_analysis, // type analysis struct + ret_primary_ret as u8, + diff_primary_ret as u8, //0 + CDerivativeMode::DEM_ReverseModeCombined, // return value, dret_used, top_level which was 1 + 1, // vector mode width + 1, // free memory + Option::None, + 0, // do not force anonymous tape + dummy_type, // additional_arg, type info (return + args) + args_uncacheable.as_ptr(), + args_uncacheable.len(), // uncacheable arguments + std::ptr::null_mut(), // write augmented function to this + 0, + ) +} pub type GetSymbolsCallback = unsafe extern "C" fn(*mut c_void, *const c_char) -> *mut c_void; pub type GetSymbolsErrorCallback = unsafe extern "C" fn(*const c_char) -> *mut c_void; extern "C" { + + // Enzyme + //pub fn LLVMReplaceAllUsesWith(old: &Value, new: &Value); + pub fn GibtsNicht(M: &Module) -> bool; + pub fn LLVMIsStructTy(ty: &Type) -> bool; + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMDeleteFunction(V: &Value); + pub fn LLVMRemoveStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint); + pub fn LLVMGetStringAttributeAtIndex(F : &Value, Idx: c_uint, K: *const c_char, KLen : c_uint) -> &Attribute; + pub fn LLVMRemoveEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind); + pub fn LLVMGetEnumAttributeAtIndex(F : &Value, Idx: c_uint, K: AttributeKind) -> &Attribute; + pub fn LLVMIsEnumAttribute(A : &Attribute) -> bool; + pub fn LLVMCreateEnumAttribute(C : &Context, Kind: AttributeKind, val:u64) -> &Attribute; + pub fn LLVMIsStringAttribute(A : &Attribute) -> bool; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetBasicBlockTerminator(B: &BasicBlock) -> &Value; + pub fn LLVMAddFunction<'a>(M: &Module, Name: *const c_char, Ty: &Type) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMGlobalGetValueType(val: &Value) -> &Type; + + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; pub fn LLVMRustInstallFatalErrorHandler(); pub fn LLVMRustDisableSystemDialogsOnCrash(); @@ -2262,6 +2443,8 @@ extern "C" { #[allow(improper_ctypes)] pub fn LLVMRustWriteTypeToString(Type: &Type, s: &RustString); #[allow(improper_ctypes)] + pub fn LLVMRustWriteValueNameToString(value_ref: &Value, s: &RustString); + #[allow(improper_ctypes)] pub fn LLVMRustWriteValueToString(value_ref: &Value, s: &RustString); pub fn LLVMIsAConstantInt(value_ref: &Value) -> Option<&ConstantInt>; @@ -2518,7 +2701,6 @@ extern "C" { remark_passes: *const *const c_char, remark_passes_len: usize, ); - #[allow(improper_ctypes)] pub fn LLVMRustGetMangledName(V: &Value, out: &RustString); @@ -2534,3 +2716,301 @@ extern "C" { error_callback: GetSymbolsErrorCallback, ) -> *mut c_void; } +// Manuel +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueTypeAnalysis { + _unused: [u8; 0], +} +pub type EnzymeTypeAnalysisRef = *mut EnzymeOpaqueTypeAnalysis; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueLogic { + _unused: [u8; 0], +} +pub type EnzymeLogicRef = *mut EnzymeOpaqueLogic; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeOpaqueAugmentedReturn { + _unused: [u8; 0], +} +pub type EnzymeAugmentedReturnPtr = *mut EnzymeOpaqueAugmentedReturn; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct IntList { + pub data: *mut i64, + pub size: size_t, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CConcreteType { + DT_Anything = 0, + DT_Integer = 1, + DT_Pointer = 2, + DT_Half = 3, + DT_Float = 4, + DT_Double = 5, + DT_Unknown = 6, +} +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct EnzymeTypeTree { + _unused: [u8; 0], +} +pub type CTypeTreeRef = *mut EnzymeTypeTree; +extern "C" { + fn EnzymeNewTypeTree() -> CTypeTreeRef; +} +extern "C" { + fn EnzymeFreeTypeTree(CTT: CTypeTreeRef); +} +extern "C" { + pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8); +} +extern "C" { + pub fn EnzymeSetCLInteger(arg1: *mut ::std::os::raw::c_void, arg2: i64); +} + +extern "C" { + pub static mut MaxIntOffset: c_void; + pub static mut MaxTypeOffset: c_void; + pub static mut EnzymeMaxTypeDepth: c_void; + + pub static mut EnzymePrintPerf: c_void; + pub static mut EnzymePrintActivity: c_void; + pub static mut EnzymePrintType: c_void; + pub static mut EnzymePrint: c_void; + pub static mut EnzymeStrictAliasing: c_void; +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct CFnTypeInfo { + #[doc = " Types of arguments, assumed of size len(Arguments)"] + pub Arguments: *mut CTypeTreeRef, + #[doc = " Type of return"] + pub Return: CTypeTreeRef, + #[doc = " The specific constant(s) known to represented by an argument, if constant"] + pub KnownValues: *mut IntList, +} +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDIFFE_TYPE { + DFT_OUT_DIFF = 0, + DFT_DUP_ARG = 1, + DFT_CONSTANT = 2, + DFT_DUP_NONEED = 3, +} + +fn cdiffe_from(act: DiffActivity) -> CDIFFE_TYPE { + return match act { + DiffActivity::None => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Active => CDIFFE_TYPE::DFT_OUT_DIFF, + DiffActivity::Const => CDIFFE_TYPE::DFT_CONSTANT, + DiffActivity::Duplicated => CDIFFE_TYPE::DFT_DUP_ARG, + DiffActivity::DuplicatedNoNeed => CDIFFE_TYPE::DFT_DUP_NONEED, + }; +} + +#[repr(u32)] +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub enum CDerivativeMode { + DEM_ForwardMode = 0, + DEM_ReverseModePrimal = 1, + DEM_ReverseModeGradient = 2, + DEM_ReverseModeCombined = 3, + DEM_ForwardModeSplit = 4, +} +extern "C" { + fn EnzymeCreatePrimalAndGradient<'a>( + arg1: EnzymeLogicRef, + builderCtx: *const u8, // &'a Builder<'_>, + callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + dretUsed: u8, + mode: CDerivativeMode, + width: ::std::os::raw::c_uint, + freeMemory: u8, + additionalArg: Option<&Type>, + forceAnonymousTape: u8, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + AtomicAdd: u8, + ) -> &'a Value; + //) -> LLVMValueRef; +} +extern "C" { + fn EnzymeCreateForwardDiff<'a>( + arg1: EnzymeLogicRef, + builderCtx: *const u8,// &'a Builder<'_>, + callerCtx: *const u8,// &'a Value, + todiff: &'a Value, + retType: CDIFFE_TYPE, + constant_args: *const CDIFFE_TYPE, + constant_args_size: size_t, + TA: EnzymeTypeAnalysisRef, + returnValue: u8, + mode: CDerivativeMode, + freeMemory: u8, + width: ::std::os::raw::c_uint, + additionalArg: Option<&Type>, + typeInfo: CFnTypeInfo, + _uncacheable_args: *const u8, + uncacheable_args_size: size_t, + augmented: EnzymeAugmentedReturnPtr, + ) -> &'a Value; +} +pub type CustomRuleType = ::std::option::Option< + unsafe extern "C" fn( + direction: ::std::os::raw::c_int, + ret: CTypeTreeRef, + args: *mut CTypeTreeRef, + known_values: *mut IntList, + num_args: size_t, + fnc: &Value, + ta: *const ::std::os::raw::c_void, + ) -> u8, +>; +extern "C" { + pub fn CreateTypeAnalysis( + Log: EnzymeLogicRef, + customRuleNames: *mut *mut ::std::os::raw::c_char, + customRules: *mut CustomRuleType, + numRules: size_t, + ) -> EnzymeTypeAnalysisRef; +} +extern "C" { + pub fn ClearTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn FreeTypeAnalysis(arg1: EnzymeTypeAnalysisRef); +} +extern "C" { + pub fn CreateEnzymeLogic(PostOpt: u8) -> EnzymeLogicRef; +} +extern "C" { + pub fn ClearEnzymeLogic(arg1: EnzymeLogicRef); +} +extern "C" { + pub fn FreeEnzymeLogic(arg1: EnzymeLogicRef); +} + +extern "C" { + fn EnzymeNewTypeTreeCT(arg1: CConcreteType, ctx: &Context) -> CTypeTreeRef; + fn EnzymeNewTypeTreeTR(arg1: CTypeTreeRef) -> CTypeTreeRef; + fn EnzymeMergeTypeTree(arg1: CTypeTreeRef, arg2: CTypeTreeRef) -> bool; + fn EnzymeTypeTreeOnlyEq(arg1: CTypeTreeRef, pos: i64); + fn EnzymeTypeTreeData0Eq(arg1: CTypeTreeRef); + fn EnzymeTypeTreeShiftIndiciesEq( + arg1: CTypeTreeRef, + data_layout: *const c_char, + offset: i64, + max_size: i64, + add_offset: u64, + ); + fn EnzymeTypeTreeToStringFree(arg1: *const c_char); + fn EnzymeTypeTreeToString(arg1: CTypeTreeRef) -> *const c_char; +} + +pub struct TypeTree { + pub inner: CTypeTreeRef, +} + +impl TypeTree { + pub fn new() -> TypeTree { + let inner = unsafe { EnzymeNewTypeTree() }; + + TypeTree { inner } + } + + #[must_use] + pub fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree { + let inner = unsafe { EnzymeNewTypeTreeCT(t, ctx) }; + + TypeTree { inner } + } + + #[must_use] + pub fn only(self, idx: isize) -> TypeTree { + unsafe { + EnzymeTypeTreeOnlyEq(self.inner, idx as i64); + } + self + } + + #[must_use] + pub fn data0(self) -> TypeTree { + unsafe { + EnzymeTypeTreeData0Eq(self.inner); + } + self + } + + pub fn merge(self, other: Self) -> Self { + unsafe { + EnzymeMergeTypeTree(self.inner, other.inner); + } + drop(other); + + self + } + + #[must_use] + pub fn shift(self, layout: &str, offset: isize, max_size: isize, add_offset: usize) -> Self { + let layout = CString::new(layout).unwrap(); + + unsafe { + EnzymeTypeTreeShiftIndiciesEq( + self.inner, + layout.as_ptr(), + offset as i64, + max_size as i64, + add_offset as u64, + ) + } + + self + } +} + +impl Clone for TypeTree { + fn clone(&self) -> Self { + let inner = unsafe { EnzymeNewTypeTreeTR(self.inner) }; + TypeTree { inner } + } +} + +impl fmt::Display for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let ptr = unsafe { EnzymeTypeTreeToString(self.inner) }; + let cstr = unsafe { CStr::from_ptr(ptr) }; + match cstr.to_str() { + Ok(x) => write!(f, "{}", x)?, + Err(err) => write!(f, "could not parse: {}", err)?, + } + + // delete C string pointer + unsafe { EnzymeTypeTreeToStringFree(ptr) } + + Ok(()) + } +} + +impl fmt::Debug for TypeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} + +impl Drop for TypeTree { + fn drop(&mut self) { + unsafe { EnzymeFreeTypeTree(self.inner) } + } +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 4f5cc575da6e5..2cd84daec7b07 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -1,5 +1,7 @@ #![allow(non_snake_case)] +//use crate::enzyme::LLVMValueRef; + pub use self::AtomicRmwBinOp::*; pub use self::CallConv::*; pub use self::CodeGenOptSize::*; @@ -31,6 +33,12 @@ impl LLVMRustResult { } } +// pub fn GetNamedFunction<'ll>(name: &str) -> &'ll LLVMValueRef { +// unsafe { +// LLVMRustGetN +// } +// } + pub fn AddFunctionAttributes<'ll>(llfn: &'ll Value, idx: AttributePlace, attrs: &[&'ll Attribute]) { unsafe { LLVMRustAddFunctionAttributes(llfn, idx.as_uint(), attrs.as_ptr(), attrs.len()); diff --git a/compiler/rustc_codegen_llvm/src/typetree.rs b/compiler/rustc_codegen_llvm/src/typetree.rs new file mode 100644 index 0000000000000..091ddaa3cf213 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/typetree.rs @@ -0,0 +1,33 @@ +use crate::llvm; +use rustc_middle::middle::typetree::{Kind, TypeTree}; + +pub fn to_enzyme_typetree( + tree: TypeTree, + llvm_data_layout: &str, + llcx: &llvm::Context, +) -> llvm::TypeTree { + tree.0.iter().fold(llvm::TypeTree::new(), |obj, x| { + let scalar = match x.kind { + Kind::Integer => llvm::CConcreteType::DT_Integer, + Kind::Float => llvm::CConcreteType::DT_Float, + Kind::Double => llvm::CConcreteType::DT_Double, + Kind::Pointer => llvm::CConcreteType::DT_Pointer, + _ => panic!("Unknown kind {:?}", x.kind), + }; + + let tt = llvm::TypeTree::from_type(scalar, llcx).only(-1); + + let tt = if !x.child.0.is_empty() { + let inner_tt = to_enzyme_typetree(x.child.clone(), llvm_data_layout, llcx); + tt.merge(inner_tt.only(-1)) + } else { + tt + }; + + if x.offset != -1 { + obj.merge(tt.shift(llvm_data_layout, 0, x.size as isize, x.offset as usize)) + } else { + obj.merge(tt) + } + }) +} diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index cb6244050df24..f27b09c8146f3 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,9 +1,11 @@ use super::write::CodegenContext; +use crate::back::write::ModuleConfig; use crate::traits::*; use crate::ModuleCodegen; -use rustc_data_structures::memmap::Mmap; +use rustc_data_structures::{fx::FxHashMap, memmap::Mmap}; use rustc_errors::FatalError; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use std::ffi::CString; use std::sync::Arc; @@ -76,6 +78,27 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat { ref module, .. } => { + //let module = module.take().unwrap(); + { + B::autodiff(cgcx, &module, diff_fncs, typetrees, config)?; + } + }, + _ => {}, + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 8f2f829c17c1c..2d8cecb4fc36c 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -315,7 +315,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); for (mono_item, &(linkage, visibility)) in cgus.iter().flat_map(|cgu| cgu.items().iter()) { if linkage != Linkage::External { diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index c323372bda42d..bc11b4e644693 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -25,6 +25,7 @@ use rustc_incremental::{ }; use rustc_metadata::EncodedMetadata; use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; use rustc_middle::middle::exported_symbols::SymbolExportInfo; use rustc_middle::ty::TyCtxt; use rustc_session::cgu_reuse_tracker::CguReuseTracker; @@ -118,6 +119,7 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub enzyme_print_activity: bool, } impl ModuleConfig { @@ -194,6 +196,7 @@ impl ModuleConfig { false ), + enzyme_print_activity: sess.opts.unstable_opts.enzyme_print_activity, sanitizer: if_regular!(sess.opts.unstable_opts.sanitizer, SanitizerSet::empty()), sanitizer_recover: if_regular!( sess.opts.unstable_opts.sanitizer_recover, @@ -376,8 +379,10 @@ impl CodegenContext { } } -fn generate_lto_work( - cgcx: &CodegenContext, +fn generate_lto_work<'tcx, B: ExtraBackendMethods>( + cgcx: &'tcx CodegenContext, + autodiff: Vec, + typetrees: FxHashMap, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -386,8 +391,15 @@ fn generate_lto_work( let (lto_modules, copy_jobs) = if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let lto_module = + + let mut lto_module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + lto_module = unsafe { lto_module.autodiff(cgcx, autodiff, typetrees, config).unwrap() }; + } + (vec![lto_module], vec![]) } else { assert!(needs_fat_lto.is_empty()); @@ -968,6 +980,7 @@ pub enum Message { module_data: SerializedModule, work_product: WorkProduct, }, + AddAutoDiffItems(Vec), CodegenComplete, CodegenItem, CodegenAborted, @@ -1251,6 +1264,8 @@ fn start_executing_work( let mut needs_link = Vec::new(); let mut needs_fat_lto = Vec::new(); let mut needs_thin_lto = Vec::new(); + let mut autodiff_items = Vec::new(); + let mut typetrees = FxHashMap::::default(); let mut lto_import_only_modules = Vec::new(); let mut started_lto = false; let mut codegen_aborted = false; @@ -1346,9 +1361,14 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + typetrees.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1452,13 +1472,24 @@ fn start_executing_work( Err(e) => { let msg = &format!("failed to acquire jobserver token: {}", e); shared_emitter.fatal(msg); + // Exit the coordinator thread + //panic!("{}", msg) codegen_done = true; codegen_aborted = true; } } } - Message::CodegenDone { llvm_work_item, cost } => { + Message::CodegenDone { mut llvm_work_item, cost } => { + //// extract build typetrees + match &mut llvm_work_item { + WorkItem::Optimize(module) => { + let tt = B::typetrees(&mut module.module_llvm); + typetrees.extend(tt); + } + _ => {}, + } + // We keep the queue sorted by estimated processing cost, // so that more expensive items are processed earlier. This // is good for throughput as it gives the main thread more @@ -1496,6 +1527,9 @@ fn start_executing_work( codegen_done = true; codegen_aborted = true; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } Message::Done { result: Ok(compiled_module), worker_id } => { free_worker(worker_id); match compiled_module.kind { @@ -1895,7 +1929,7 @@ impl OngoingCodegen { sess.abort_if_errors(); panic!("expected abort due to worker thread errors") } - Err(_) => { + Err(_err) => { bug!("panic during codegen/LLVM phase"); } }); @@ -1914,6 +1948,7 @@ impl OngoingCodegen { self.backend.print_pass_timings() } + // HERE ( CodegenResults { metadata: self.metadata, @@ -1946,6 +1981,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } @@ -1970,6 +2009,7 @@ pub fn submit_codegened_module_to_llvm( module: ModuleCodegen, cost: u64, ) { + // BLUB let llvm_work_item = WorkItem::Optimize(module); drop(tx_to_llvm_workers.send(Box::new(Message::CodegenDone:: { llvm_work_item, cost }))); } @@ -2021,8 +2061,8 @@ fn msvc_imps_needed(tcx: TyCtxt<'_>) -> bool { tcx.sess.target.is_like_windows && tcx.sess.crate_types().iter().any(|ct| *ct == CrateType::Rlib) && - // ThinLTO can't handle this workaround in all cases, so we don't - // emit the `__imp_` symbols. Instead we make them unnecessary by disallowing - // dynamic linking when linker plugin LTO is enabled. - !tcx.sess.opts.cg.linker_plugin_lto.enabled() + // ThinLTO can't handle this workaround in all cases, so we don't + // emit the `__imp_` symbols. Instead we make them unnecessary by disallowing + // dynamic linking when linker plugin LTO is enabled. + !tcx.sess.opts.cg.linker_plugin_lto.enabled() } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index ae45ae9d802c8..83f1f974b8998 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -584,7 +584,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -652,6 +653,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -965,7 +970,7 @@ pub fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index 8dae5dab42972..e33f9970e1e8e 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,10 +1,11 @@ -use rustc_ast::{ast, MetaItemKind, NestedMetaItem}; +use rustc_ast::{ast, MetaItem, MetaItemKind, NestedMetaItem}; use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_errors::struct_span_err; use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_hir::def_id::{DefId, LocalDefId, LOCAL_CRATE}; use rustc_hir::{lang_items, weak_lang_items::WEAK_LANG_ITEMS, LangItem}; +use rustc_middle::middle::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}; use rustc_middle::mir::mono::Linkage; use rustc_middle::ty::query::Providers; @@ -13,6 +14,7 @@ use rustc_session::{lint, parse::feature_err}; use rustc_span::symbol::Ident; use rustc_span::{sym, Span}; use rustc_target::spec::{abi, SanitizerSet}; +use std::str::FromStr; use crate::errors; use crate::target_features::from_target_feature; @@ -649,6 +651,162 @@ fn check_link_name_xor_ordinal( } } +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::autodiff_into); + + let attrs = attrs + .into_iter() + .filter(|attr| attr.name_or_empty() == sym::autodiff_into) + .collect::>(); + + // check for exactly one autodiff attribute on extern block + let attr = match &attrs[..] { + &[] => return AutoDiffAttrs::inactive(), + &[elm] => elm, + x => { + tcx.sess + .struct_span_err(x[1].span, "autodiff attribute can only be applied once") + .span_label(x[1].span, "more than one") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs { + mode: DiffMode::Source, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + }; + } + + let mode = match &list[0] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "attribute must contain autodiff mode") + .span_label(attr.span, "empty argument list") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + // parse mode + let mode = match mode.as_str() { + //map(|x| x.as_str()) { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + _ => { + tcx.sess + .struct_span_err(attr.span, "mode should be either forward or reverse") + .span_label(attr.span, "invalid mode") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_symbol = match &list[1] { + NestedMetaItem::MetaItem(MetaItem { path: ref p2, kind: MetaItemKind::Word, .. }) => { + p2.segments.first().unwrap().ident + } + _ => { + tcx.sess + .struct_span_err(attr.span, "autodiff attribute must contain the return activity") + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + let mut arg_activities: Vec = vec![]; + for arg in &list[2..] { + let arg_symbol = match arg { + NestedMetaItem::MetaItem(MetaItem { + path: ref p2, kind: MetaItemKind::Word, .. + }) => p2.segments.first().unwrap().ident, + _ => { + tcx.sess + .struct_span_err( + attr.span, + "autodiff attribute must contain the return activity", + ) + .span_label(attr.span, "missing return activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.sess + .struct_span_err(attr.span, "unknown return activity") + .span_label(attr.span, "invalid input activity") + .emit(); + + return AutoDiffAttrs::inactive(); + } + } + } + + if mode == DiffMode::Forward { + if ret_activity == DiffActivity::Active { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active ret") + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 { + tcx.sess + .struct_span_err(attr.span, "Forward Mode is incompatible with Active args") + .span_label(attr.span, "invalid input activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + if mode == DiffMode::Reverse { + if ret_activity == DiffActivity::Duplicated + || ret_activity == DiffActivity::DuplicatedNoNeed + { + tcx.sess + .struct_span_err( + attr.span, + "Reverse Mode is only compatible with Active, None, or Const ret", + ) + .span_label(attr.span, "invalid return activity") + .emit(); + return AutoDiffAttrs::inactive(); + } + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/traits/misc.rs b/compiler/rustc_codegen_ssa/src/traits/misc.rs index 04e2b8796c46a..5f64dd3367661 100644 --- a/compiler/rustc_codegen_ssa/src/traits/misc.rs +++ b/compiler/rustc_codegen_ssa/src/traits/misc.rs @@ -19,4 +19,5 @@ pub trait MiscMethods<'tcx>: BackendTypes { fn apply_target_cpu_attr(&self, llfn: Self::Function); /// Declares the extern "C" main function for the entry point. Returns None if the symbol already exists. fn declare_c_main(&self, fn_type: Self::Type) -> Option; + fn create_autodiff(&self) -> Vec; } diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index 9826256a4c5d5..9b319c8189e50 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -2,8 +2,10 @@ use crate::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use crate::back::write::{CodegenContext, FatLTOInput, ModuleConfig}; use crate::{CompiledModule, ModuleCodegen}; +use rustc_data_structures::fx::FxHashMap; use rustc_errors::{FatalError, Handler}; use rustc_middle::dep_graph::WorkProduct; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; pub trait WriteBackendMethods: 'static + Sized + Clone { type Module: Send + Sync; @@ -12,6 +14,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -57,6 +60,15 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result; fn prepare_thin(module: ModuleCodegen) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + typetrees: FxHashMap, + config: &ModuleConfig, + ) -> Result<(), FatalError>; + fn typetrees(module: &mut Self::Module) -> FxHashMap; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index fe05d4590e7a6..7b8138bddc367 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -351,6 +351,13 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ ungated!(used, Normal, template!(Word, List: "compiler|linker"), WarnFollowing, @only_local: true), ungated!(link_ordinal, Normal, template!(List: "ordinal"), ErrorPreceding), + // Autodiff + ungated!( + autodiff_into, Normal, + template!(Word, List: r#""...""#), + DuplicatesOk, + ), + // Limits: ungated!(recursion_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), ungated!(type_length_limit, CrateLevel, template!(NameValueStr: "N"), FutureWarnFollowing), diff --git a/compiler/rustc_incremental/src/assert_module_sources.rs b/compiler/rustc_incremental/src/assert_module_sources.rs index c550e553bb032..f2320e4fae9ed 100644 --- a/compiler/rustc_incremental/src/assert_module_sources.rs +++ b/compiler/rustc_incremental/src/assert_module_sources.rs @@ -40,7 +40,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>) { } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let ams = AssertModuleSource { tcx, available_cgus }; diff --git a/compiler/rustc_interface/src/queries.rs b/compiler/rustc_interface/src/queries.rs index 6483d51a0b9a9..c86d57d443778 100644 --- a/compiler/rustc_interface/src/queries.rs +++ b/compiler/rustc_interface/src/queries.rs @@ -375,6 +375,9 @@ impl Linker { } let _timer = sess.prof.verbose_generic_activity("link_crate"); + + // FINAL CALL? + // self.codegen_backend.link(&self.sess, codegen_results, &self.prepare_outputs) } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 1bae771e373d4..b024d1bd589aa 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -744,6 +744,7 @@ fn test_unstable_options_tracking_hash() { tracked!(dep_info_omit_d_target, true); tracked!(drop_tracking, true); tracked!(dual_proc_macros, true); + tracked!(enzyme_print_activity, false); tracked!(dwarf_version, Some(5)); tracked!(emit_thin_lto, false); tracked!(export_executable_symbols, true); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 49acd71b3e106..08d1949863ed1 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -90,6 +90,24 @@ extern "C" char *LLVMRustGetLastError(void) { return Ret; } +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +// Enzyme +// extern "C" bool LLVMRustIsNull(LLVMValueRef V) { +// Value *Val = unwrap(V); +// if (Constant *C = dyn_cast(Val)) +// return C->isNullValue(); +// return false; +// } +// extern "C" LLVMValueRef LLVMRustGetNamedFunction(LLVMModuleRef M, +// const char *Name) { +// Module *Mod = unwrap(M); +// return wrap(Mod->getFunction(Name)); +// } + extern "C" void LLVMRustSetLastError(const char *Err) { free((void *)LastError); LastError = strdup(Err); diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 6a1a2a061ddd6..f8cf0efc5b01b 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -93,6 +93,7 @@ macro_rules! arena_types { [] upvars_mentioned: rustc_data_structures::fx::FxIndexMap, [] object_safety_violations: rustc_middle::traits::ObjectSafetyViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, + [] autodiff_item: rustc_middle::middle::autodiff_attrs::AutoDiffItem, [decode] attribute: rustc_ast::Attribute, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, diff --git a/compiler/rustc_middle/src/middle/autodiff_attrs.rs b/compiler/rustc_middle/src/middle/autodiff_attrs.rs new file mode 100644 index 0000000000000..2412df725fe2b --- /dev/null +++ b/compiler/rustc_middle/src/middle/autodiff_attrs.rs @@ -0,0 +1,94 @@ +use crate::middle::typetree::TypeTree; +use std::str::FromStr; + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffMode { + Inactive, + Source, + Forward, + Reverse, +} + +#[allow(dead_code)] +#[derive(Clone, Copy, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum DiffActivity { + None, + Active, + Const, + Duplicated, + DuplicatedNoNeed, +} + +impl FromStr for DiffActivity { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "None" => Ok(DiffActivity::None), + "Active" => Ok(DiffActivity::Active), + "Const" => Ok(DiffActivity::Const), + "Duplicated" => Ok(DiffActivity::Duplicated), + "DuplicatedNoNeed" => Ok(DiffActivity::DuplicatedNoNeed), + _ => Err(()), + } + } +} + +#[allow(dead_code)] +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffAttrs { + pub mode: DiffMode, + pub ret_activity: DiffActivity, + pub input_activity: Vec, +} + +impl AutoDiffAttrs { + pub fn inactive() -> Self { + AutoDiffAttrs { + mode: DiffMode::Inactive, + ret_activity: DiffActivity::None, + input_activity: Vec::new(), + } + } + + pub fn is_active(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + _ => true, + } + } + + pub fn is_source(&self) -> bool { + match self.mode { + DiffMode::Source => true, + _ => false, + } + } + pub fn apply_autodiff(&self) -> bool { + match self.mode { + DiffMode::Inactive => false, + DiffMode::Source => false, + _ => true, + } + } + + pub fn into_item( + self, + source: String, + target: String, + inputs: Vec, + output: TypeTree, + ) -> AutoDiffItem { + AutoDiffItem { source, target, inputs, output, attrs: self } + } +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct AutoDiffItem { + pub source: String, + pub target: String, + pub attrs: AutoDiffAttrs, + pub inputs: Vec, + pub output: TypeTree, +} diff --git a/compiler/rustc_middle/src/middle/mod.rs b/compiler/rustc_middle/src/middle/mod.rs index 9c25f3009ba81..d0ed4df3707d7 100644 --- a/compiler/rustc_middle/src/middle/mod.rs +++ b/compiler/rustc_middle/src/middle/mod.rs @@ -1,3 +1,4 @@ +pub mod autodiff_attrs; pub mod codegen_fn_attrs; pub mod dependency_format; pub mod exported_symbols; @@ -31,6 +32,7 @@ pub mod privacy; pub mod region; pub mod resolve_bound_vars; pub mod stability; +pub mod typetree; pub fn provide(providers: &mut crate::ty::query::Providers) { limits::provide(providers); diff --git a/compiler/rustc_middle/src/middle/typetree.rs b/compiler/rustc_middle/src/middle/typetree.rs new file mode 100644 index 0000000000000..4049d32540bd2 --- /dev/null +++ b/compiler/rustc_middle/src/middle/typetree.rs @@ -0,0 +1,39 @@ +use std::fmt; +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub enum Kind { + Anything, + Integer, + Pointer, + Half, + Float, + Double, + Unknown, +} + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct TypeTree(pub Vec); + +#[derive(Clone, Eq, PartialEq, TyEncodable, TyDecodable, HashStable, Debug)] +pub struct Type { + pub offset: isize, + pub size: usize, + pub kind: Kind, + pub child: TypeTree, +} + +impl Type { + pub fn add_offset(self, add: isize) -> Self { + let offset = match self.offset { + -1 => add, + x => add + x, + }; + + Self { size: self.size, kind: self.kind, child: self.child, offset } + } +} + +impl fmt::Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + ::fmt(self, f) + } +} diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index 28a9c1eef1a6d..0b362d61eb3b4 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -180,6 +180,10 @@ impl EraseType for (&'_ T0, &'_ [T1]) { type Result = [u8; size_of::<(&'static (), &'static [()])>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} + macro_rules! trivial { ($($ty:ty),+ $(,)?) => { $( diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index 9fad2816b0d84..d0fe588628b39 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -1132,6 +1132,13 @@ rustc_queries! { separate_provide_extern } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1769,7 +1776,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs index 47cf48f46cf89..a5073b8ea29ec 100644 --- a/compiler/rustc_middle/src/ty/layout.rs +++ b/compiler/rustc_middle/src/ty/layout.rs @@ -1,5 +1,6 @@ use crate::fluent_generated as fluent; use crate::middle::codegen_fn_attrs::CodegenFnAttrFlags; +//use crate::middle::autodiff_attrs::AutoDiffAttrs; use crate::ty::normalize_erasing_regions::NormalizationError; use crate::ty::{self, ReprOptions, Ty, TyCtxt, TypeVisitableExt}; use rustc_errors::{DiagnosticBuilder, Handler, IntoDiagnostic}; diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index f882f54d62811..120947e36711f 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -175,6 +175,8 @@ pub struct ResolverGlobalCtxt { /// Mapping from ident span to path span for paths that don't exist as written, but that /// exist under `std`. For example, wrote `str::from_utf8` instead of `std::str::from_utf8`. pub confused_type_with_std_module: FxHashMap, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, pub doc_link_resolutions: FxHashMap, pub doc_link_traits_in_scope: FxHashMap>, pub all_macro_rules: FxHashMap>, diff --git a/compiler/rustc_middle/src/ty/query.rs b/compiler/rustc_middle/src/ty/query.rs index 07d47cae5ee93..50edda0f192cf 100644 --- a/compiler/rustc_middle/src/ty/query.rs +++ b/compiler/rustc_middle/src/ty/query.rs @@ -5,6 +5,7 @@ use crate::dep_graph::DepKind; use crate::infer::canonical::{self, Canonical}; use crate::lint::LintExpectation; use crate::metadata::ModChild; +use crate::middle::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use crate::middle::codegen_fn_attrs::CodegenFnAttrs; use crate::middle::exported_symbols::{ExportedSymbol, SymbolExportInfo}; use crate::middle::lib_features::LibFeatures; diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index 6d3a3bf906ebf..7c35f65750db6 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -19,3 +19,4 @@ rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } rustc_target = { path = "../rustc_target" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 55c937b305a49..8f26f46d22220 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -218,8 +218,8 @@ pub struct InliningMap<'tcx> { // Maps a source mono item to the range of mono items // accessed by it. // The range selects elements within the `targets` vecs. - index: FxHashMap, Range>, - targets: Vec>, + pub index: FxHashMap, Range>, + pub targets: Vec>, // Contains one bit per mono item in the `targets` field. That bit // is true if that mono item needs to be inlined into every CGU. @@ -1299,6 +1299,7 @@ impl<'v> RootCollector<'_, 'v> { /// monomorphized copy of the start lang item based on /// the return type of `main`. This is not needed when /// the user writes their own `start` manually. + /// TODO: remove annotations after automatic differentation pass fn push_extra_entry_roots(&mut self) { let Some((main_def_id, EntryFnType::Main { .. })) = self.entry_fn else { return; diff --git a/compiler/rustc_monomorphize/src/partitioning/default.rs b/compiler/rustc_monomorphize/src/partitioning/default.rs index 37b7f6bf8a8fc..ebabe54fe0e31 100644 --- a/compiler/rustc_monomorphize/src/partitioning/default.rs +++ b/compiler/rustc_monomorphize/src/partitioning/default.rs @@ -70,13 +70,20 @@ impl<'tcx> Partition<'tcx> for DefaultPartitioning { .or_insert_with(|| CodegenUnit::new(codegen_unit_name)); let mut can_be_internalized = true; + let (linkage, visibility) = mono_item_linkage_and_visibility( cx.tcx, &mono_item, &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + //dbg!(&characteristic_def_id); + let autodiff_active = characteristic_def_id + .map(|x| cx.tcx.autodiff_attrs(x).is_active()) + .unwrap_or(false); + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } diff --git a/compiler/rustc_monomorphize/src/partitioning/mod.rs b/compiler/rustc_monomorphize/src/partitioning/mod.rs index 993e35c7fd251..19c95aa075a49 100644 --- a/compiler/rustc_monomorphize/src/partitioning/mod.rs +++ b/compiler/rustc_monomorphize/src/partitioning/mod.rs @@ -103,14 +103,17 @@ use std::path::{Path, PathBuf}; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_data_structures::sync; use rustc_hir::def_id::{DefIdSet, LOCAL_CRATE}; +use rustc_middle::middle::autodiff_attrs::AutoDiffItem; +use rustc_middle::middle::typetree::{Kind, Type, TypeTree}; use rustc_middle::mir; use rustc_middle::mir::mono::MonoItem; use rustc_middle::mir::mono::{CodegenUnit, Linkage}; use rustc_middle::ty::print::with_no_trimmed_paths; use rustc_middle::ty::query::Providers; -use rustc_middle::ty::TyCtxt; +use rustc_middle::ty::{ParamEnv, TyCtxt}; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_span::symbol::Symbol; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use crate::collector::InliningMap; use crate::collector::{self, MonoItemCollectionMode}; @@ -415,7 +418,10 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +fn collect_and_partition_mono_items( + tcx: TyCtxt<'_>, + (): (), +) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_mode = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -479,6 +485,49 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co }) .collect(); + let autodiff_items = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .filter_map(|(item, instance)| { + let target_id = instance.def_id(); + let target_attrs = tcx.autodiff_attrs(target_id); + if !target_attrs.apply_autodiff() { + return None; + } + + let target_symbol = + symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + let range = inlining_map.index.get(&item).unwrap(); + + let source = inlining_map.targets[range.clone()] + .into_iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + + if tcx.autodiff_attrs(source_id).is_active() { + return Some(instance_s); + } + + None + } + _ => None, + }) + .next(); + + source.map(|inst| { + let (inputs, output) = fnc_typetrees(inst.ty(tcx, ParamEnv::empty()), tcx); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + target_attrs.clone().into_item(symb, target_symbol, inputs, output) + }) + }); + + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -539,7 +588,7 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co } } - (tcx.arena.alloc(mono_items), codegen_units) + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s @@ -620,7 +669,7 @@ fn dump_mono_items_stats<'tcx>( } fn codegened_and_inlined_items(tcx: TyCtxt<'_>, (): ()) -> &DefIdSet { - let (items, cgus) = tcx.collect_and_partition_mono_items(()); + let (items, _, cgus) = tcx.collect_and_partition_mono_items(()); let mut visited = DefIdSet::default(); let mut result = items.clone(); @@ -648,17 +697,173 @@ fn codegened_and_inlined_items(tcx: TyCtxt<'_>, (): ()) -> &DefIdSet { tcx.arena.alloc(result) } +use rustc_middle::ty::{self, Adt, ParamEnvAnd, Ty}; +use rustc_target::abi::FieldsShape; +use std::iter; + +pub fn typetree_empty() -> TypeTree { + TypeTree(vec![]) +} + +pub fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize) -> TypeTree { + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + + let inner_ty = ty.builtin_deref(true).unwrap().ty; + let child = typetree_from_ty(inner_ty, tcx, depth + 1); + + let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; + //println!("{:depth$} add indirection {:?}", "", tt); + + return TypeTree(vec![tt]); + } + + if ty.is_scalar() { + assert!(!ty.is_any_ptr()); + + let (kind, size) = if ty.is_integral() { + (Kind::Integer, 8) + } else { + assert!(ty.is_floating_point()); + match ty { + x if x == tcx.types.f32 => (Kind::Float, 4), + x if x == tcx.types.f64 => (Kind::Double, 8), + _ => panic!("floatTy scalar that is neither f32 nor f64"), + } + }; + + return TypeTree(vec![Type { offset: -1, child: typetree_empty(), kind, size }]); + } + + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; + + let layout = tcx.layout_of(param_env_and); + assert!(layout.is_ok()); + + let layout = layout.unwrap().layout; + let fields = layout.fields(); + let max_size = layout.size(); + + if ty.is_adt() { + let adt_def = ty.ty_adt_def().unwrap(); + let substs = match ty.kind() { + Adt(_, subst_ref) => subst_ref, + _ => panic!(""), + }; + + if adt_def.is_struct() { + let (offsets, _memory_index) = match fields { + FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), + _ => panic!(""), + }; + //println!("{:depth$} combine fields", ""); + + let fields = adt_def.all_fields(); + let fields = fields + .into_iter() + .zip(offsets.into_iter()) + .filter_map(|(field, offset)| { + let field_ty: Ty<'_> = field.ty(tcx, substs); + let field_ty: Ty<'_> = + tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); + + if field_ty.is_phantom_data() { + return None; + } + + let mut child = typetree_from_ty(field_ty, tcx, depth + 1).0; + + for c in &mut child { + if c.offset == -1 { + c.offset = offset.bytes() as isize + } else { + c.offset += offset.bytes() as isize; + } + } + + //inner_tt.offset = offset; + + //println!("{:depth$} -> {:?}", "", child); + + Some(child) + }) + .flatten() + .collect::>(); + + let ret_tt = TypeTree(fields); + //println!("{:depth$} into {:?}", "", ret_tt); + return ret_tt; + } else { + unimplemented!("adt that isn't a struct"); + } + } + + if ty.is_array() { + let (stride, count) = match fields { + FieldsShape::Array { stride: s, count: c } => (s, c), + _ => panic!(""), + }; + let byte_stride = stride.bytes_usize(); + let byte_max_size = max_size.bytes_usize(); + + assert!(byte_stride * *count as usize == byte_max_size); + assert!(*count > 0); // return empty TT for empty? + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + // calculate size of subtree + let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; + let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; + let tt = TypeTree( + iter::repeat(subtt) + .take(*count as usize) + .enumerate() + .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) + .flatten() + .collect(), + ); + + //println!("{:depth$} repeated array into {:?}", "", tt); + + return tt; + } + + if ty.is_slice() { + let sub_ty = ty.builtin_index().unwrap(); + let subtt = typetree_from_ty(sub_ty, tcx, depth + 1); + + return subtt; + } + + //println!("Warning: create empty typetree for {}", ty); + typetree_empty() +} + +pub fn fnc_typetrees<'tcx>(fn_ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> (Vec, TypeTree) { + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // TODO: verify. + let x: ty::FnSig<'_> = fnc_binder.skip_binder(); + + let inputs = x.inputs().into_iter().map(|x| typetree_from_ty(*x, tcx, 0)).collect(); + + let output = typetree_from_ty(x.output(), tcx, 0); + + (inputs, output) +} pub fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.codegened_and_inlined_items = codegened_and_inlined_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/src/check_attr.rs b/compiler/rustc_passes/src/check_attr.rs index 06aa273791526..36ebe3bca95e7 100644 --- a/compiler/rustc_passes/src/check_attr.rs +++ b/compiler/rustc_passes/src/check_attr.rs @@ -220,6 +220,7 @@ impl CheckAttrVisitor<'_> { self.check_generic_attr(hir_id, attr, target, Target::Fn); self.check_proc_macro(hir_id, target, ProcMacroKind::Derive) } + sym::autodiff_into => self.check_autodiff(hir_id, attr, span, target), _ => {} } @@ -2264,6 +2265,20 @@ impl CheckAttrVisitor<'_> { self.abort.set(true); } } + + /// Checks if `#[autodiff]` is applied to an item other than a foreign module. + fn check_autodiff(&self, _hir_id: HirId, _attr: &Attribute, _span: Span, _target: Target) { + //match target { + // Target::ForeignMod => {} + // _ => { + // self.tcx + // .sess + // .struct_span_err(attr.span, "attribute should be applied to an `extern` block") + // .span_label(span, "not an `extern` block") + // .emit(); + // } + //} + } } impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> { diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 590609f9ed3db..6b5707d4f3cb7 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1411,6 +1411,7 @@ impl<'a, 'tcx> Resolver<'a, 'tcx> { trait_impls: self.trait_impls, proc_macros, confused_type_with_std_module, + autodiff_map: Default::default(), doc_link_resolutions: self.doc_link_resolutions, doc_link_traits_in_scope: self.doc_link_traits_in_scope, all_macro_rules: self.all_macro_rules, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 5976b9aa3e74a..72913d5f7bfbf 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -1442,6 +1442,8 @@ options! { "enables LTO for dylib crate type"), emit_stack_sizes: bool = (false, parse_bool, [UNTRACKED], "emit a section containing stack size metadata (default: no)"), + enzyme_print_activity: bool = (false, parse_bool, [TRACKED], + "print type trees for functions passed to enzyme"), emit_thin_lto: bool = (true, parse_bool, [TRACKED], "emit the bc module with thin LTO info (default: yes)"), export_executable_symbols: bool = (false, parse_bool, [TRACKED], diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 60efcb768cb07..61fea4e446d55 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -407,6 +407,7 @@ symbols! { attributes, augmented_assignments, auto_traits, + autodiff_into, automatically_derived, avx, avx512_target_feature, @@ -957,6 +958,7 @@ symbols! { miri, misc, mmx_reg, + mode, modifiers, module, module_path, diff --git a/config.example.toml b/config.example.toml index d0eaa9fd7ffac..04e5c2c629da2 100644 --- a/config.example.toml +++ b/config.example.toml @@ -135,6 +135,9 @@ changelog-seen = 2 # Whether or not to specify `-DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=YES` #allow-old-toolchain = false +# Whether to build enzyme +#enzyme = false + # Whether to include the Polly optimizer. #polly = false diff --git a/config.toml.example b/config.toml.example new file mode 100644 index 0000000000000..f09b35ed59a77 --- /dev/null +++ b/config.toml.example @@ -0,0 +1,748 @@ +# Sample TOML configuration file for building Rust. +# +# To configure rustbuild, copy this file to the directory from which you will be +# running the build, and name it config.toml. +# +# All options are commented out by default in this file, and they're commented +# out with their default values. The build system by default looks for +# `config.toml` in the current directory of a build for build configuration, but +# a custom configuration file can also be specified with `--config` to the build +# system. + +# Keeps track of the last version of `x.py` used. +# If it does not match the version that is currently running, +# `x.py` will prompt you to update it and read the changelog. +# See `src/bootstrap/CHANGELOG.md` for more information. +changelog-seen = 2 + +# ============================================================================= +# Global Settings +# ============================================================================= + +# Use different pre-set defaults than the global defaults. +# +# See `src/bootstrap/defaults` for more information. +# Note that this has no default value (x.py uses the defaults in `config.toml.example`). +#profile = + +# ============================================================================= +# Tweaking how LLVM is compiled +# ============================================================================= +[llvm] + +# Whether to use Rust CI built LLVM instead of locally building it. +# +# Unless you're developing for a target where Rust CI doesn't build a compiler +# toolchain or changing LLVM locally, you probably want to set this to true. +# +# This is false by default so that distributions don't unexpectedly download +# LLVM from the internet. +# +# All tier 1 targets are currently supported; set this to `"if-available"` if +# you are not sure whether you're on a tier 1 target. +# +# We also currently only support this when building LLVM for the build triple. +# +# Note that many of the LLVM options are not currently supported for +# downloading. Currently only the "assertions" option can be toggled. +#download-ci-llvm = false + +# Indicates whether LLVM rebuild should be skipped when running bootstrap. If +# this is `false` then the compiler's LLVM will be rebuilt whenever the built +# version doesn't have the correct hash. If it is `true` then LLVM will never +# be rebuilt. The default value is `false`. +#skip-rebuild = false + +# Indicates whether the LLVM build is a Release or Debug build +#optimize = true + +# Indicates whether LLVM should be built with ThinLTO. Note that this will +# only succeed if you use clang, lld, llvm-ar, and llvm-ranlib in your C/C++ +# toolchain (see the `cc`, `cxx`, `linker`, `ar`, and `ranlib` options below). +# More info at: https://clang.llvm.org/docs/ThinLTO.html#clang-bootstrap +#thin-lto = false + +# Indicates whether an LLVM Release build should include debug info +#release-debuginfo = false + +# Indicates whether we should build the LLVM Plugin Enzyme +#enzyme = true + +# Indicates whether the LLVM assertions are enabled or not +#assertions = false + +# Indicates whether the LLVM testsuite is enabled in the build or not. Does +# not execute the tests as part of the build as part of x.py build et al, +# just makes it possible to do `ninja check-llvm` in the staged LLVM build +# directory when doing LLVM development as part of Rust development. +#tests = false + +# Indicates whether the LLVM plugin is enabled or not +#plugins = false + +# Indicates whether ccache is used when building LLVM +#ccache = false +# or alternatively ... +#ccache = "/path/to/ccache" + +# If an external LLVM root is specified, we automatically check the version by +# default to make sure it's within the range that we're expecting, but setting +# this flag will indicate that this version check should not be done. +#version-check = true + +# Link libstdc++ statically into the rustc_llvm instead of relying on a +# dynamic version to be available. +#static-libstdcpp = true + +# Whether to use Ninja to build LLVM. This runs much faster than make. +#ninja = true + +# LLVM targets to build support for. +# Note: this is NOT related to Rust compilation targets. However, as Rust is +# dependent on LLVM for code generation, turning targets off here WILL lead to +# the resulting rustc being unable to compile for the disabled architectures. +# Also worth pointing out is that, in case support for new targets are added to +# LLVM, enabling them here doesn't mean Rust is automatically gaining said +# support. You'll need to write a target specification at least, and most +# likely, teach rustc about the C ABI of the target. Get in touch with the +# Rust team and file an issue if you need assistance in porting! +#targets = "AArch64;ARM;BPF;Hexagon;MSP430;Mips;NVPTX;PowerPC;RISCV;Sparc;SystemZ;WebAssembly;X86" + +# LLVM experimental targets to build support for. These targets are specified in +# the same format as above, but since these targets are experimental, they are +# not built by default and the experimental Rust compilation targets that depend +# on them will not work unless the user opts in to building them. +#experimental-targets = "AVR;M68k" + +# Cap the number of parallel linker invocations when compiling LLVM. +# This can be useful when building LLVM with debug info, which significantly +# increases the size of binaries and consequently the memory required by +# each linker process. +# If absent or 0, linker invocations are treated like any other job and +# controlled by rustbuild's -j parameter. +#link-jobs = 0 + +# When invoking `llvm-config` this configures whether the `--shared` argument is +# passed to prefer linking to shared libraries. +# NOTE: `thin-lto = true` requires this to be `true` and will give an error otherwise. +#link-shared = false + +# When building llvm, this configures what is being appended to the version. +# The default is "-rust-$version-$channel", except for dev channel where rustc +# version number is omitted. To use LLVM version as is, provide an empty string. +#version-suffix = "-rust-dev" + +# On MSVC you can compile LLVM with clang-cl, but the test suite doesn't pass +# with clang-cl, so this is special in that it only compiles LLVM with clang-cl. +# Note that this takes a /path/to/clang-cl, not a boolean. +#clang-cl = cc + +# Pass extra compiler and linker flags to the LLVM CMake build. +#cflags = "" +#cxxflags = "" +#ldflags = "" + +# Use libc++ when building LLVM instead of libstdc++. This is the default on +# platforms already use libc++ as the default C++ library, but this option +# allows you to use libc++ even on platforms when it's not. You need to ensure +# that your host compiler ships with libc++. +#use-libcxx = false + +# The value specified here will be passed as `-DLLVM_USE_LINKER` to CMake. +#use-linker = (path) + +# Whether or not to specify `-DLLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN=YES` +#allow-old-toolchain = false + +# Whether to include the Polly optimizer. +#polly = false + +# Whether to build the clang compiler. +#clang = false + +# Custom CMake defines to set when building LLVM. +#build-config = {} + +# ============================================================================= +# General build configuration options +# ============================================================================= +[build] +# The default stage to use for the `check` subcommand +#check-stage = 0 + +# The default stage to use for the `doc` subcommand +#doc-stage = 0 + +# The default stage to use for the `build` subcommand +#build-stage = 1 + +# The default stage to use for the `test` subcommand +#test-stage = 1 + +# The default stage to use for the `dist` subcommand +#dist-stage = 2 + +# The default stage to use for the `install` subcommand +#install-stage = 2 + +# The default stage to use for the `bench` subcommand +#bench-stage = 2 + +# Build triple for the original snapshot compiler. This must be a compiler that +# nightlies are already produced for. The current platform must be able to run +# binaries of this build triple and the nightly will be used to bootstrap the +# first compiler. +# +# Defaults to platform where `x.py` is run. +#build = "x86_64-unknown-linux-gnu" (as an example) + +# Which triples to produce a compiler toolchain for. Each of these triples will +# be bootstrapped from the build triple themselves. +# +# Defaults to just the build triple. +#host = ["x86_64-unknown-linux-gnu"] (as an example) + +# Which triples to build libraries (core/alloc/std/test/proc_macro) for. Each of +# these triples will be bootstrapped from the build triple themselves. +# +# Defaults to `host`. If you set this explicitly, you likely want to add all +# host triples to this list as well in order for those host toolchains to be +# able to compile programs for their native target. +#target = ["x86_64-unknown-linux-gnu"] (as an example) + +# Use this directory to store build artifacts. +# You can use "$ROOT" to indicate the root of the git repository. +#build-dir = "build" + +# Instead of downloading the src/stage0.json version of Cargo specified, use +# this Cargo binary instead to build all Rust code +#cargo = "/path/to/cargo" + +# Instead of downloading the src/stage0.json version of the compiler +# specified, use this rustc binary instead as the stage0 snapshot compiler. +#rustc = "/path/to/rustc" + +# Instead of download the src/stage0.json version of rustfmt specified, +# use this rustfmt binary instead as the stage0 snapshot rustfmt. +#rustfmt = "/path/to/rustfmt" + +# Flag to specify whether any documentation is built. If false, rustdoc and +# friends will still be compiled but they will not be used to generate any +# documentation. +#docs = true + +# Flag to specify whether CSS, JavaScript, and HTML are minified when +# docs are generated. JSON is always minified, because it's enormous, +# and generated in already-minified form from the beginning. +#docs-minification = true + +# Indicate whether the compiler should be documented in addition to the standard +# library and facade crates. +#compiler-docs = false + +# Indicate whether git submodules are managed and updated automatically. +#submodules = true + +# Update git submodules only when the checked out commit in the submodules differs +# from what is committed in the main rustc repo. +#fast-submodules = true + +# The path to (or name of) the GDB executable to use. This is only used for +# executing the debuginfo test suite. +#gdb = "gdb" + +# The node.js executable to use. Note that this is only used for the emscripten +# target when running tests, otherwise this can be omitted. +#nodejs = "node" + +# Python interpreter to use for various tasks throughout the build, notably +# rustdoc tests, the lldb python interpreter, and some dist bits and pieces. +# +# Defaults to the Python interpreter used to execute x.py +#python = "python" + +# Force Cargo to check that Cargo.lock describes the precise dependency +# set that all the Cargo.toml files create, instead of updating it. +#locked-deps = false + +# Indicate whether the vendored sources are used for Rust dependencies or not +#vendor = false + +# Typically the build system will build the Rust compiler twice. The second +# compiler, however, will simply use its own libraries to link against. If you +# would rather to perform a full bootstrap, compiling the compiler three times, +# then you can set this option to true. You shouldn't ever need to set this +# option to true. +#full-bootstrap = false + +# Enable a build of the extended Rust tool set which is not only the compiler +# but also tools such as Cargo. This will also produce "combined installers" +# which are used to install Rust and Cargo together. This is disabled by +# default. The `tools` option (immediately below) specifies which tools should +# be built if `extended = true`. +#extended = false + +# Installs chosen set of extended tools if `extended = true`. By default builds +# all extended tools except `rust-demangler`, unless the target is also being +# built with `profiler = true`. If chosen tool failed to build the installation +# fails. If `extended = false`, this option is ignored. +#tools = ["cargo", "rls", "clippy", "rustfmt", "analysis", "src"] # + "rust-demangler" if `profiler` + +# Verbosity level: 0 == not verbose, 1 == verbose, 2 == very verbose +#verbose = 0 + +# Build the sanitizer runtimes +#sanitizers = false + +# Build the profiler runtime (required when compiling with options that depend +# on this runtime, such as `-C profile-generate` or `-C instrument-coverage`). +#profiler = false + +# Indicates whether the native libraries linked into Cargo will be statically +# linked or not. +#cargo-native-static = false + +# Run the build with low priority, by setting the process group's "nice" value +# to +10 on Unix platforms, and by using a "low priority" job object on Windows. +#low-priority = false + +# Arguments passed to the `./configure` script, used during distcheck. You +# probably won't fill this in but rather it's filled in by the `./configure` +# script. +#configure-args = [] + +# Indicates that a local rebuild is occurring instead of a full bootstrap, +# essentially skipping stage0 as the local compiler is recompiling itself again. +#local-rebuild = false + +# Print out how long each rustbuild step took (mostly intended for CI and +# tracking over time) +#print-step-timings = false + +# Print out resource usage data for each rustbuild step, as defined by the Unix +# struct rusage. (Note that this setting is completely unstable: the data it +# captures, what platforms it supports, the format of its associated output, and +# this setting's very existence, are all subject to change.) +#print-step-rusage = false + +# Always patch binaries for usage with Nix toolchains. If `true` then binaries +# will be patched unconditionally. If `false` or unset, binaries will be patched +# only if the current distribution is NixOS. This option is useful when using +# a Nix toolchain on non-NixOS distributions. +#patch-binaries-for-nix = false + +# ============================================================================= +# General install configuration options +# ============================================================================= +[install] + +# Instead of installing to /usr/local, install to this path instead. +#prefix = "/usr/local" + +# Where to install system configuration files +# If this is a relative path, it will get installed in `prefix` above +#sysconfdir = "/etc" + +# Where to install documentation in `prefix` above +#docdir = "share/doc/rust" + +# Where to install binaries in `prefix` above +#bindir = "bin" + +# Where to install libraries in `prefix` above +#libdir = "lib" + +# Where to install man pages in `prefix` above +#mandir = "share/man" + +# Where to install data in `prefix` above +#datadir = "share" + +# ============================================================================= +# Options for compiling Rust code itself +# ============================================================================= +[rust] + +# Whether or not to optimize the compiler and standard library. +# WARNING: Building with optimize = false is NOT SUPPORTED. Due to bootstrapping, +# building without optimizations takes much longer than optimizing. Further, some platforms +# fail to build without this optimization (c.f. #65352). +#optimize = true + +# Indicates that the build should be configured for debugging Rust. A +# `debug`-enabled compiler and standard library will be somewhat +# slower (due to e.g. checking of debug assertions) but should remain +# usable. +# +# Note: If this value is set to `true`, it will affect a number of +# configuration options below as well, if they have been left +# unconfigured in this file. +# +# Note: changes to the `debug` setting do *not* affect `optimize` +# above. In theory, a "maximally debuggable" environment would +# set `optimize` to `false` above to assist the introspection +# facilities of debuggers like lldb and gdb. To recreate such an +# environment, explicitly set `optimize` to `false` and `debug` +# to `true`. In practice, everyone leaves `optimize` set to +# `true`, because an unoptimized rustc with debugging +# enabled becomes *unusably slow* (e.g. rust-lang/rust#24840 +# reported a 25x slowdown) and bootstrapping the supposed +# "maximally debuggable" environment (notably libstd) takes +# hours to build. +# +#debug = false + +# Whether to download the stage 1 and 2 compilers from CI. +# This is mostly useful for tools; if you have changes to `compiler/` they will be ignored. +# +# You can set this to "if-unchanged" to only download if `compiler/` has not been modified. +#download-rustc = false + +# Number of codegen units to use for each compiler invocation. A value of 0 +# means "the number of cores on this machine", and 1+ is passed through to the +# compiler. +# +# Uses the rustc defaults: https://doc.rust-lang.org/rustc/codegen-options/index.html#codegen-units +#codegen-units = if incremental { 256 } else { 16 } + +# Sets the number of codegen units to build the standard library with, +# regardless of what the codegen-unit setting for the rest of the compiler is. +# NOTE: building with anything other than 1 is known to occasionally have bugs. +# See https://github.com/rust-lang/rust/issues/83600. +#codegen-units-std = codegen-units + +# Whether or not debug assertions are enabled for the compiler and standard +# library. Debug assertions control the maximum log level used by rustc. When +# enabled calls to `trace!` and `debug!` macros are preserved in the compiled +# binary, otherwise they are omitted. +# +# Defaults to rust.debug value +#debug-assertions = rust.debug (boolean) + +# Whether or not debug assertions are enabled for the standard library. +# Overrides the `debug-assertions` option, if defined. +# +# Defaults to rust.debug-assertions value +#debug-assertions-std = rust.debug-assertions (boolean) + +# Whether or not to leave debug! and trace! calls in the rust binary. +# Overrides the `debug-assertions` option, if defined. +# +# Defaults to rust.debug-assertions value +# +# If you see a message from `tracing` saying +# `max_level_info` is enabled and means logging won't be shown, +# set this value to `true`. +#debug-logging = rust.debug-assertions (boolean) + +# Whether or not overflow checks are enabled for the compiler and standard +# library. +# +# Defaults to rust.debug value +#overflow-checks = rust.debug (boolean) + +# Whether or not overflow checks are enabled for the standard library. +# Overrides the `overflow-checks` option, if defined. +# +# Defaults to rust.overflow-checks value +#overflow-checks-std = rust.overflow-checks (boolean) + +# Debuginfo level for most of Rust code, corresponds to the `-C debuginfo=N` option of `rustc`. +# `0` - no debug info +# `1` - line tables only - sufficient to generate backtraces that include line +# information and inlined functions, set breakpoints at source code +# locations, and step through execution in a debugger. +# `2` - full debug info with variable and type information +# Can be overridden for specific subsets of Rust code (rustc, std or tools). +# Debuginfo for tests run with compiletest is not controlled by this option +# and needs to be enabled separately with `debuginfo-level-tests`. +# +# Note that debuginfo-level = 2 generates several gigabytes of debuginfo +# and will slow down the linking process significantly. +# +# Defaults to 1 if debug is true +#debuginfo-level = 0 + +# Debuginfo level for the compiler. +#debuginfo-level-rustc = debuginfo-level + +# Debuginfo level for the standard library. +#debuginfo-level-std = debuginfo-level + +# Debuginfo level for the tools. +#debuginfo-level-tools = debuginfo-level + +# Debuginfo level for the test suites run with compiletest. +# FIXME(#61117): Some tests fail when this option is enabled. +#debuginfo-level-tests = 0 + +# Whether to run `dsymutil` on Apple platforms to gather debug info into .dSYM +# bundles. `dsymutil` adds time to builds for no clear benefit, and also makes +# it more difficult for debuggers to find debug info. The compiler currently +# defaults to running `dsymutil` to preserve its historical default, but when +# compiling the compiler itself, we skip it by default since we know it's safe +# to do so in that case. +#run-dsymutil = false + +# Whether or not `panic!`s generate backtraces (RUST_BACKTRACE) +#backtrace = true + +# Whether to always use incremental compilation when building rustc +#incremental = false + +# Build a multi-threaded rustc +# FIXME(#75760): Some UI tests fail when this option is enabled. +#parallel-compiler = false + +# The default linker that will be hard-coded into the generated +# compiler for targets that don't specify a default linker explicitly +# in their target specifications. Note that this is not the linker +# used to link said compiler. It can also be set per-target (via the +# `[target.]` block), which may be useful in a cross-compilation +# setting. +# +# See https://doc.rust-lang.org/rustc/codegen-options/index.html#linker for more information. +#default-linker = (path) + +# The "channel" for the Rust build to produce. The stable/beta channels only +# allow using stable features, whereas the nightly and dev channels allow using +# nightly features +#channel = "dev" + +# A descriptive string to be appended to `rustc --version` output, which is +# also used in places like debuginfo `DW_AT_producer`. This may be useful for +# supplementary build information, like distro-specific package versions. +#description = (string) + +# The root location of the musl installation directory. The library directory +# will also need to contain libunwind.a for an unwinding implementation. Note +# that this option only makes sense for musl targets that produce statically +# linked binaries. +# +# Defaults to /usr on musl hosts. Has no default otherwise. +#musl-root = (path) + +# By default the `rustc` executable is built with `-Wl,-rpath` flags on Unix +# platforms to ensure that the compiler is usable by default from the build +# directory (as it links to a number of dynamic libraries). This may not be +# desired in distributions, for example. +#rpath = true + +# Prints each test name as it is executed, to help debug issues in the test harness itself. +#verbose-tests = false + +# Flag indicating whether tests are compiled with optimizations (the -O flag). +#optimize-tests = true + +# Flag indicating whether codegen tests will be run or not. If you get an error +# saying that the FileCheck executable is missing, you may want to disable this. +# Also see the target's llvm-filecheck option. +#codegen-tests = true + +# Flag indicating whether git info will be retrieved from .git automatically. +# Having the git information can cause a lot of rebuilds during development. +# Note: If this attribute is not explicitly set (e.g. if left commented out) it +# will default to true if channel = "dev", but will default to false otherwise. +#ignore-git = if channel == "dev" { true } else { false } + +# When creating source tarballs whether or not to create a source tarball. +#dist-src = true + +# After building or testing extended tools (e.g. clippy and rustfmt), append the +# result (broken, compiling, testing) into this JSON file. +#save-toolstates = (path) + +# This is an array of the codegen backends that will be compiled for the rustc +# that's being compiled. The default is to only build the LLVM codegen backend, +# and currently the only standard options supported are `"llvm"`, `"cranelift"` +# and `"gcc"`. The first backend in this list will be used as default by rustc +# when no explicit backend is specified. +#codegen-backends = ["llvm"] + +# Indicates whether LLD will be compiled and made available in the sysroot for +# rustc to execute. +#lld = false + +# Indicates whether LLD will be used to link Rust crates during bootstrap on +# supported platforms. The LLD from the bootstrap distribution will be used +# and not the LLD compiled during the bootstrap. +# +# LLD will not be used if we're cross linking. +# +# Explicitly setting the linker for a target will override this option when targeting MSVC. +#use-lld = false + +# Indicates whether some LLVM tools, like llvm-objdump, will be made available in the +# sysroot. +#llvm-tools = false + +# Whether to deny warnings in crates +#deny-warnings = true + +# Print backtrace on internal compiler errors during bootstrap +#backtrace-on-ice = false + +# Whether to verify generated LLVM IR +#verify-llvm-ir = false + +# Compile the compiler with a non-default ThinLTO import limit. This import +# limit controls the maximum size of functions imported by ThinLTO. Decreasing +# will make code compile faster at the expense of lower runtime performance. +#thin-lto-import-instr-limit = if incremental { 10 } else { LLVM default (currently 100) } + +# Map debuginfo paths to `/rust/$sha/...`, generally only set for releases +#remap-debuginfo = false + +# Link the compiler against `jemalloc`, where on Linux and OSX it should +# override the default allocator for rustc and LLVM. +#jemalloc = false + +# Run tests in various test suites with the "nll compare mode" in addition to +# running the tests in normal mode. Largely only used on CI and during local +# development of NLL +#test-compare-mode = false + +# Use LLVM libunwind as the implementation for Rust's unwinder. +# Accepted values are 'in-tree' (formerly true), 'system' or 'no' (formerly false). +# This option only applies for Linux and Fuchsia targets. +# On Linux target, if crt-static is not enabled, 'no' means dynamic link to +# `libgcc_s.so`, 'in-tree' means static link to the in-tree build of llvm libunwind +# and 'system' means dynamic link to `libunwind.so`. If crt-static is enabled, +# the behavior is depend on the libc. On musl target, 'no' and 'in-tree' both +# means static link to the in-tree build of llvm libunwind, and 'system' means +# static link to `libunwind.a` provided by system. Due to the limitation of glibc, +# it must link to `libgcc_eh.a` to get a working output, and this option have no effect. +#llvm-libunwind = 'no' + +# Enable Windows Control Flow Guard checks in the standard library. +# This only applies from stage 1 onwards, and only for Windows targets. +#control-flow-guard = false + +# Enable symbol-mangling-version v0. This can be helpful when profiling rustc, +# as generics will be preserved in symbols (rather than erased into opaque T). +# When no setting is given, the new scheme will be used when compiling the +# compiler and its tools and the legacy scheme will be used when compiling the +# standard library. +# If an explicit setting is given, it will be used for all parts of the codebase. +#new-symbol-mangling = true|false (see comment) + +# ============================================================================= +# Options for specific targets +# +# Each of the following options is scoped to the specific target triple in +# question and is used for determining how to compile each target. +# ============================================================================= +[target.x86_64-unknown-linux-gnu] + +# C compiler to be used to compile C code. Note that the +# default value is platform specific, and if not specified it may also depend on +# what platform is crossing to what platform. +# See `src/bootstrap/cc_detect.rs` for details. +#cc = "cc" (path) + +# C++ compiler to be used to compile C++ code (e.g. LLVM and our LLVM shims). +# This is only used for host targets. +# See `src/bootstrap/cc_detect.rs` for details. +#cxx = "c++" (path) + +# Archiver to be used to assemble static libraries compiled from C/C++ code. +# Note: an absolute path should be used, otherwise LLVM build will break. +#ar = "ar" (path) + +# Ranlib to be used to assemble static libraries compiled from C/C++ code. +# Note: an absolute path should be used, otherwise LLVM build will break. +#ranlib = "ranlib" (path) + +# Linker to be used to bootstrap Rust code. Note that the +# default value is platform specific, and if not specified it may also depend on +# what platform is crossing to what platform. +# Setting this will override the `use-lld` option for Rust code when targeting MSVC. +#linker = "cc" (path) + +# Path to the `llvm-config` binary of the installation of a custom LLVM to link +# against. Note that if this is specified we don't compile LLVM at all for this +# target. +#llvm-config = (path) + +# Normally the build system can find LLVM's FileCheck utility, but if +# not, you can specify an explicit file name for it. +#llvm-filecheck = "/path/to/llvm-version/bin/FileCheck" + +# If this target is for Android, this option will be required to specify where +# the NDK for the target lives. This is used to find the C compiler to link and +# build native code. +# See `src/bootstrap/cc_detect.rs` for details. +#android-ndk = (path) + +# Build the sanitizer runtimes for this target. +# This option will override the same option under [build] section. +#sanitizers = build.sanitizers (bool) + +# Build the profiler runtime for this target(required when compiling with options that depend +# on this runtime, such as `-C profile-generate` or `-C instrument-coverage`). +# This option will override the same option under [build] section. +#profiler = build.profiler (bool) + +# Force static or dynamic linkage of the standard library for this target. If +# this target is a host for rustc, this will also affect the linkage of the +# compiler itself. This is useful for building rustc on targets that normally +# only use static libraries. If unset, the target's default linkage is used. +#crt-static = (bool) + +# The root location of the musl installation directory. The library directory +# will also need to contain libunwind.a for an unwinding implementation. Note +# that this option only makes sense for musl targets that produce statically +# linked binaries. +#musl-root = build.musl-root (path) + +# The full path to the musl libdir. +#musl-libdir = musl-root/lib + +# The root location of the `wasm32-wasi` sysroot. Only used for the +# `wasm32-wasi` target. If you are building wasm32-wasi target, make sure to +# create a `[target.wasm32-wasi]` section and move this field there. +#wasi-root = (path) + +# Used in testing for configuring where the QEMU images are located, you +# probably don't want to use this. +#qemu-rootfs = (path) + +# ============================================================================= +# Distribution options +# +# These options are related to distribution, mostly for the Rust project itself. +# You probably won't need to concern yourself with any of these options +# ============================================================================= +[dist] + +# This is the folder of artifacts that the build system will sign. All files in +# this directory will be signed with the default gpg key using the system `gpg` +# binary. The `asc` and `sha256` files will all be output into the standard dist +# output folder (currently `build/dist`) +# +# This folder should be populated ahead of time before the build system is +# invoked. +#sign-folder = (path) + +# The remote address that all artifacts will eventually be uploaded to. The +# build system generates manifests which will point to these urls, and for the +# manifests to be correct they'll have to have the right URLs encoded. +# +# Note that this address should not contain a trailing slash as file names will +# be appended to it. +#upload-addr = (URL) + +# Whether to build a plain source tarball to upload +# We disable that on Windows not to override the one already uploaded on S3 +# as the one built on Windows will contain backslashes in paths causing problems +# on linux +#src-tarball = true + +# Whether to allow failures when building tools +#missing-tools = false + +# List of compression formats to use when generating dist tarballs. The list of +# formats is provided to rust-installer, which must support all of them. +# +# This list must be non-empty. +#compression-formats = ["gz", "xz"] diff --git a/library/autodiff/Cargo.toml b/library/autodiff/Cargo.toml new file mode 100644 index 0000000000000..cbbff8d375e3d --- /dev/null +++ b/library/autodiff/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "autodiff" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + + +[profile.release] +lto = "fat" + +[profile.dev] +lto = "fat" + +[lib] +name = "autodiff" +proc-macro = true + +[dependencies] +quote = "1.0" +proc-macro2 = "1" +proc-macro-error = "1" +syn = { version = "1", features = ["extra-traits", "full", "visit", "visit-mut"]} + +[dev-dependencies] +macrotest = "1" +trybuild = "1" +ndarray = "0.15" diff --git a/library/autodiff/examples/array.rs b/library/autodiff/examples/array.rs new file mode 100644 index 0000000000000..60c6b63fd84cb --- /dev/null +++ b/library/autodiff/examples/array.rs @@ -0,0 +1,23 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active, Duplicated)] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} + +fn main() { + let arr = [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]; + let mut d_arr = [[[0.0; 2]; 2]; 2]; + + d_array(&arr, &mut d_arr, 1.0); + + dbg!(&d_arr); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/box.rs b/library/autodiff/examples/box.rs new file mode 100644 index 0000000000000..5d4f114830bf4 --- /dev/null +++ b/library/autodiff/examples/box.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(cos_box, Reverse, Active, Duplicated)] +fn sin(x: &Box) -> f32 { + f32::sin(**x) +} + +fn main() { + let x = Box::::new(3.14); + let mut df_dx = Box::::new(0.0); + cos_box(&x, &mut df_dx, 1.0); + + dbg!(&df_dx); + + assert!(*df_dx == f32::cos(*x)); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/broken_matvec.rs b/library/autodiff/examples/broken_matvec.rs new file mode 100644 index 0000000000000..0c4b2cfe6e927 --- /dev/null +++ b/library/autodiff/examples/broken_matvec.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +type Matrix = Vec>; +type Vector = Vec; + +#[autodiff(d_matvec, Forward, Const)] +fn matvec(#[dup] mat: &Matrix, vec: &Vector, #[dup] out: &mut Vector) { + for i in 0..mat.len() - 1 { + for j in 0..mat[0].len() - 1 { + out[i] += mat[i][j] * vec[j]; + } + } +} + +fn main() { + let mat = vec![vec![1.0, 1.0], vec![1.0, 1.0]]; + let mut d_mat = vec![vec![0.0, 0.0], vec![0.0, 0.0]]; + let inp = vec![1.0, 1.0]; + let mut out = vec![0.0, 0.0]; + let mut out_tang = vec![0.0, 1.0]; + + //matvec(&mat, &inp, &mut out); + d_matvec(&mat, &mut d_mat, &inp, &mut out, &mut out_tang); + + dbg!(&out); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/hessian_sin.rs b/library/autodiff/examples/hessian_sin.rs new file mode 100644 index 0000000000000..6b1e776476fd2 --- /dev/null +++ b/library/autodiff/examples/hessian_sin.rs @@ -0,0 +1,28 @@ +use autodiff::autodiff; + +fn sin(x: &Vec, y: &mut f32) { + *y = x.into_iter().map(|x| f32::sin(*x)).sum() +} + +#[autodiff(sin, Reverse, Const, Duplicated, Duplicated)] +fn jac(x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +#[autodiff(jac, Forward, Const, Duplicated, Const, Const, Const)] +fn hessian(x: &Vec, y_x: &Vec, d_x: &mut Vec, y: &mut f32, y_t: &f32); + +fn main() { + let inp = vec![3.1415 / 2., 1.0, 0.5]; + let mut d_inp = vec![0.0, 0.0, 0.0]; + let mut y = 0.0; + let tang = vec![1.0, 0.0, 0.0]; + hessian(&inp, &tang, &mut d_inp, &mut y, &1.0); + dbg!(&d_inp); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/ndarray.rs b/library/autodiff/examples/ndarray.rs new file mode 100644 index 0000000000000..34402c43cb3e6 --- /dev/null +++ b/library/autodiff/examples/ndarray.rs @@ -0,0 +1,25 @@ +use autodiff::autodiff; + +use ndarray::Array1; + +#[autodiff(d_collect, Reverse, Active)] +fn collect(#[dup] x: &Array1) -> f32 { + x[0] +} + +fn main() { + let a = Array1::zeros(19); + let mut d_a = Array1::zeros(19); + + d_collect(&a, &mut d_a, 1.0); + + dbg!(&d_a); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd.rs b/library/autodiff/examples/rosenbrock_fwd.rs new file mode 100644 index 0000000000000..a3ab7a47578d0 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_fwd_iter.rs b/library/autodiff/examples/rosenbrock_fwd_iter.rs new file mode 100644 index 0000000000000..1648014392f19 --- /dev/null +++ b/library/autodiff/examples/rosenbrock_fwd_iter.rs @@ -0,0 +1,34 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Forward, DuplicatedNoNeed)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + (0..x.len() - 1) + .map(|i| { + let (a, b) = (x[i + 1] - x[i] * x[i], x[i] - 1.0); + 100.0 * a * a + b * b + }) + .sum() +} + +fn main() { + let x = [3.14f64, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let df_dx = d_rosenbrock(&x, &[1.0, 0.0]); + let df_dy = d_rosenbrock(&x, &[0.0, 1.0]); + + dbg!(&df_dx, &df_dy); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx - 9373.54).abs() < 0.1); + assert!((df_dy - (-1491.92)).abs() < 0.1); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/rosenbrock_rev.rs b/library/autodiff/examples/rosenbrock_rev.rs new file mode 100644 index 0000000000000..b4ce00b5afe9d --- /dev/null +++ b/library/autodiff/examples/rosenbrock_rev.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +#[autodiff(d_rosenbrock, Reverse, Active)] +fn rosenbrock(#[dup] x: &[f64; 2]) -> f64 { + let mut res = 0.0; + for i in 0..(x.len() - 1) { + let a = x[i + 1] - x[i] * x[i]; + let b = x[i] - 1.0; + res += 100.0 * a * a + b * b; + } + res +} + +fn main() { + let x = [3.14, 2.4]; + let output = rosenbrock(&x); + println!("{output}"); + + let mut df_dx = [0.0f64; 2]; + d_rosenbrock(&x, &mut df_dx, 1.0); + + // https://www.wolframalpha.com/input?i2d=true&i=x%3D3.14%3B+y%3D2.4%3B+D%5Brosenbrock+function%5C%2840%29x%5C%2844%29+y%5C%2841%29+%2Cy%5D + assert!((df_dx[0] - 9373.54).abs() < 0.01); + assert!((df_dx[1] - (-1491.92)).abs() < 0.01); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sin.rs b/library/autodiff/examples/sin.rs new file mode 100644 index 0000000000000..1655b1e7ecd09 --- /dev/null +++ b/library/autodiff/examples/sin.rs @@ -0,0 +1,36 @@ +use autodiff::autodiff; + +#[autodiff(cos_inplace, Reverse, Const)] +fn sin_inplace(#[dup] x: &f32, #[dup] y: &mut f32) { + *y = x.sin(); +} + + +fn main() { + // Here we can use ==, even though we work on f32. + // Enzyme will recognize the sin function and replace it with llvm's cos function (see below). + // Calling f32::cos directly will also result in calling llvm's cos function. + let a = 3.1415; + let mut da = 0.0; + let mut y = 0.0; + cos_inplace(&a, &mut da, &mut y, &mut 1.0); + + dbg!(&a, &da, &y); + assert!(da - f32::cos(a) == 0.0); +} + +// Just for curious readers, this is the (inner) function that Enzyme does generate: +// define internal { float } @diffe_ZN3sin3sin17h18f17f71fe94e58fE(float %0, float %1) unnamed_addr #35 { +// %3 = call fast float @llvm.cos.f32(float %0) +// %4 = fmul fast float %1, %3 +// %5 = insertvalue { float } undef, float %4, 0 +// ret { float } %5 +// } + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/sqrt.rs b/library/autodiff/examples/sqrt.rs new file mode 100644 index 0000000000000..d15c6f5ec2051 --- /dev/null +++ b/library/autodiff/examples/sqrt.rs @@ -0,0 +1,21 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} + +fn main() { + let mut d_b = 0.0; + + let (d_a, d_d) = d_sqrt(1.0, &1.0, &mut d_b, &1.0, 1.0, 1.0); + dbg!(d_a, d_b, d_d); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples/struct.rs b/library/autodiff/examples/struct.rs new file mode 100644 index 0000000000000..1235307fdbcbf --- /dev/null +++ b/library/autodiff/examples/struct.rs @@ -0,0 +1,33 @@ +use autodiff::autodiff; + +use std::io; + +// Will be represented as {f32, i16, i16} when passed by reference +// will be represented as i64 if passed by value +struct Foo { + c1: i16, + a: f32, + c2: i16, +} + +#[autodiff(cos, Reverse, Active, Duplicated)] +fn sin(x: &Foo) -> f32 { + assert!(x.c1 < x.c2); + f32::sin(x.a) +} + +fn main() { + let mut s = String::new(); + println!("Please enter a value for c1"); + io::stdin().read_line(&mut s).unwrap(); + let c2 = s.trim_end().parse::().unwrap(); + dbg!(c2); + + let foo = Foo { c1: 4, a: 3.14, c2 }; + let mut df_dfoo = Foo { c1: 4, a: 0.0, c2 }; + + dbg!(df_dfoo.a); + dbg!(cos(&foo, &mut df_dfoo, 1.0)); + dbg!(df_dfoo.a); + dbg!(f32::cos(foo.a)); +} diff --git a/library/autodiff/examples/vec.rs b/library/autodiff/examples/vec.rs new file mode 100644 index 0000000000000..e82618fac4dac --- /dev/null +++ b/library/autodiff/examples/vec.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sum, Forward, Duplicated)] +fn sum(#[dup] x: &Vec>) -> f32 { + x.into_iter().map(|x| x.into_iter().map(|x| x.sqrt())).flatten().sum() +} + +fn main() { + let a = vec![vec![1.0, 2.0, 4.0, 8.0]]; + //let mut b = vec![vec![0.0, 0.0, 0.0, 0.0]]; + let b = vec![vec![1.0, 0.0, 0.0, 0.0]]; + + dbg!(&d_sum(&a, &b)); + + dbg!(&b); +} + +#[cfg(test)] +mod tests { + #[test] + fn main() { + super::main() + } +} diff --git a/library/autodiff/examples_broken/biquad.rs b/library/autodiff/examples_broken/biquad.rs new file mode 100644 index 0000000000000..7689b1cd1fc51 --- /dev/null +++ b/library/autodiff/examples_broken/biquad.rs @@ -0,0 +1,54 @@ +use autodiff::autodiff; + +#[derive(Debug)] +struct Biquad { + coeffs: [[f32; 5]; N], +} + +impl Biquad { + pub fn new() -> Self { + Biquad { coeffs: [[0.0; 5]; N] } + } + + pub fn process(&self, samples: &[f32], target: &[f32]) -> f32 { + // do some horrible inefficient biquad filtering + let mut samples = samples.to_vec(); + let mut samples_out = vec![0.0; samples.len()]; + + for coeff_set in self.coeffs { + for idx in 0..samples.len() { + samples_out[idx] = coeff_set[0] * samples[idx]; + + if idx > 0 { + samples_out[idx] += coeff_set[1] * samples[idx - 1] - + coeff_set[3] * samples_out[idx - 1]; + } + if idx > 1 { + samples_out[idx] += coeff_set[2] * samples[idx - 2] - + coeff_set[4] * samples_out[idx - 2]; + } + } + + (samples, samples_out) = (samples_out, samples); + } + + samples_out.into_iter().zip(target.into_iter()).map(|(a, b)| a - b).sum() + } + + #[autodiff(Self::process, Reverse, Active)] + pub fn deriv(#[dup] &self, params: &mut Self, samples: &[f32], target: &[f32], ret_adj: f32); +} + +fn main() { + let biquad = Biquad::<10>::new(); + let mut dbiquad = Biquad::<10>::new(); + + // create ramp and pulse train + let signal = (0..1024).map(|x| (x as f32) / 1024.0).collect::>(); + let target = (0..1024).map(|x| if x % 2 == 0 { 0.0 } else { 1.0 }).collect::>(); + + dbg!(&biquad.process(&signal, &target)); + biquad.deriv(&mut dbiquad, &signal, &target, 1.0); + + dbg!(&dbiquad); +} diff --git a/library/autodiff/examples_broken/broken_iter.rs b/library/autodiff/examples_broken/broken_iter.rs new file mode 100644 index 0000000000000..16d205f7373c8 --- /dev/null +++ b/library/autodiff/examples_broken/broken_iter.rs @@ -0,0 +1,20 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; +use std::ptr; + +#[autodiff(sin_vec, Reverse, Active)] +fn cos_vec(#[dup] x: &Vec) -> f32 { + // uses enum internally and breaks + let res = x.into_iter().collect::>(); + + *res[0] +} + +fn main() { + let x = vec![1.0, 1.0, 1.0]; + let mut d_x = vec![0.0; 3]; + + sin_vec(&x, &mut d_x, 1.0); + + dbg!(&d_x, &x); +} diff --git a/library/autodiff/examples_broken/broken_recursive.rs b/library/autodiff/examples_broken/broken_recursive.rs new file mode 100644 index 0000000000000..a1f3ff25eb511 --- /dev/null +++ b/library/autodiff/examples_broken/broken_recursive.rs @@ -0,0 +1,66 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +// TODO: As seen by the bloated code generated for the iterative version, +// we definetly have to disable unroll, slpvec, loop-vec before AD. +// We also should check if we have other opts that Julia, C++, Fortran etc. don't have +// and which could make our input code more "complex". +// We then however have to start doing whole-module opt after AD to re-include them, +// instead of just using enzyme to optimize the generated function. + +#[autodiff(d_power_recursive, Forward, DuplicatedNoNeed)] +fn power_recursive(#[dup] a: f64, n: i32) -> f64 { + if n == 0 { + return 1.0; + } + return a * power_recursive(a, n - 1); +} + +#[autodiff(d_power_iterative, Reverse, DuplicatedNoNeed)] +fn power_iterative(#[active] a: f64, n: i32) -> f64 { + let mut res = 1.0; + for _ in 0..n { + res *= a; + } + res +} + +fn main() { + // d/dx x^n = n * x^(n-1) + let n = 4; + let nf = n as f64; + let a = 1.337; + assert!(power_recursive(a, n) == power_iterative(a, n)); + let dpr = d_power_recursive(a, 1.0, n); + let dpi = d_power_iterative(a, n, 1.0); + let control = nf * a.powi(n - 1); + dbg!(dpr); + dbg!(dpi); + dbg!(control); + assert!(dpr == control); + assert!(dpi == control); +} + +// Again, for the curious. We can find n * x^(n-1) nicely in the LLVM-IR +// +// define internal double @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE(double %0, double %1, i32 %2) unnamed_addr #8 { +// => if (n == 0) goto 5: and return 0. Correct, since for n==0 we have 0 * x ^ (0-1) = 0 +// => if (n != 0) goto 7: +// %4 = icmp eq i32 %2, 0 +// br i1 %4, label %5, label %7 +// +// 5: ; preds = %7, %3 +// %6 = phi fast double [ %14, %7 ], [ 0.000000e+00, %3 ] +// ret double %6 +// +// 7: ; preds = %3 +// => reduce n by 1, +// %8 = add i32 %2, -1 +// %9 = call { double, double } @fwddiffe_ZN9recursive15power_recursive17h789de751cfc6154dE.1229(double %0, double %1, i32 %8) +// %10 = extractvalue { double, double } %9, 0 +// %11 = extractvalue { double, double } %9, 1 +// %12 = fmul fast double %11, %0 +// %13 = fmul fast double %1, %10 +// %14 = fadd fast double %12, %13 +// br label %5 +// } diff --git a/library/autodiff/examples_broken/broken_second_order.rs b/library/autodiff/examples_broken/broken_second_order.rs new file mode 100644 index 0000000000000..8b427d7dae36a --- /dev/null +++ b/library/autodiff/examples_broken/broken_second_order.rs @@ -0,0 +1,17 @@ +#![feature(bench_black_box)] +use autodiff::autodiff; + +fn sin(x: &f32) -> f32 { + f32::sin(*x) +} + +#[autodiff(sin, Reverse, Active, Active)] +fn cos(x: &f32, adj: f32) -> f32; + +//#[autodiff(cos, Reverse, Active, Active, Const)] +//fn neg_sin(x: &f32, adj: f32, adj_sec: f32) -> f32; + +fn main() { + dbg!(&cos(&1.0, 1.0)); + //dbg!(&neg_sin(&1.0, 1.0, 1.0)); +} diff --git a/library/autodiff/src/gen.rs b/library/autodiff/src/gen.rs new file mode 100644 index 0000000000000..68aae56ea3311 --- /dev/null +++ b/library/autodiff/src/gen.rs @@ -0,0 +1,217 @@ +use crate::parser::{is_ref_mut, PrimalSig}; +use crate::parser::{Activity, DiffItem, Mode}; +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{parse_quote, FnArg, Ident, Pat, ReturnType, Type}; + +pub(crate) fn generate_header(item: &DiffItem) -> TokenStream { + let mode = match item.header.mode { + Mode::Forward => format_ident!("Forward"), + Mode::Reverse => format_ident!("Reverse"), + }; + let ret_act = item.header.ret_act.to_ident(); + let param_act = item.params.iter().map(|x| x.to_ident()); + + quote!(#[autodiff_into(#mode, #ret_act, #( #param_act, )*)]) +} + +pub(crate) fn primal_fnc(item: &mut DiffItem) -> TokenStream { + // construct body of primal if not given + let body = item.block.clone().map(|x| quote!(#x)).unwrap_or_else(|| { + let header_fnc = &item.header.name; + //let primal_wrapper = format_ident!("primal_{}", item.primal.ident); + //item.primal.ident = primal_wrapper.clone(); + let inputs = item.primal.inputs.iter().map(|x| only_ident(x)).collect::>(); + + quote!({ + #header_fnc(#(#inputs,)*) + }) + }); + + let sig = &item.primal; + let PrimalSig { ident, inputs, output } = sig; + + let ident = + if item.block.is_some() { ident.clone() } else { format_ident!("primal_{}", ident) }; + + let sig = quote!(fn #ident(#(#inputs,)*) #output); + + quote!( + #[autodiff_into] + #sig + #body + ) +} + +fn only_ident(arg: &FnArg) -> Ident { + match arg { + FnArg::Receiver(_) => format_ident!("self"), + FnArg::Typed(t) => match &*t.pat { + Pat::Ident(ident) => ident.ident.clone(), + _ => panic!(""), + }, + } +} + +fn only_type(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +fn as_ref_mut(arg: &FnArg, name: &str, mutable: bool) -> FnArg { + match arg { + FnArg::Receiver(_) => { + let name = format_ident!("{}_self", name); + if mutable { parse_quote!(#name: &mut Self) } else { parse_quote!(#name: &Self) } + } + FnArg::Typed(t) => { + let inner = match &*t.ty { + Type::Reference(t) => &t.elem, + _ => panic!(""), // should not be reachable, as we checked mutability before + }; + + let pat_name = match &*t.pat { + Pat::Ident(x) => &x.ident, + _ => panic!(""), + }; + + let name = format_ident!("{}_{}", name, pat_name); + if mutable { parse_quote!(#name: &mut #inner) } else { parse_quote!(#name: &#inner) } + } + } +} + +pub(crate) fn adjoint_fnc(item: &DiffItem) -> TokenStream { + let mut res_inputs: Vec = Vec::new(); + let mut add_inputs: Vec = Vec::new(); + let out_type = match &item.primal.output { + ReturnType::Type(_, x) => Some(*x.clone()), + _ => None, + }; + + let mut outputs = if item.header.ret_act == Activity::Duplicated { + vec![out_type.clone().unwrap()] + } else { + vec![] + }; + + let PrimalSig { ident, inputs, .. } = &item.primal; + + for (input, activity) in inputs.iter().zip(item.params.iter()) { + res_inputs.push(input.clone()); + + match (item.header.mode, activity, is_ref_mut(&input)) { + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Forward, Activity::Duplicated|Activity::DuplicatedNoNeed, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "dual", false)); + add_inputs.push(as_ref_mut(&input, "dual", false)); + out_type.clone().map(|x| outputs.push(x)); + } + (Mode::Forward, Activity::Duplicated, None) => outputs.push(only_type(&input)), + (Mode::Reverse, Activity::Duplicated, Some(false)) => { + res_inputs.push(as_ref_mut(&input, "grad", true)); + add_inputs.push(as_ref_mut(&input, "grad", true)); + } + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed, Some(true)) => { + res_inputs.push(as_ref_mut(&input, "grad", false)); + add_inputs.push(as_ref_mut(&input, "grad", false)); + } + (Mode::Reverse, Activity::Active, None) => outputs.push(only_type(&input)), + _ => {} + } + } + + match (item.header.mode, item.header.ret_act) { + (Mode::Reverse, Activity::Active) => { + let t: FnArg = match &item.primal.output { + ReturnType::Type(_, ty) => parse_quote!(tang_y: #ty), + _ => panic!(""), + }; + res_inputs.push(t.clone()); + add_inputs.push(t); + } + _ => {} + } + + // for adjoint function -> take header if primal + // -> take ident of primal function + let adjoint_ident = if item.block.is_some() { + if let Some(ident) = item.header.name.get_ident() { + ident.clone() + } else { + abort!( + item.header.name, + "not a function name"; + help = "`#[autodiff]` function name should be a single word instead of path" + ); + } + } else { + item.primal.ident.clone() + }; + + let output = match outputs.len() { + 0 => quote!(), + 1 => { + let output = outputs.first().unwrap(); + + quote!(-> #output) + } + _ => quote!(-> (#(#outputs,)*)), + }; + + let sig = quote!(fn #adjoint_ident(#(#res_inputs,)*) #output); + let inputs = inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + let add_inputs = add_inputs + .iter() + .map(|x| match x { + FnArg::Typed(ty) => { + let pat = &ty.pat; + quote!(#pat) + } + FnArg::Receiver(_) => quote!(self), + }) + .collect::>(); + + let call_ident = match item.block.is_some() { + false => { + let ident = format_ident!("primal_{}", ident); + if item.header.name.segments.first().unwrap().ident == "Self" { + quote!(Self::#ident) + } else { + quote!(#ident) + } + } + true => quote!(#ident), + }; + + let body = quote!({ + std::hint::black_box((#call_ident(#(#inputs,)*), #(#add_inputs,)*)); + + std::hint::black_box(unsafe { std::mem::zeroed() }) + }); + let header = generate_header(&item); + + quote!( + #header + #sig + #body + ) +} diff --git a/library/autodiff/src/lib.rs b/library/autodiff/src/lib.rs new file mode 100644 index 0000000000000..b1d265fa9c59b --- /dev/null +++ b/library/autodiff/src/lib.rs @@ -0,0 +1,31 @@ +use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; +use quote::quote; + +mod gen; +mod parser; + +#[proc_macro_attribute] +#[proc_macro_error] +pub fn autodiff(args: TokenStream, input: TokenStream) -> TokenStream { + let mut params = parser::parse(args.into(), input.clone().into()); + let (primal, adjoint) = (gen::primal_fnc(&mut params), gen::adjoint_fnc(¶ms)); + + let res = quote!( + #primal + #adjoint + ); + + res.into() +} + +#[test] +pub fn expanding() { + macrotest::expand("tests/expand/*.rs"); +} + +#[test] +fn ui() { + let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/*.rs"); +} diff --git a/library/autodiff/src/parser.rs b/library/autodiff/src/parser.rs new file mode 100644 index 0000000000000..52e0c800826e7 --- /dev/null +++ b/library/autodiff/src/parser.rs @@ -0,0 +1,617 @@ +use proc_macro2::TokenStream; +use proc_macro_error::abort; +use quote::{format_ident, quote}; +use syn::{ + parse::Parser, parse_quote, punctuated::Punctuated, Attribute, Block, FnArg, ForeignItemFn, + Ident, Item, Path, ReturnType, Signature, Token, Type, +}; + +#[derive(Debug)] +pub struct PrimalSig { + pub(crate) ident: Ident, + pub(crate) inputs: Vec, + pub(crate) output: ReturnType, +} + +#[derive(Debug)] +pub struct DiffItem { + pub(crate) header: Header, + pub(crate) params: Vec, + pub(crate) primal: PrimalSig, + pub(crate) block: Option>, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Mode { + Forward, + Reverse, +} + +#[derive(Copy, Clone, Debug, PartialEq)] +pub(crate) enum Activity { + Const, + Active, + Duplicated, + DuplicatedNoNeed, +} + +impl Activity { + fn from_header(name: Option<&Ident>) -> Activity { + if name.is_none() { + return Activity::Const; + } + + match name.unwrap().to_string().as_str() { + "Const" => Activity::Const, + "Active" => Activity::Active, + "Duplicated" => Activity::Duplicated, + "DuplicatedNoNeed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (Const|Active|Duplicated|DuplicatedNoNeed)" + ); + } + } + } + + fn from_inline(name: Attribute) -> Activity { + let name = name.path.segments.first().unwrap(); + match name.ident.to_string().as_str() { + "const" => Activity::Const, + "active" => Activity::Active, + "dup" => Activity::Duplicated, + "dup_noneed" => Activity::DuplicatedNoNeed, + _ => { + abort!( + name, + "unknown activity"; + help = "`#[autodiff]` should use activities (const|active|dup|dup_noneed)" + ); + } + } + } + + pub(crate) fn to_ident(&self) -> Ident { + format_ident!( + "{}", + match self { + Activity::Const => "Const", + Activity::Active => "Active", + Activity::Duplicated => "Duplicated", + Activity::DuplicatedNoNeed => "DuplicatedNoNeed", + } + ) + } +} + +#[derive(Debug)] +pub(crate) struct Header { + pub name: Path, + pub mode: Mode, + pub ret_act: Activity, +} + +impl Header { + fn from_params(name: &Path, mode: Option<&Ident>, ret_activity: Option<&Ident>) -> Self { + // parse mode and return activity + let mode = mode + .map(|x| match x.to_string().as_str() { + "forward" | "Forward" => Mode::Forward, + "reverse" | "Reverse" => Mode::Reverse, + _ => { + abort!( + mode, + "should be forward or reverse"; + help = "`#[autodiff]` modes should be either forward or reverse" + ); + } + }) + .unwrap_or(Mode::Forward); + let ret_act = Activity::from_header(ret_activity); + + // check for invalid mode and return activity combinations + match (mode, ret_act) { + (Mode::Forward, Activity::Active) => abort!( + ret_activity, + "active return for forward mode"; + help = "`#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode" + ), + (Mode::Reverse, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + ret_activity, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + + _ => {} + } + + Header { name: name.clone(), mode, ret_act } + } + + fn parse(args: TokenStream) -> (Header, Vec) { + let args_parsed: Vec<_> = + match Punctuated::::parse_terminated.parse(args.clone().into()) { + Ok(x) => x.into_iter().collect(), + Err(_) => abort!( + args, + "duplicated return for reverse mode"; + help = "`#[autodiff]` return should be Const or Active in reverse mode" + ), + }; + + match &args_parsed[..] { + [name] => (Self::from_params(&name, None, None), vec![]), + [name, mode] => { + (Self::from_params(&name, Some(&mode.get_ident().unwrap()), None), vec![]) + } + [name, mode, ret_act, rem @ ..] => { + let params = Self::from_params( + &name, + Some(&mode.get_ident().unwrap()), + Some(&ret_act.get_ident().unwrap()), + ); + let rem = rem.into_iter() + .map(|x| x.get_ident().unwrap()) + .map(|x| Activity::from_header(Some(x))) + .map(|x| match (params.mode, x) { + (Mode::Forward, Activity::Active) => { + abort!( + args, + "active argument in forward mode"; + help = "`#[autodiff]` forward mode should be either Const, Duplicated" + ); + }, + (_, x) => x, + }) + .collect(); + + (params, rem) + } + _ => { + abort!( + args, + "please specify the autodiff function"; + help = "`#[autodiff]` needs a function name for primal or adjoint" + ); + } + } + } +} + +pub(crate) fn is_ref_mut(t: &FnArg) -> Option { + match t { + FnArg::Receiver(pat) => Some(pat.mutability.is_some()), + FnArg::Typed(pat) => match &*pat.ty { + Type::Reference(t) => Some(t.mutability.is_some()), + _ => None, + }, + } +} + +fn is_scalar(t: &Type) -> bool { + let t_f32: Type = parse_quote!(f32); + let t_f64: Type = parse_quote!(f64); + t == &t_f32 || t == &t_f64 +} + +fn ret_arg(arg: &FnArg) -> Type { + match arg { + FnArg::Receiver(_) => parse_quote!(Self), + FnArg::Typed(t) => match &*t.ty { + Type::Reference(t) => *t.elem.clone(), + x => x.clone(), + }, + } +} + +pub(crate) fn reduce_params( + mut sig: Signature, + header_acts: Vec, + is_adjoint: bool, + header: &Header, +) -> (PrimalSig, Vec) { + let mut args = Vec::new(); + let mut ret = Vec::new(); + let mut acts = Vec::new(); + let mut last_arg: Option = None; + + let mut arg_it = sig.inputs.iter_mut(); + let mut header_acts_it = header_acts.iter(); + + while let Some(arg) = arg_it.next() { + // Compare current with last argument when parsing duplicated rules. This only + // happens when we parse the signature of adjoint/augmented primal function + if let Some(prev_arg) = last_arg.take() { + match (header.mode, is_ref_mut(&prev_arg), is_ref_mut(&arg)) { + (Mode::Forward, Some(false), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate tangent into second parameter for forward mode" + ), + (Mode::Forward, Some(true), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode" + ), + (Mode::Reverse, Some(false), Some(false) | None) => abort!( + arg, + "should be a mutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + (Mode::Reverse, Some(true), Some(true) | None) => abort!( + arg, + "should be an immutable reference"; + help = "`#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode" + ), + _ => {} + } + + continue; + } + + // parse current attribute macro + let attrs: Vec<_> = match arg { + FnArg::Typed(pat) => pat.attrs.drain(..).collect(), + FnArg::Receiver(pat) => pat.attrs.drain(..).collect(), + }; + let attr = attrs.first(); + let act: Activity = match (header_acts.is_empty(), attr) { + (false, None) => header_acts_it.next().map(|x| *x).unwrap_or(Activity::Const), + (true, Some(x)) => Activity::from_inline(x.clone()), + (true, None) => Activity::Const, + _ => { + abort!( + arg, + "inline activity"; + help = "`#[autodiff]` should have activities either specified in header or as inline attributes" + ); + } + }; + + // compare indirection with activity + match (header.mode, is_ref_mut(&arg), act) { + (Mode::Forward, None, Activity::Duplicated) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated types should be behind a reference" + ), + (Mode::Forward, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "should be mutable reference"; + help = "`#[autodiff]` parameter should be output for DuplicatedNoNeed activity" + ), + (Mode::Reverse, Some(_), Activity::Active) => abort!( + arg, + "type behind reference"; + help = "`#[autodiff]` active parameter should be concrete in reverse mode" + ), + (Mode::Reverse, None, Activity::Duplicated | Activity::DuplicatedNoNeed) => abort!( + arg, + "type not behind reference"; + help = "`#[autodiff]` duplicated parameters should be behind reference in reverse mode" + ), + (Mode::Reverse, Some(false), Activity::DuplicatedNoNeed) => abort!( + arg, + "use duplicated instead"; + help = "`#[autodiff]` input parameter cannot be declared as duplicatednoneed" + ), + (Mode::Forward, Some(false), Activity::Duplicated) + if header.ret_act != Activity::Const => + { + ret.push(ret_arg(&arg)) + } + (Mode::Reverse, None, Activity::Active) => ret.push(ret_arg(&arg)), + (Mode::Forward, Some(_), Activity::Duplicated | Activity::DuplicatedNoNeed) + | (Mode::Reverse, _, Activity::Duplicated | Activity::DuplicatedNoNeed) + if is_adjoint => + { + last_arg = Some(arg.clone()) + } + _ => {} + } + + args.push(arg.clone()); + acts.push(act); + } + + // if we have adjoint signature and are in forward mode + // if duplicated -> return type * (n + 1) times + // if duplicated_no_need -> return type * n times + // if const -> no return + + // if we have adjoint signature and are in reverse mode + // if active -> input type * n times + // construct return type based on mode + let ret = if is_adjoint { + let ret_typs = match &sig.output { + ReturnType::Type(_, ref x) => match &**x { + Type::Tuple(x) => x.elems.iter().cloned().collect(), + x => vec![x.clone()], + }, + ReturnType::Default => vec![], + }; + + match (header.mode, header.ret_act) { + (Mode::Forward, Activity::Duplicated) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len() + 1]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Forward, Activity::DuplicatedNoNeed) => { + let expected = ret_typs[0].clone(); + let list = vec![expected.clone(); ret.len()]; + + if list != ret_typs { + let ret = quote!((#(#list,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ); + } + + parse_quote!(-> #expected) + } + (Mode::Reverse, Activity::Active) => { + // tangent of output is latest in parameter list + let ret_typ = match (args.pop(), acts.pop()) { + (Some(x), Some(y)) => { + let x = ret_arg(&x); + if !is_scalar(&x) { + abort!( + x, + "output tangent not a floating point"; + help = "`#[autodiff]` the output tangent should be a floating point" + ); + } else if y != Activity::Const { + abort!( + x, + "output tangent not const"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent" + ); + } else { + parse_quote!(-> #x) + } + } + (None, None) => abort!( + sig, + "missing output tangent parameter"; + help = "`#[autodiff]` the last parameter of an adjoint with active return should exist" + ), + _ => unreachable!(), + }; + + // check that the return tuple confirms with return types + if ret_typs != ret { + let ret = quote!((#(#ret,)*)); + abort!( + sig.output, + "invalid output"; + help = format!("`#[autodiff]` expected {}", ret) + ) + } + + ret_typ + } + (_, Activity::Const) if ret.len() > 0 => { + abort!( + ret[0], + "constant return but more than one return"; + help = "`#[autodiff]` adjoint should have a return type when active" + ) + } + _ => ReturnType::Default, + } + } else { + if header.ret_act != Activity::Const && sig.output == ReturnType::Default { + abort!( + sig, + "no return type"; + help = "`#[autodiff]` non-const return activity but no return type" + ) + } + + sig.output.clone() + }; + + let sig = if is_adjoint { + // header is used for calling if we are adjoint + format_ident!("{}", sig.ident) + } else { + sig.ident.clone() + }; + + (PrimalSig { ident: sig, inputs: args, output: ret }, acts) +} + +//fn check_output(arg: &FnArg) -> bool { +// match arg { +// FnArg::Receiver(x) => x.mutability.is_some(), +// FnArg::Typed(t) => is_ref_mut(&t.ty), +// } +//} +// +//fn dup_arg_with_name_mut(arg: &FnArg, name: &str, mutable: bool) -> FnArg { +// match arg { +// FnArg::Receiver(_) => { +// let name = format_ident!("{}_self", name); +// if mutable { +// parse_quote!(#name: &mut Self) +// } else { +// parse_quote!(#name: &Self) +// } +// }, +// FnArg::Typed(t) => { +// +// let inner = match &*t.ty { +// Type::Reference(t) => &t.elem, +// _ => panic!("") // should not be reachable, as we checked mutability before +// }; +// +// let pat_name = match &*t.pat { +// Pat::Ident(x) => &x.ident, +// _ => panic!(""), +// }; +// +// let name = format_ident!("{}_{}", name, pat_name); +// if mutable { +// parse_quote!(#name: &mut #inner) +// } else { +// parse_quote!(#name: &#inner) +// } +// } +// } +//} +// +//fn ret_arg(arg: &FnArg) -> Type { +// match arg { +// FnArg::Receiver(_) => parse_quote!(Self), +// FnArg::Typed(t) => { +// match &*t.ty { +// Type::Reference(t) => *t.elem.clone(), +// _ => panic!(""), +// } +// } +// } +//} +// +//fn create_target_signature_forward(mut sig: Signature, act: &Vec, ret_act: &Activity) -> Signature { +// let mut inputs = Vec::new(); +// let mut outputs = Vec::new(); +// for (p, a) in sig.inputs.iter().zip(act.into_iter()) { +// let is_output = check_output(p); +// +// if !is_output { +// inputs.push(p.clone()); +// +// if *a != Activity::Const { +// inputs.push(dup_arg_with_name_mut(&p, "adj", false)); +// } +// +// if *ret_act != Activity::Const { +// match sig.output { +// ReturnType::Type(_, ref ty) => outputs.push(ty.clone()), +// _ => panic!(""), +// } +// } +// } else { +// inputs.push(p.clone()); +// +// if *a != Activity::Const { +// inputs.push(dup_arg_with_name_mut(&p, "d", true)); +// } +// } +// } +// +// sig.inputs = inputs.into_iter().collect(); +// +// if *ret_act != Activity::Const { +// let ret_ty = match sig.output { +// ReturnType::Type(_, t) => t, +// _ => { +// abort!( +// sig.output, +// "no return type"; +// help = "`#[autodiff]` specified duplicated activity but function has not return" +// ); +// } +// }; +// +// sig.output = if *ret_act == Activity::Duplicated { +// parse_quote!(-> (#ret_ty, #( #outputs, )*)) +// } else { +// if outputs.len() > 1 { +// parse_quote!(-> (#( #outputs, )*)) +// } else { +// parse_quote!(-> #( #outputs )*) +// } +// }; +// } +// +// sig +//} +// +//fn create_target_signature_reverse(mut sig: Signature, act: &Vec, ret_act: &Activity) -> Signature { +// let mut inputs = Vec::new(); +// let mut outputs = Vec::new(); +// for (p, a) in sig.inputs.iter().zip(act.into_iter()) { +// let is_output = check_output(p); +// +// if !is_output { +// inputs.push(p.clone()); +// +// match a { +// Activity::Active => { +// outputs.push(ret_arg(&p)); +// }, +// Activity::Duplicated | Activity::DuplicatedNoNeed => inputs.push(dup_arg_with_name_mut(&p, "d", true)), +// _ => {} +// } +// } else { +// inputs.push(p.clone()); +// +// if *a != Activity::Const { +// inputs.push(dup_arg_with_name_mut(&p, "adj", false)); +// } +// } +// } +// +// match sig.output { +// ReturnType::Type(_, typ) => { +// inputs.push(parse_quote!(ret_adj: #typ)); +// }, +// _ => {} +// } +// +// sig.inputs = inputs.into_iter().collect(); +// +// sig.output = if *ret_act == Activity::Active { +// match outputs.len() { +// 0 => parse_quote!(), +// 1 => parse_quote!(-> #( #outputs )*), +// _ => parse_quote!(-> (#( #outputs, )*)) +// } +// } else { +// parse_quote!() +// }; +// +// sig +//} +pub(crate) fn parse(args: TokenStream, input: TokenStream) -> DiffItem { + // first parse function + let (_attrs, _, sig, block) = match syn::parse2::(input) { + Ok(Item::Fn(item)) => (item.attrs, item.vis, item.sig, Some(item.block)), + Ok(Item::Verbatim(x)) => match syn::parse2::(x) { + Ok(item) => (item.attrs, item.vis, item.sig, None), + Err(err) => panic!("Could not parse item {}", err), + }, + Ok(item) => { + abort!( + item, + "item is not a function"; + help = "`#[autodiff]` can only be used on primal or adjoint functions" + ) + } + Err(err) => panic!("Could not parse item: {}", err), + }; + + // then parse attributes + let (header, param_attrs) = Header::parse(args); + + // reduce parameters to primal parameter set + let (primal, params) = reduce_params(sig, param_attrs, !block.is_some(), &header); + + DiffItem { header, primal, params, block } +} diff --git a/library/autodiff/tests/expand/forward_duplicated.expanded.rs b/library/autodiff/tests/expand/forward_duplicated.expanded.rs new file mode 100644 index 0000000000000..bf3890154ab8e --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Forward, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, dual_a: &Vec, b: &mut f32, grad_b: &mut f32) { + std::hint::black_box((square(a, b), dual_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated.rs b/library/autodiff/tests/expand/forward_duplicated.rs new file mode 100644 index 0000000000000..9a0bfc6c13a47 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Forward, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs new file mode 100644 index 0000000000000..a3754de7ab70b --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.expanded.rs @@ -0,0 +1,15 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square2(a: &Vec, b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} +#[autodiff_into(Forward, Duplicated, Duplicated, Duplicated)] +fn d_square2( + a: &Vec, + dual_a: &Vec, + b: &Vec, + dual_b: &Vec, +) -> (f32, f32, f32) { + std::hint::black_box((square2(a, b), dual_a, dual_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/forward_duplicated_return.rs b/library/autodiff/tests/expand/forward_duplicated_return.rs new file mode 100644 index 0000000000000..3397e5309ea96 --- /dev/null +++ b/library/autodiff/tests/expand/forward_duplicated_return.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square2, Forward, Duplicated)] +fn square2(#[dup] a: &Vec, #[dup] b: &Vec) -> f32 { + a.into_iter().map(f32::square).sum() +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.expanded.rs b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs new file mode 100644 index 0000000000000..60c0d7f2f696b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn square(a: &Vec, b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} +#[autodiff_into(Reverse, Const, Duplicated, Duplicated)] +fn d_square(a: &Vec, grad_a: &mut Vec, b: &mut f32, grad_b: &f32) { + std::hint::black_box((square(a, b), grad_a, grad_b)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_duplicated.rs b/library/autodiff/tests/expand/reverse_duplicated.rs new file mode 100644 index 0000000000000..107a708bec848 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_duplicated.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_square, Reverse, Const)] +fn square(#[dup] a: &Vec, #[dup] b: &mut f32) { + *b = a.into_iter().map(f32::square).sum(); +} diff --git a/library/autodiff/tests/expand/reverse_return_array.expanded.rs b/library/autodiff/tests/expand/reverse_return_array.expanded.rs new file mode 100644 index 0000000000000..5b784157fea7b --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.expanded.rs @@ -0,0 +1,10 @@ +use autodiff::autodiff; +#[autodiff_into] +fn array(arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} +#[autodiff_into(Reverse, Active, Duplicated)] +fn d_array(arr: &[[[f32; 2]; 2]; 2], grad_arr: &mut [[[f32; 2]; 2]; 2], tang_y: f32) { + std::hint::black_box((array(arr), grad_arr, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_array.rs b/library/autodiff/tests/expand/reverse_return_array.rs new file mode 100644 index 0000000000000..da080a6b3a860 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_array.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_array, Reverse, Active)] +fn array(#[dup] arr: &[[[f32; 2]; 2]; 2]) -> f32 { + arr[0][0][0] * arr[1][1][1] +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs new file mode 100644 index 0000000000000..f49864fb7e9b9 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.expanded.rs @@ -0,0 +1,17 @@ +use autodiff::autodiff; +#[autodiff_into] +fn sqrt(a: f32, b: &f32, c: &f32, d: f32) -> f32 { + a * (b * b + c * c * d * d).sqrt() +} +#[autodiff_into(Reverse, Active, Active, Duplicated, Const, Active)] +fn d_sqrt( + a: f32, + b: &f32, + grad_b: &mut f32, + c: &f32, + d: f32, + tang_y: f32, +) -> (f32, f32) { + std::hint::black_box((sqrt(a, b, c, d), grad_b, tang_y)); + std::hint::black_box(unsafe { std::mem::zeroed() }) +} diff --git a/library/autodiff/tests/expand/reverse_return_mixed.rs b/library/autodiff/tests/expand/reverse_return_mixed.rs new file mode 100644 index 0000000000000..3260c3560d523 --- /dev/null +++ b/library/autodiff/tests/expand/reverse_return_mixed.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sqrt, Reverse, Active)] +fn sqrt(#[active] a: f32, #[dup] b: &f32, c: &f32, #[active] d: f32) -> f32 { + a * (b * b + c*c*d*d).sqrt() +} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.rs b/library/autodiff/tests/ui/active_in_forward_mode.rs new file mode 100644 index 0000000000000..10366b1b422b8 --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/active_in_forward_mode.stderr b/library/autodiff/tests/ui/active_in_forward_mode.stderr new file mode 100644 index 0000000000000..cd413564068ae --- /dev/null +++ b/library/autodiff/tests/ui/active_in_forward_mode.stderr @@ -0,0 +1,7 @@ +error: active argument in forward mode + --> tests/ui/active_in_forward_mode.rs:3:12 + | +3 | #[autodiff(d_sin, Forward, DuplicatedNoNeed, Active)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` forward mode should be either Const, Duplicated diff --git a/library/autodiff/tests/ui/activities_inline_and_header.rs b/library/autodiff/tests/ui/activities_inline_and_header.rs new file mode 100644 index 0000000000000..1ecf37ec60a8f --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active, Active)] +fn sin(#[active] x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/activities_inline_and_header.stderr b/library/autodiff/tests/ui/activities_inline_and_header.stderr new file mode 100644 index 0000000000000..b4d50d02a26a4 --- /dev/null +++ b/library/autodiff/tests/ui/activities_inline_and_header.stderr @@ -0,0 +1,7 @@ +error: inline activity + --> tests/ui/activities_inline_and_header.rs:4:18 + | +4 | fn sin(#[active] x: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` should have activities either specified in header or as inline attributes diff --git a/library/autodiff/tests/ui/invalid_indirection.rs b/library/autodiff/tests/ui/invalid_indirection.rs new file mode 100644 index 0000000000000..627a7cb0fc6f9 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.rs @@ -0,0 +1,19 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Const)] +fn duplicated_without_reference(#[dup] x: f32) { +} + +#[autodiff(d_sin, Reverse, Const)] +fn active_with_reference(#[active] x: &f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_forward(#[dup] x: f32) { +} + +#[autodiff(d_sin, Forward, Const)] +fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { +} + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_indirection.stderr b/library/autodiff/tests/ui/invalid_indirection.stderr new file mode 100644 index 0000000000000..cb27c542018e5 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_indirection.stderr @@ -0,0 +1,31 @@ +error: type not behind reference + --> tests/ui/invalid_indirection.rs:4:40 + | +4 | fn duplicated_without_reference(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated parameters should be behind reference in reverse mode + +error: type behind reference + --> tests/ui/invalid_indirection.rs:8:36 + | +8 | fn active_with_reference(#[active] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` active parameter should be concrete in reverse mode + +error: type not behind reference + --> tests/ui/invalid_indirection.rs:12:30 + | +12 | fn duplicated_forward(#[dup] x: f32) { + | ^^^^^^ + | + = help: `#[autodiff]` duplicated types should be behind a reference + +error: should be mutable reference + --> tests/ui/invalid_indirection.rs:16:45 + | +16 | fn duplicated_no_need_forward(#[dup_noneed] x: &f32) { + | ^^^^^^^ + | + = help: `#[autodiff]` parameter should be output for DuplicatedNoNeed activity diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.rs b/library/autodiff/tests/ui/invalid_mutability_pairs.rs new file mode 100644 index 0000000000000..708ecc597a5be --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.rs @@ -0,0 +1,24 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Duplicated)] +fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Forward, Duplicated)] +fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_mutability_pairs.stderr b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr new file mode 100644 index 0000000000000..37af0c2ad52ee --- /dev/null +++ b/library/autodiff/tests/ui/invalid_mutability_pairs.stderr @@ -0,0 +1,55 @@ +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:4:48 + | +4 | fn fwd_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:7:41 + | +7 | fn output_immutable(#[dup] x: &mut f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` output parameter should duplicate derivative into second parameter for forward mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:10:43 + | +10 | fn rev_input_no_reference(#[dup] x: &f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:13:48 + | +13 | fn rev_output_no_reference(#[dup] x: &mut f32, y: f32) -> f32; + | ^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be a mutable reference + --> tests/ui/invalid_mutability_pairs.rs:16:36 + | +16 | fn input_immutable(#[dup] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: should be an immutable reference + --> tests/ui/invalid_mutability_pairs.rs:19:39 + | +19 | fn output_mutable(#[dup] x: &mut f32, y: &mut f32) -> f32; + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` input parameter should duplicate derivative into second parameter for reverse mode + +error: use duplicated instead + --> tests/ui/invalid_mutability_pairs.rs:22:34 + | +22 | fn dupnoneed_input(#[dup_noneed] x: &f32, y: &f32) -> f32; + | ^^^^^^^ + | + = help: `#[autodiff]` input parameter cannot be declared as duplicatednoneed diff --git a/library/autodiff/tests/ui/invalid_return.rs b/library/autodiff/tests/ui/invalid_return.rs new file mode 100644 index 0000000000000..b3c8bce1166bf --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Forward, Active)] +fn sin1(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, Duplicated)] +fn sin2(x: f32) -> f32; + +#[autodiff(d_sin, Reverse, DuplicatedNoNeed)] +fn sin3(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return.stderr b/library/autodiff/tests/ui/invalid_return.stderr new file mode 100644 index 0000000000000..4ddaccdba0f72 --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return.stderr @@ -0,0 +1,23 @@ +error: active return for forward mode + --> tests/ui/invalid_return.rs:3:28 + | +3 | #[autodiff(d_sin, Forward, Active)] + | ^^^^^^ + | + = help: `#[autodiff]` return should be Const, Duplicated or DuplicatedNoNeed in forward mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:6:28 + | +6 | #[autodiff(d_sin, Reverse, Duplicated)] + | ^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode + +error: duplicated return for reverse mode + --> tests/ui/invalid_return.rs:9:28 + | +9 | #[autodiff(d_sin, Reverse, DuplicatedNoNeed)] + | ^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` return should be Const or Active in reverse mode diff --git a/library/autodiff/tests/ui/invalid_return_type.rs b/library/autodiff/tests/ui/invalid_return_type.rs new file mode 100644 index 0000000000000..7b91ccd2d650a --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.rs @@ -0,0 +1,16 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn active_but_no_return(#[active] x: f32) { +} + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + +#[autodiff(d_sin, Forward, Duplicated)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + +#[autodiff(d_sin, Forward, DuplicatedNoNeed)] +fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + +fn main() {} diff --git a/library/autodiff/tests/ui/invalid_return_type.stderr b/library/autodiff/tests/ui/invalid_return_type.stderr new file mode 100644 index 0000000000000..90e5e47a2a33d --- /dev/null +++ b/library/autodiff/tests/ui/invalid_return_type.stderr @@ -0,0 +1,31 @@ +error: no return type + --> tests/ui/invalid_return_type.rs:4:1 + | +4 | fn active_but_no_return(#[active] x: f32) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` non-const return activity but no return type + +error: invalid output + --> tests/ui/invalid_return_type.rs:8:100 + | +8 | fn invalid_primal_value(#[active] x: f32, #[active] y: Vec, #[active] z: Tensor, y_tang: f32) -> (i32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, Vec < f32 >, Tensor,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:11:121 + | +11 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32, f32); + | ^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32, f32,) + +error: invalid output + --> tests/ui/invalid_return_type.rs:14:121 + | +14 | fn invalid_forward_return(#[dup] x: &f32, tx: &f32, #[dup] y: &Vec, ty: &Vec, #[dup] z: &Tensor, tz: &Tensor) -> (f32, f32); + | ^^^^^^^^^^^^^ + | + = help: `#[autodiff]` expected (f32, f32, f32,) diff --git a/library/autodiff/tests/ui/no_function_name.rs b/library/autodiff/tests/ui/no_function_name.rs new file mode 100644 index 0000000000000..8222ca4aaf37d --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/no_function_name.stderr b/library/autodiff/tests/ui/no_function_name.stderr new file mode 100644 index 0000000000000..e98add3164c9f --- /dev/null +++ b/library/autodiff/tests/ui/no_function_name.stderr @@ -0,0 +1,8 @@ +error: please specify the autodiff function + --> tests/ui/no_function_name.rs:3:1 + | +3 | #[autodiff] + | ^^^^^^^^^^^ + | + = help: `#[autodiff]` needs a function name for primal or adjoint + = note: this error originates in the attribute macro `autodiff` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/library/autodiff/tests/ui/not_a_function.rs b/library/autodiff/tests/ui/not_a_function.rs new file mode 100644 index 0000000000000..0a3c11725a086 --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff] +struct NotAFunction; + +fn main() {} diff --git a/library/autodiff/tests/ui/not_a_function.stderr b/library/autodiff/tests/ui/not_a_function.stderr new file mode 100644 index 0000000000000..c681841532a5e --- /dev/null +++ b/library/autodiff/tests/ui/not_a_function.stderr @@ -0,0 +1,7 @@ +error: item is not a function + --> tests/ui/not_a_function.rs:4:1 + | +4 | struct NotAFunction; + | ^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` can only be used on primal or adjoint functions diff --git a/library/autodiff/tests/ui/reverse_tangent.rs b/library/autodiff/tests/ui/reverse_tangent.rs new file mode 100644 index 0000000000000..603f7fd1789ce --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.rs @@ -0,0 +1,12 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, Reverse, Active)] +fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + +#[autodiff(d_sin, Reverse, Active)] +fn tangent_missing() -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/reverse_tangent.stderr b/library/autodiff/tests/ui/reverse_tangent.stderr new file mode 100644 index 0000000000000..a7b4b6e3d97d6 --- /dev/null +++ b/library/autodiff/tests/ui/reverse_tangent.stderr @@ -0,0 +1,23 @@ +error: output tangent not a floating point + --> tests/ui/reverse_tangent.rs:4:58 + | +4 | fn invalid_output_tangent_type(#[active] x: f32, y_tang: i32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the output tangent should be a floating point + +error: output tangent not const + --> tests/ui/reverse_tangent.rs:7:62 + | +7 | fn active_output_tangent(#[active] x: f32, #[active] y_tang: f32) -> f32; + | ^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should be a constant tangent + +error: missing output tangent parameter + --> tests/ui/reverse_tangent.rs:10:1 + | +10 | fn tangent_missing() -> f32; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = help: `#[autodiff]` the last parameter of an adjoint with active return should exist diff --git a/library/autodiff/tests/ui/wrong_mode.rs b/library/autodiff/tests/ui/wrong_mode.rs new file mode 100644 index 0000000000000..1b500711de109 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.rs @@ -0,0 +1,6 @@ +use autodiff::autodiff; + +#[autodiff(d_sin, WrongMode)] +fn sin(x: f32) -> f32; + +fn main() {} diff --git a/library/autodiff/tests/ui/wrong_mode.stderr b/library/autodiff/tests/ui/wrong_mode.stderr new file mode 100644 index 0000000000000..ca18d81abb306 --- /dev/null +++ b/library/autodiff/tests/ui/wrong_mode.stderr @@ -0,0 +1,7 @@ +error: should be forward or reverse + --> tests/ui/wrong_mode.rs:3:19 + | +3 | #[autodiff(d_sin, WrongMode)] + | ^^^^^^^^^ + | + = help: `#[autodiff]` modes should be either forward or reverse diff --git a/library/core/src/macros/mod.rs b/library/core/src/macros/mod.rs index b24882ddb179f..04884da2ac8cd 100644 --- a/library/core/src/macros/mod.rs +++ b/library/core/src/macros/mod.rs @@ -1312,6 +1312,18 @@ pub(crate) mod builtin { }; } + /// Differentiate function + ///#[unstable( + /// feature = "autodiff", + /// issue = "29598", + /// reason = "autodiff is not stable enough" + ///)] + ///#[rustc_builtin_macro] + ///#[macro_export] + ///pub macro autodiff($item:item) { + /// /* compiler built-in */ + ///} + /// Parses a file as an expression or an item according to the context. /// /// **Warning**: For multi-file Rust projects, the `include!` macro is probably not what you diff --git a/library/std/Cargo.toml b/library/std/Cargo.toml index 1454b00255650..41bb8e43cc6de 100644 --- a/library/std/Cargo.toml +++ b/library/std/Cargo.toml @@ -26,6 +26,7 @@ std_detect = { path = "../stdarch/crates/std_detect", default-features = false, addr2line = { version = "0.19.0", optional = true, default-features = false } rustc-demangle = { version = "0.1.21", features = ['rustc-dep-of-std'] } miniz_oxide = { version = "0.6.0", optional = true, default-features = false } + [dependencies.object] version = "0.30.0" optional = true diff --git a/src/bootstrap/builder.rs b/src/bootstrap/builder.rs index 237f65b039f82..b8338c5b43066 100644 --- a/src/bootstrap/builder.rs +++ b/src/bootstrap/builder.rs @@ -1353,7 +1353,7 @@ impl<'a> Builder<'a> { }).unwrap_or_else(|_| { eprintln!( "error: `x.py clippy` requires a host `rustc` toolchain with the `clippy` component" - ); + ); eprintln!("help: try `rustup component add clippy`"); crate::detail_exit(1); }); @@ -1365,6 +1365,11 @@ impl<'a> Builder<'a> { } } + // TODO: adjust -14 ending for Enzyme + // https://rust-lang.zulipchat.com/#narrow/stream/182449-t-compiler.2Fhelp/topic/.E2.9C.94.20link.20new.20library.20into.20stage1.2Frustc + rustflags.arg("-l"); + rustflags.arg("LLVMEnzyme-16"); + let use_new_symbol_mangling = match self.config.rust_new_symbol_mangling { Some(setting) => { // If an explicit setting is given, use that diff --git a/src/bootstrap/compile.rs b/src/bootstrap/compile.rs index 33addb90da372..4f9590dcad7b5 100644 --- a/src/bootstrap/compile.rs +++ b/src/bootstrap/compile.rs @@ -6,6 +6,10 @@ //! the compiler. This module is also responsible for assembling the sysroot as it //! goes along from the output of the previous stage. +// !#[cfg(compiler.stage == 1)] +// extern "C" { +// } + use std::borrow::Cow; use std::collections::HashSet; use std::env; @@ -1356,6 +1360,7 @@ pub struct Assemble { pub target_compiler: Compiler, } +#[allow(unreachable_code)] impl Step for Assemble { type Output = Compiler; const ONLY_HOSTS: bool = true; @@ -1411,6 +1416,24 @@ impl Step for Assemble { return target_compiler; } + // Build enzyme + let enzyme_install = if builder.config.llvm_enzyme { + Some(builder.ensure(llvm::Enzyme { target: build_compiler.host })) + } else { + None + }; + + if let Some(enzyme_install) = enzyme_install { + let src_lib = enzyme_install.join("build/Enzyme/LLVMEnzyme-16.so"); + + let libdir = builder.sysroot_libdir(build_compiler, build_compiler.host); + let target_libdir = builder.sysroot_libdir(target_compiler, target_compiler.host); + let dst_lib = libdir.join("libLLVMEnzyme-16.so"); + let target_dst_lib = target_libdir.join("libLLVMEnzyme-16.so"); + builder.copy(&src_lib, &dst_lib); + builder.copy(&src_lib, &target_dst_lib); + } + // Build the libraries for this compiler to link to (i.e., the libraries // it uses at runtime). NOTE: Crates the target compiler compiles don't // link to these. (FIXME: Is that correct? It seems to be correct most @@ -1787,7 +1810,7 @@ pub fn stream_cargo( if builder.is_verbose() && !status.success() { eprintln!( "command did not execute successfully: {:?}\n\ - expected success, got: {}", + expected success, got: {}", cargo, status ); } diff --git a/src/bootstrap/config.rs b/src/bootstrap/config.rs index bf3bc3247acaf..ee2bb8fd6ffe4 100644 --- a/src/bootstrap/config.rs +++ b/src/bootstrap/config.rs @@ -116,6 +116,7 @@ pub struct Config { // llvm codegen options pub llvm_assertions: bool, pub llvm_tests: bool, + pub llvm_enzyme: bool, pub llvm_plugins: bool, pub llvm_optimize: bool, pub llvm_thin_lto: bool, @@ -541,9 +542,9 @@ macro_rules! define_config { $($field:ident: Option<$field_ty:ty> = $field_key:literal,)* }) => { $(#[$attr])* - struct $name { - $($field: Option<$field_ty>,)* - } + struct $name { + $($field: Option<$field_ty>,)* + } impl Merge for $name { fn merge(&mut self, other: Self) { @@ -551,7 +552,7 @@ macro_rules! define_config { if !self.$field.is_some() { self.$field = other.$field; } - )* + )* } } @@ -560,64 +561,64 @@ macro_rules! define_config { // compile time of rustbuild. impl<'de> Deserialize<'de> for $name { fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct Field; - impl<'de> serde::de::Visitor<'de> for Field { - type Value = $name; - fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(concat!("struct ", stringify!($name))) - } + where + D: Deserializer<'de>, + { + struct Field; + impl<'de> serde::de::Visitor<'de> for Field { + type Value = $name; + fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(concat!("struct ", stringify!($name))) + } - #[inline] - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - $(let mut $field: Option<$field_ty> = None;)* - while let Some(key) = - match serde::de::MapAccess::next_key::(&mut map) { - Ok(val) => val, - Err(err) => { - return Err(err); - } - } - { - match &*key { - $($field_key => { - if $field.is_some() { - return Err(::duplicate_field( - $field_key, - )); - } - $field = match serde::de::MapAccess::next_value::<$field_ty>( - &mut map, - ) { - Ok(val) => Some(val), - Err(err) => { - return Err(err); + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + $(let mut $field: Option<$field_ty> = None;)* + while let Some(key) = + match serde::de::MapAccess::next_key::(&mut map) { + Ok(val) => val, + Err(err) => { + return Err(err); + } + } + { + match &*key { + $($field_key => { + if $field.is_some() { + return Err(::duplicate_field( + $field_key, + )); + } + $field = match serde::de::MapAccess::next_value::<$field_ty>( + &mut map, + ) { + Ok(val) => Some(val), + Err(err) => { + return Err(err); + } + }; + })* + key => { + return Err(serde::de::Error::unknown_field(key, FIELDS)); } - }; - })* - key => { - return Err(serde::de::Error::unknown_field(key, FIELDS)); + } } + Ok($name { $($field),* }) } - } - Ok($name { $($field),* }) } + const FIELDS: &'static [&'static str] = &[ + $($field_key,)* + ]; + Deserializer::deserialize_struct( + deserializer, + stringify!($name), + FIELDS, + Field, + ) } - const FIELDS: &'static [&'static str] = &[ - $($field_key,)* - ]; - Deserializer::deserialize_struct( - deserializer, - stringify!($name), - FIELDS, - Field, - ) - } } } } @@ -691,6 +692,7 @@ define_config! { release_debuginfo: Option = "release-debuginfo", assertions: Option = "assertions", tests: Option = "tests", + enzyme: Option = "enzyme", plugins: Option = "plugins", ccache: Option = "ccache", static_libstdcpp: Option = "static-libstdcpp", @@ -1092,6 +1094,7 @@ impl Config { // we'll infer default values for them later let mut llvm_assertions = None; let mut llvm_tests = None; + let mut llvm_enzyme = None; let mut llvm_plugins = None; let mut debug = None; let mut debug_assertions = None; @@ -1204,6 +1207,7 @@ impl Config { set(&mut config.ninja_in_file, llvm.ninja); llvm_assertions = llvm.assertions; llvm_tests = llvm.tests; + llvm_enzyme = llvm.enzyme; llvm_plugins = llvm.plugins; set(&mut config.llvm_optimize, llvm.optimize); set(&mut config.llvm_thin_lto, llvm.thin_lto); @@ -1268,6 +1272,7 @@ impl Config { check_ci_llvm!(llvm.polly); check_ci_llvm!(llvm.clang); check_ci_llvm!(llvm.build_config); + check_ci_llvm!(llvm.enzyme); check_ci_llvm!(llvm.plugins); } @@ -1363,6 +1368,7 @@ impl Config { config.llvm_assertions = llvm_assertions.unwrap_or(false); config.llvm_tests = llvm_tests.unwrap_or(false); + config.llvm_enzyme = llvm_enzyme.unwrap_or(false); config.llvm_plugins = llvm_plugins.unwrap_or(false); config.rust_optimize = optimize.unwrap_or(true); diff --git a/src/bootstrap/configure.py b/src/bootstrap/configure.py index 571062a3a6fd0..a97e3715e184e 100755 --- a/src/bootstrap/configure.py +++ b/src/bootstrap/configure.py @@ -70,6 +70,7 @@ def v(*args): # channel, etc. o("optimize-llvm", "llvm.optimize", "build optimized LLVM") o("llvm-assertions", "llvm.assertions", "build LLVM with assertions") +o("llvm-enzyme", "llvm.enzyme", "build LLVM with Enzyme") o("llvm-plugins", "llvm.plugins", "build LLVM with plugin interface") o("debug-assertions", "rust.debug-assertions", "build with debugging assertions") o("debug-assertions-std", "rust.debug-assertions-std", "build the standard library with debugging assertions") diff --git a/src/bootstrap/lib.rs b/src/bootstrap/lib.rs index 994336977dc6a..53b17c6eb083d 100644 --- a/src/bootstrap/lib.rs +++ b/src/bootstrap/lib.rs @@ -827,6 +827,10 @@ impl Build { self.out.join(&*target.triple).join("lld") } + fn enzyme_out(&self, target: TargetSelection) -> PathBuf { + self.out.join(&*target.triple).join("enzyme") + } + /// Output directory for all documentation for a target fn doc_out(&self, target: TargetSelection) -> PathBuf { self.out.join(&*target.triple).join("doc") diff --git a/src/bootstrap/llvm.rs b/src/bootstrap/llvm.rs index 67cb88373910c..8dc4c3af667fb 100644 --- a/src/bootstrap/llvm.rs +++ b/src/bootstrap/llvm.rs @@ -767,6 +767,72 @@ fn get_var(var_base: &str, host: &str, target: &str) -> Option { .or_else(|| env::var_os(var_base)) } +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Enzyme { + pub target: TargetSelection, +} + +impl Step for Enzyme { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/tools/enzyme/enzyme") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Enzyme { target: run.target }); + } + + /// Compile Enzyme for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run() { + let out_dir = builder.enzyme_out(self.target); + return out_dir; + } + let target = self.target; + + let LlvmResult { llvm_config, .. } = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.enzyme_out(target); + let done_stamp = out_dir.join("enzyme-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building Enzyme for {}", target)); + let _time = util::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + builder.update_submodule(&Path::new("src").join("tools").join("enzyme")); + let mut cfg = cmake::Config::new(builder.src.join("src/tools/enzyme/enzyme/")); + // TODO: Find a nicer way to use Enzyme Debug builds + //cfg.profile("Debug"); + //cfg.define("CMAKE_BUILD_TYPE", "Debug"); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default(), &[]); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_ENABLE_ASSERTIONS", "ON") + .define("ENZYME_EXTERNAL_SHARED_LIB", "OFF") + .define("LLVM_DIR", builder.llvm_out(target)); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] pub struct Lld { pub target: TargetSelection, diff --git a/src/bootstrap/native.rs b/src/bootstrap/native.rs new file mode 100644 index 0000000000000..3ae0015bb8a41 --- /dev/null +++ b/src/bootstrap/native.rs @@ -0,0 +1,1166 @@ +//! Compilation of native dependencies like LLVM. +//! +//! Native projects like LLVM unfortunately aren't suited just yet for +//! compilation in build scripts that Cargo has. This is because the +//! compilation takes a *very* long time but also because we don't want to +//! compile LLVM 3 times as part of a normal bootstrap (we want it cached). +//! +//! LLVM and compiler-rt are essentially just wired up to everything else to +//! ensure that they're always in place if needed. + +use std::env; +use std::env::consts::EXE_EXTENSION; +use std::ffi::{OsStr, OsString}; +use std::fs::{self, File}; +use std::io; +use std::path::{Path, PathBuf}; +use std::process::Command; + +use crate::builder::{Builder, RunConfig, ShouldRun, Step}; +use crate::config::TargetSelection; +use crate::util::{self, exe, output, t, up_to_date}; +use crate::{CLang, GitRepo}; + +pub struct Meta { + stamp: HashStamp, + build_llvm_config: PathBuf, + out_dir: PathBuf, + root: String, +} + +// Linker flags to pass to LLVM's CMake invocation. +#[derive(Debug, Clone, Default)] +struct LdFlags { + // CMAKE_EXE_LINKER_FLAGS + exe: OsString, + // CMAKE_SHARED_LINKER_FLAGS + shared: OsString, + // CMAKE_MODULE_LINKER_FLAGS + module: OsString, +} + +impl LdFlags { + fn push_all(&mut self, s: impl AsRef) { + let s = s.as_ref(); + self.exe.push(" "); + self.exe.push(s); + self.shared.push(" "); + self.shared.push(s); + self.module.push(" "); + self.module.push(s); + } +} + +// This returns whether we've already previously built LLVM. +// +// It's used to avoid busting caches during x.py check -- if we've already built +// LLVM, it's fine for us to not try to avoid doing so. +// +// This will return the llvm-config if it can get it (but it will not build it +// if not). +pub fn prebuilt_llvm_config( + builder: &Builder<'_>, + target: TargetSelection, +) -> Result { + // If we're using a custom LLVM bail out here, but we can only use a + // custom LLVM for the build triple. + if let Some(config) = builder.config.target_config.get(&target) { + if let Some(ref s) = config.llvm_config { + check_llvm_version(builder, s); + return Ok(s.to_path_buf()); + } + } + + let root = "src/llvm-project/llvm"; + let out_dir = builder.llvm_out(target); + + let mut llvm_config_ret_dir = builder.llvm_out(builder.config.build); + if !builder.config.build.contains("msvc") || builder.ninja() { + llvm_config_ret_dir.push("build"); + } + llvm_config_ret_dir.push("bin"); + + let build_llvm_config = llvm_config_ret_dir.join(exe("llvm-config", builder.config.build)); + + let stamp = out_dir.join("llvm-finished-building"); + let stamp = HashStamp::new(stamp, builder.in_tree_llvm_info.sha()); + + if builder.config.llvm_skip_rebuild && stamp.path.exists() { + builder.info( + "Warning: \ + Using a potentially stale build of LLVM; \ + This may not behave well.", + ); + return Ok(build_llvm_config); + } + + if stamp.is_done() { + if stamp.hash.is_none() { + builder.info( + "Could not determine the LLVM submodule commit hash. \ + Assuming that an LLVM rebuild is not necessary.", + ); + builder.info(&format!( + "To force LLVM to rebuild, remove the file `{}`", + stamp.path.display() + )); + } + return Ok(build_llvm_config); + } + + Err(Meta { stamp, build_llvm_config, out_dir, root: root.into() }) +} + +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Llvm { + pub target: TargetSelection, +} + +impl Step for Llvm { + type Output = PathBuf; // path to llvm-config + + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/llvm-project").path("src/llvm-project/llvm").path("src/llvm") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Llvm { target: run.target }); + } + + /// Compile LLVM for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + let target = self.target; + let target_native = if self.target.starts_with("riscv") { + // RISC-V target triples in Rust is not named the same as C compiler target triples. + // This converts Rust RISC-V target triples to C compiler triples. + let idx = target.triple.find('-').unwrap(); + + format!("riscv{}{}", &target.triple[5..7], &target.triple[idx..]) + } else if self.target.starts_with("powerpc") && self.target.ends_with("freebsd") { + // FreeBSD 13 had incompatible ABI changes on all PowerPC platforms. + // Set the version suffix to 13.0 so the correct target details are used. + format!("{}{}", self.target, "13.0") + } else { + target.to_string() + }; + + let Meta { stamp, build_llvm_config, out_dir, root } = + match prebuilt_llvm_config(builder, target) { + Ok(p) => return p, + Err(m) => m, + }; + + builder.update_submodule(&Path::new("src").join("llvm-project")); + if builder.config.llvm_link_shared + && (target.contains("windows") || target.contains("apple-darwin")) + { + panic!("shared linking to LLVM is not currently supported on {}", target.triple); + } + + builder.info(&format!("Building LLVM for {}", target)); + t!(stamp.remove()); + let _time = util::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + // https://llvm.org/docs/CMake.html + let mut cfg = cmake::Config::new(builder.src.join(root)); + let mut ldflags = LdFlags::default(); + + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + // NOTE: remember to also update `config.toml.example` when changing the + // defaults! + let llvm_targets = match &builder.config.llvm_targets { + Some(s) => s, + None => { + "AArch64;ARM;BPF;Hexagon;MSP430;Mips;NVPTX;PowerPC;RISCV;\ + Sparc;SystemZ;WebAssembly;X86" + } + }; + + let llvm_exp_targets = match builder.config.llvm_experimental_targets { + Some(ref s) => s, + None => "AVR;M68k", + }; + + let assertions = if builder.config.llvm_assertions { "ON" } else { "OFF" }; + // Not needed, not an LLVM component yet + // let enzyme = if builder.config.llvm_enzyme { "ON" } else { "OFF" }; + let plugins = if builder.config.llvm_plugins { "ON" } else { "OFF" }; + let enable_tests = if builder.config.llvm_tests { "ON" } else { "OFF" }; + + cfg.out_dir(&out_dir) + .profile(profile) + .define("LLVM_ENABLE_ASSERTIONS", assertions) + .define("LLVM_ENABLE_PLUGINS", plugins) + .define("LLVM_TARGETS_TO_BUILD", llvm_targets) + .define("LLVM_EXPERIMENTAL_TARGETS_TO_BUILD", llvm_exp_targets) + .define("LLVM_INCLUDE_EXAMPLES", "OFF") + .define("LLVM_INCLUDE_DOCS", "OFF") + .define("LLVM_INCLUDE_BENCHMARKS", "OFF") + .define("LLVM_INCLUDE_TESTS", enable_tests) + .define("LLVM_ENABLE_TERMINFO", "OFF") + .define("LLVM_ENABLE_LIBEDIT", "OFF") + .define("LLVM_ENABLE_BINDINGS", "OFF") + .define("LLVM_ENABLE_Z3_SOLVER", "OFF") + .define("LLVM_PARALLEL_COMPILE_JOBS", builder.jobs().to_string()) + .define("LLVM_TARGET_ARCH", target_native.split('-').next().unwrap()) + .define("LLVM_DEFAULT_TARGET_TRIPLE", target_native); + + // Parts of our test suite rely on the `FileCheck` tool, which is built by default in + // `build/$TARGET/llvm/build/bin` is but *not* then installed to `build/$TARGET/llvm/bin`. + // This flag makes sure `FileCheck` is copied in the final binaries directory. + cfg.define("LLVM_INSTALL_UTILS", "ON"); + + if builder.config.llvm_profile_generate { + cfg.define("LLVM_BUILD_INSTRUMENTED", "IR"); + cfg.define("LLVM_BUILD_RUNTIME", "No"); + } + if let Some(path) = builder.config.llvm_profile_use.as_ref() { + cfg.define("LLVM_PROFDATA_FILE", &path); + } + + if target != "aarch64-apple-darwin" && !target.contains("windows") { + cfg.define("LLVM_ENABLE_ZLIB", "ON"); + } else { + cfg.define("LLVM_ENABLE_ZLIB", "OFF"); + } + + // Are we compiling for iOS/tvOS? + if target.contains("apple-ios") || target.contains("apple-tvos") { + // These two defines prevent CMake from automatically trying to add a MacOSX sysroot, which leads to a compiler error. + cfg.define("CMAKE_OSX_SYSROOT", "/"); + cfg.define("CMAKE_OSX_DEPLOYMENT_TARGET", ""); + // Prevent cmake from adding -bundle to CFLAGS automatically, which leads to a compiler error because "-bitcode_bundle" also gets added. + cfg.define("LLVM_ENABLE_PLUGINS", "OFF"); + // Zlib fails to link properly, leading to a compiler error. + cfg.define("LLVM_ENABLE_ZLIB", "OFF"); + } + + if builder.config.llvm_thin_lto { + cfg.define("LLVM_ENABLE_LTO", "Thin"); + if !target.contains("apple") { + cfg.define("LLVM_ENABLE_LLD", "ON"); + } + } + + // This setting makes the LLVM tools link to the dynamic LLVM library, + // which saves both memory during parallel links and overall disk space + // for the tools. We don't do this on every platform as it doesn't work + // equally well everywhere. + // + // If we're not linking rustc to a dynamic LLVM, though, then don't link + // tools to it. + if builder.llvm_link_tools_dynamically(target) && builder.config.llvm_link_shared { + cfg.define("LLVM_LINK_LLVM_DYLIB", "ON"); + } + + if target.starts_with("riscv") && !target.contains("freebsd") { + // RISC-V GCC erroneously requires linking against + // `libatomic` when using 1-byte and 2-byte C++ + // atomics but the LLVM build system check cannot + // detect this. Therefore it is set manually here. + // FreeBSD uses Clang as its system compiler and + // provides no libatomic in its base system so does + // not want this. + ldflags.exe.push(" -latomic"); + ldflags.shared.push(" -latomic"); + } + + if target.contains("msvc") { + cfg.define("LLVM_USE_CRT_DEBUG", "MT"); + cfg.define("LLVM_USE_CRT_RELEASE", "MT"); + cfg.define("LLVM_USE_CRT_RELWITHDEBINFO", "MT"); + cfg.static_crt(true); + } + + if target.starts_with("i686") { + cfg.define("LLVM_BUILD_32_BITS", "ON"); + } + + let mut enabled_llvm_projects = Vec::new(); + + if util::forcing_clang_based_tests() { + enabled_llvm_projects.push("clang"); + enabled_llvm_projects.push("compiler-rt"); + } + + if builder.config.llvm_polly { + enabled_llvm_projects.push("polly"); + } + + if builder.config.llvm_clang { + enabled_llvm_projects.push("clang"); + } + + // We want libxml to be disabled. + // See https://github.com/rust-lang/rust/pull/50104 + cfg.define("LLVM_ENABLE_LIBXML2", "OFF"); + + if !enabled_llvm_projects.is_empty() { + enabled_llvm_projects.sort(); + enabled_llvm_projects.dedup(); + cfg.define("LLVM_ENABLE_PROJECTS", enabled_llvm_projects.join(";")); + } + + if let Some(num_linkers) = builder.config.llvm_link_jobs { + if num_linkers > 0 { + cfg.define("LLVM_PARALLEL_LINK_JOBS", num_linkers.to_string()); + } + } + + // Workaround for ppc32 lld limitation + if target == "powerpc-unknown-freebsd" { + ldflags.exe.push(" -fuse-ld=bfd"); + } + + // https://llvm.org/docs/HowToCrossCompileLLVM.html + if target != builder.config.build { + builder.ensure(Llvm { target: builder.config.build }); + // FIXME: if the llvm root for the build triple is overridden then we + // should use llvm-tblgen from there, also should verify that it + // actually exists most of the time in normal installs of LLVM. + let host_bin = builder.llvm_out(builder.config.build).join("bin"); + cfg.define("CMAKE_CROSSCOMPILING", "True"); + cfg.define("LLVM_TABLEGEN", host_bin.join("llvm-tblgen").with_extension(EXE_EXTENSION)); + cfg.define("LLVM_NM", host_bin.join("llvm-nm").with_extension(EXE_EXTENSION)); + cfg.define( + "LLVM_CONFIG_PATH", + host_bin.join("llvm-config").with_extension(EXE_EXTENSION), + ); + } + + if let Some(ref suffix) = builder.config.llvm_version_suffix { + // Allow version-suffix="" to not define a version suffix at all. + if !suffix.is_empty() { + cfg.define("LLVM_VERSION_SUFFIX", suffix); + } + } else if builder.config.channel == "dev" { + // Changes to a version suffix require a complete rebuild of the LLVM. + // To avoid rebuilds during a time of version bump, don't include rustc + // release number on the dev channel. + cfg.define("LLVM_VERSION_SUFFIX", "-rust-dev"); + } else { + let suffix = format!("-rust-{}-{}", builder.version, builder.config.channel); + cfg.define("LLVM_VERSION_SUFFIX", suffix); + } + + if let Some(ref linker) = builder.config.llvm_use_linker { + cfg.define("LLVM_USE_LINKER", linker); + } + + if builder.config.llvm_allow_old_toolchain { + cfg.define("LLVM_TEMPORARILY_ALLOW_OLD_TOOLCHAIN", "YES"); + } + + configure_cmake(builder, target, &mut cfg, true, ldflags); + + for (key, val) in &builder.config.llvm_build_config { + cfg.define(key, val); + } + + // FIXME: we don't actually need to build all LLVM tools and all LLVM + // libraries here, e.g., we just want a few components and a few + // tools. Figure out how to filter them down and only build the right + // tools and libs on all platforms. + + if builder.config.dry_run { + return build_llvm_config; + } + + cfg.build(); + + t!(stamp.write()); + + build_llvm_config + } +} + +fn check_llvm_version(builder: &Builder<'_>, llvm_config: &Path) { + if !builder.config.llvm_version_check { + return; + } + + if builder.config.dry_run { + return; + } + + let mut cmd = Command::new(llvm_config); + let version = output(cmd.arg("--version")); + let mut parts = version.split('.').take(2).filter_map(|s| s.parse::().ok()); + if let (Some(major), Some(_minor)) = (parts.next(), parts.next()) { + if major >= 12 { + return; + } + } + panic!("\n\nbad LLVM version: {}, need >=12.0\n\n", version) +} + +fn configure_cmake( + builder: &Builder<'_>, + target: TargetSelection, + cfg: &mut cmake::Config, + use_compiler_launcher: bool, + mut ldflags: LdFlags, +) { + // Do not print installation messages for up-to-date files. + // LLVM and LLD builds can produce a lot of those and hit CI limits on log size. + cfg.define("CMAKE_INSTALL_MESSAGE", "LAZY"); + + // Do not allow the user's value of DESTDIR to influence where + // LLVM will install itself. LLVM must always be installed in our + // own build directories. + cfg.env("DESTDIR", ""); + + if builder.ninja() { + cfg.generator("Ninja"); + } + cfg.target(&target.triple).host(&builder.config.build.triple); + + if target != builder.config.build { + if target.contains("netbsd") { + cfg.define("CMAKE_SYSTEM_NAME", "NetBSD"); + } else if target.contains("freebsd") { + cfg.define("CMAKE_SYSTEM_NAME", "FreeBSD"); + } else if target.contains("windows") { + cfg.define("CMAKE_SYSTEM_NAME", "Windows"); + } else if target.contains("haiku") { + cfg.define("CMAKE_SYSTEM_NAME", "Haiku"); + } else if target.contains("solaris") || target.contains("illumos") { + cfg.define("CMAKE_SYSTEM_NAME", "SunOS"); + } + // When cross-compiling we should also set CMAKE_SYSTEM_VERSION, but in + // that case like CMake we cannot easily determine system version either. + // + // Since, the LLVM itself makes rather limited use of version checks in + // CMakeFiles (and then only in tests), and so far no issues have been + // reported, the system version is currently left unset. + } + + let sanitize_cc = |cc: &Path| { + if target.contains("msvc") { + OsString::from(cc.to_str().unwrap().replace("\\", "/")) + } else { + cc.as_os_str().to_owned() + } + }; + + // MSVC with CMake uses msbuild by default which doesn't respect these + // vars that we'd otherwise configure. In that case we just skip this + // entirely. + if target.contains("msvc") && !builder.ninja() { + return; + } + + let (cc, cxx) = match builder.config.llvm_clang_cl { + Some(ref cl) => (cl.as_ref(), cl.as_ref()), + None => (builder.cc(target), builder.cxx(target).unwrap()), + }; + + // Handle msvc + ninja + ccache specially (this is what the bots use) + if target.contains("msvc") && builder.ninja() && builder.config.ccache.is_some() { + let mut wrap_cc = env::current_exe().expect("failed to get cwd"); + wrap_cc.set_file_name("sccache-plus-cl.exe"); + + cfg.define("CMAKE_C_COMPILER", sanitize_cc(&wrap_cc)) + .define("CMAKE_CXX_COMPILER", sanitize_cc(&wrap_cc)); + cfg.env("SCCACHE_PATH", builder.config.ccache.as_ref().unwrap()) + .env("SCCACHE_TARGET", target.triple) + .env("SCCACHE_CC", &cc) + .env("SCCACHE_CXX", &cxx); + + // Building LLVM on MSVC can be a little ludicrous at times. We're so far + // off the beaten path here that I'm not really sure this is even half + // supported any more. Here we're trying to: + // + // * Build LLVM on MSVC + // * Build LLVM with `clang-cl` instead of `cl.exe` + // * Build a project with `sccache` + // * Build for 32-bit as well + // * Build with Ninja + // + // For `cl.exe` there are different binaries to compile 32/64 bit which + // we use but for `clang-cl` there's only one which internally + // multiplexes via flags. As a result it appears that CMake's detection + // of a compiler's architecture and such on MSVC **doesn't** pass any + // custom flags we pass in CMAKE_CXX_FLAGS below. This means that if we + // use `clang-cl.exe` it's always diagnosed as a 64-bit compiler which + // definitely causes problems since all the env vars are pointing to + // 32-bit libraries. + // + // To hack around this... again... we pass an argument that's + // unconditionally passed in the sccache shim. This'll get CMake to + // correctly diagnose it's doing a 32-bit compilation and LLVM will + // internally configure itself appropriately. + if builder.config.llvm_clang_cl.is_some() && target.contains("i686") { + cfg.env("SCCACHE_EXTRA_ARGS", "-m32"); + } + } else { + // If ccache is configured we inform the build a little differently how + // to invoke ccache while also invoking our compilers. + if use_compiler_launcher { + if let Some(ref ccache) = builder.config.ccache { + cfg.define("CMAKE_C_COMPILER_LAUNCHER", ccache) + .define("CMAKE_CXX_COMPILER_LAUNCHER", ccache); + } + } + cfg.define("CMAKE_C_COMPILER", sanitize_cc(cc)) + .define("CMAKE_CXX_COMPILER", sanitize_cc(cxx)) + .define("CMAKE_ASM_COMPILER", sanitize_cc(cc)); + } + + cfg.build_arg("-j").build_arg(builder.jobs().to_string()); + let mut cflags: OsString = builder.cflags(target, GitRepo::Llvm, CLang::C).join(" ").into(); + if let Some(ref s) = builder.config.llvm_cflags { + cflags.push(" "); + cflags.push(s); + } + // Some compiler features used by LLVM (such as thread locals) will not work on a min version below iOS 10. + if target.contains("apple-ios") { + if target.contains("86-") { + cflags.push(" -miphonesimulator-version-min=10.0"); + } else { + cflags.push(" -miphoneos-version-min=10.0"); + } + } + if builder.config.llvm_clang_cl.is_some() { + cflags.push(&format!(" --target={}", target)); + } + cfg.define("CMAKE_C_FLAGS", cflags); + let mut cxxflags: OsString = builder.cflags(target, GitRepo::Llvm, CLang::Cxx).join(" ").into(); + if let Some(ref s) = builder.config.llvm_cxxflags { + cxxflags.push(" "); + cxxflags.push(s); + } + if builder.config.llvm_clang_cl.is_some() { + cxxflags.push(&format!(" --target={}", target)); + } + cfg.define("CMAKE_CXX_FLAGS", cxxflags); + if let Some(ar) = builder.ar(target) { + if ar.is_absolute() { + // LLVM build breaks if `CMAKE_AR` is a relative path, for some reason it + // tries to resolve this path in the LLVM build directory. + cfg.define("CMAKE_AR", sanitize_cc(ar)); + } + } + + if let Some(ranlib) = builder.ranlib(target) { + if ranlib.is_absolute() { + // LLVM build breaks if `CMAKE_RANLIB` is a relative path, for some reason it + // tries to resolve this path in the LLVM build directory. + cfg.define("CMAKE_RANLIB", sanitize_cc(ranlib)); + } + } + + if let Some(ref flags) = builder.config.llvm_ldflags { + ldflags.push_all(flags); + } + + if let Some(flags) = get_var("LDFLAGS", &builder.config.build.triple, &target.triple) { + ldflags.push_all(&flags); + } + + // For distribution we want the LLVM tools to be *statically* linked to libstdc++. + // We also do this if the user explicitly requested static libstdc++. + if builder.config.llvm_static_stdcpp { + if !target.contains("msvc") && !target.contains("netbsd") { + if target.contains("apple") || target.contains("windows") { + ldflags.push_all("-static-libstdc++"); + } else { + ldflags.push_all("-Wl,-Bsymbolic -static-libstdc++"); + } + } + } + + cfg.define("CMAKE_SHARED_LINKER_FLAGS", &ldflags.shared); + cfg.define("CMAKE_MODULE_LINKER_FLAGS", &ldflags.module); + cfg.define("CMAKE_EXE_LINKER_FLAGS", &ldflags.exe); + + if env::var_os("SCCACHE_ERROR_LOG").is_some() { + cfg.env("RUSTC_LOG", "sccache=warn"); + } +} + +// Adapted from https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2347-L2365 +fn get_var(var_base: &str, host: &str, target: &str) -> Option { + let kind = if host == target { "HOST" } else { "TARGET" }; + let target_u = target.replace("-", "_"); + env::var_os(&format!("{}_{}", var_base, target)) + .or_else(|| env::var_os(&format!("{}_{}", var_base, target_u))) + .or_else(|| env::var_os(&format!("{}_{}", kind, var_base))) + .or_else(|| env::var_os(var_base)) +} + +#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] +pub struct Lld { + pub target: TargetSelection, +} + +impl Step for Lld { + type Output = PathBuf; + const ONLY_HOSTS: bool = true; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/llvm-project/lld").path("src/tools/lld") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Lld { target: run.target }); + } + + /// Compile LLD for `target`. + fn run(self, builder: &Builder<'_>) -> PathBuf { + if builder.config.dry_run { + return PathBuf::from("lld-out-dir-test-gen"); + } + let target = self.target; + + let llvm_config = builder.ensure(Llvm { target: self.target }); + + let out_dir = builder.lld_out(target); + let done_stamp = out_dir.join("lld-finished-building"); + if done_stamp.exists() { + return out_dir; + } + + builder.info(&format!("Building LLD for {}", target)); + let _time = util::timeit(&builder); + t!(fs::create_dir_all(&out_dir)); + + let mut cfg = cmake::Config::new(builder.src.join("src/llvm-project/lld")); + configure_cmake(builder, target, &mut cfg, true, LdFlags::default()); + + // This is an awful, awful hack. Discovered when we migrated to using + // clang-cl to compile LLVM/LLD it turns out that LLD, when built out of + // tree, will execute `llvm-config --cmakedir` and then tell CMake about + // that directory for later processing. Unfortunately if this path has + // forward slashes in it (which it basically always does on Windows) + // then CMake will hit a syntax error later on as... something isn't + // escaped it seems? + // + // Instead of attempting to fix this problem in upstream CMake and/or + // LLVM/LLD we just hack around it here. This thin wrapper will take the + // output from llvm-config and replace all instances of `\` with `/` to + // ensure we don't hit the same bugs with escaping. It means that you + // can't build on a system where your paths require `\` on Windows, but + // there's probably a lot of reasons you can't do that other than this. + let llvm_config_shim = env::current_exe().unwrap().with_file_name("llvm-config-wrapper"); + + // Re-use the same flags as llvm to control the level of debug information + // generated for lld. + let profile = match (builder.config.llvm_optimize, builder.config.llvm_release_debuginfo) { + (false, _) => "Debug", + (true, false) => "Release", + (true, true) => "RelWithDebInfo", + }; + + cfg.out_dir(&out_dir) + .profile(profile) + .env("LLVM_CONFIG_REAL", &llvm_config) + .define("LLVM_CONFIG_PATH", llvm_config_shim) + .define("LLVM_INCLUDE_TESTS", "OFF"); + + // While we're using this horrible workaround to shim the execution of + // llvm-config, let's just pile on more. I can't seem to figure out how + // to build LLD as a standalone project and also cross-compile it at the + // same time. It wants a natively executable `llvm-config` to learn + // about LLVM, but then it learns about all the host configuration of + // LLVM and tries to link to host LLVM libraries. + // + // To work around that we tell our shim to replace anything with the + // build target with the actual target instead. This'll break parts of + // LLD though which try to execute host tools, such as llvm-tblgen, so + // we specifically tell it where to find those. This is likely super + // brittle and will break over time. If anyone knows better how to + // cross-compile LLD it would be much appreciated to fix this! + if target != builder.config.build { + cfg.env("LLVM_CONFIG_SHIM_REPLACE", &builder.config.build.triple) + .env("LLVM_CONFIG_SHIM_REPLACE_WITH", &target.triple) + .define( + "LLVM_TABLEGEN_EXE", + llvm_config.with_file_name("llvm-tblgen").with_extension(EXE_EXTENSION), + ); + } + + // Explicitly set C++ standard, because upstream doesn't do so + // for standalone builds. + cfg.define("CMAKE_CXX_STANDARD", "14"); + + cfg.build(); + + t!(File::create(&done_stamp)); + out_dir + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct TestHelpers { + pub target: TargetSelection, +} + +impl Step for TestHelpers { + type Output = (); + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/test/auxiliary/rust_test_helpers.c") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(TestHelpers { target: run.target }) + } + + /// Compiles the `rust_test_helpers.c` library which we used in various + /// `run-pass` tests for ABI testing. + fn run(self, builder: &Builder<'_>) { + if builder.config.dry_run { + return; + } + // The x86_64-fortanix-unknown-sgx target doesn't have a working C + // toolchain. However, some x86_64 ELF objects can be linked + // without issues. Use this hack to compile the test helpers. + let target = if self.target == "x86_64-fortanix-unknown-sgx" { + TargetSelection::from_user("x86_64-unknown-linux-gnu") + } else { + self.target + }; + let dst = builder.test_helpers_out(target); + let src = builder.src.join("src/test/auxiliary/rust_test_helpers.c"); + if up_to_date(&src, &dst.join("librust_test_helpers.a")) { + return; + } + + builder.info("Building test helpers"); + t!(fs::create_dir_all(&dst)); + let mut cfg = cc::Build::new(); + // FIXME: Workaround for https://github.com/emscripten-core/emscripten/issues/9013 + if target.contains("emscripten") { + cfg.pic(false); + } + + // We may have found various cross-compilers a little differently due to our + // extra configuration, so inform cc of these compilers. Note, though, that + // on MSVC we still need cc's detection of env vars (ugh). + if !target.contains("msvc") { + if let Some(ar) = builder.ar(target) { + cfg.archiver(ar); + } + cfg.compiler(builder.cc(target)); + } + cfg.cargo_metadata(false) + .out_dir(&dst) + .target(&target.triple) + .host(&builder.config.build.triple) + .opt_level(0) + .warnings(false) + .debug(false) + .file(builder.src.join("src/test/auxiliary/rust_test_helpers.c")) + .compile("rust_test_helpers"); + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Sanitizers { + pub target: TargetSelection, +} + +impl Step for Sanitizers { + type Output = Vec; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/llvm-project/compiler-rt").path("src/sanitizers") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Sanitizers { target: run.target }); + } + + /// Builds sanitizer runtime libraries. + fn run(self, builder: &Builder<'_>) -> Self::Output { + let compiler_rt_dir = builder.src.join("src/llvm-project/compiler-rt"); + if !compiler_rt_dir.exists() { + return Vec::new(); + } + + let out_dir = builder.native_dir(self.target).join("sanitizers"); + let runtimes = supported_sanitizers(&out_dir, self.target, &builder.config.channel); + if runtimes.is_empty() { + return runtimes; + } + + let llvm_config = builder.ensure(Llvm { target: builder.config.build }); + if builder.config.dry_run { + return runtimes; + } + + let stamp = out_dir.join("sanitizers-finished-building"); + let stamp = HashStamp::new(stamp, builder.in_tree_llvm_info.sha()); + + if stamp.is_done() { + if stamp.hash.is_none() { + builder.info(&format!( + "Rebuild sanitizers by removing the file `{}`", + stamp.path.display() + )); + } + return runtimes; + } + + builder.info(&format!("Building sanitizers for {}", self.target)); + t!(stamp.remove()); + let _time = util::timeit(&builder); + + let mut cfg = cmake::Config::new(&compiler_rt_dir); + cfg.profile("Release"); + cfg.define("CMAKE_C_COMPILER_TARGET", self.target.triple); + cfg.define("COMPILER_RT_BUILD_BUILTINS", "OFF"); + cfg.define("COMPILER_RT_BUILD_CRT", "OFF"); + cfg.define("COMPILER_RT_BUILD_LIBFUZZER", "OFF"); + cfg.define("COMPILER_RT_BUILD_PROFILE", "OFF"); + cfg.define("COMPILER_RT_BUILD_SANITIZERS", "ON"); + cfg.define("COMPILER_RT_BUILD_XRAY", "OFF"); + cfg.define("COMPILER_RT_DEFAULT_TARGET_ONLY", "ON"); + cfg.define("COMPILER_RT_USE_LIBCXX", "OFF"); + cfg.define("LLVM_CONFIG_PATH", &llvm_config); + + // On Darwin targets the sanitizer runtimes are build as universal binaries. + // Unfortunately sccache currently lacks support to build them successfully. + // Disable compiler launcher on Darwin targets to avoid potential issues. + let use_compiler_launcher = !self.target.contains("apple-darwin"); + configure_cmake(builder, self.target, &mut cfg, use_compiler_launcher, LdFlags::default()); + + t!(fs::create_dir_all(&out_dir)); + cfg.out_dir(out_dir); + + for runtime in &runtimes { + cfg.build_target(&runtime.cmake_target); + cfg.build(); + } + t!(stamp.write()); + + runtimes + } +} + +#[derive(Clone, Debug)] +pub struct SanitizerRuntime { + /// CMake target used to build the runtime. + pub cmake_target: String, + /// Path to the built runtime library. + pub path: PathBuf, + /// Library filename that will be used rustc. + pub name: String, +} + +/// Returns sanitizers available on a given target. +fn supported_sanitizers( + out_dir: &Path, + target: TargetSelection, + channel: &str, +) -> Vec { + let darwin_libs = |os: &str, components: &[&str]| -> Vec { + components + .iter() + .map(move |c| SanitizerRuntime { + cmake_target: format!("clang_rt.{}_{}_dynamic", c, os), + path: out_dir + .join(&format!("build/lib/darwin/libclang_rt.{}_{}_dynamic.dylib", c, os)), + name: format!("librustc-{}_rt.{}.dylib", channel, c), + }) + .collect() + }; + + let common_libs = |os: &str, arch: &str, components: &[&str]| -> Vec { + components + .iter() + .map(move |c| SanitizerRuntime { + cmake_target: format!("clang_rt.{}-{}", c, arch), + path: out_dir.join(&format!("build/lib/{}/libclang_rt.{}-{}.a", os, c, arch)), + name: format!("librustc-{}_rt.{}.a", channel, c), + }) + .collect() + }; + + match &*target.triple { + "aarch64-apple-darwin" => darwin_libs("osx", &["asan", "lsan", "tsan"]), + "aarch64-fuchsia" => common_libs("fuchsia", "aarch64", &["asan"]), + "aarch64-unknown-linux-gnu" => { + common_libs("linux", "aarch64", &["asan", "lsan", "msan", "tsan", "hwasan"]) + } + "x86_64-apple-darwin" => darwin_libs("osx", &["asan", "lsan", "tsan"]), + "x86_64-fuchsia" => common_libs("fuchsia", "x86_64", &["asan"]), + "x86_64-unknown-freebsd" => common_libs("freebsd", "x86_64", &["asan", "msan", "tsan"]), + "x86_64-unknown-netbsd" => { + common_libs("netbsd", "x86_64", &["asan", "lsan", "msan", "tsan"]) + } + "x86_64-unknown-illumos" => common_libs("illumos", "x86_64", &["asan"]), + "x86_64-pc-solaris" => common_libs("solaris", "x86_64", &["asan"]), + "x86_64-unknown-linux-gnu" => { + common_libs("linux", "x86_64", &["asan", "lsan", "msan", "tsan"]) + } + "x86_64-unknown-linux-musl" => { + common_libs("linux", "x86_64", &["asan", "lsan", "msan", "tsan"]) + } + _ => Vec::new(), + } +} + +struct HashStamp { + path: PathBuf, + hash: Option>, +} + +impl HashStamp { + fn new(path: PathBuf, hash: Option<&str>) -> Self { + HashStamp { path, hash: hash.map(|s| s.as_bytes().to_owned()) } + } + + fn is_done(&self) -> bool { + match fs::read(&self.path) { + Ok(h) => self.hash.as_deref().unwrap_or(b"") == h.as_slice(), + Err(e) if e.kind() == io::ErrorKind::NotFound => false, + Err(e) => { + panic!("failed to read stamp file `{}`: {}", self.path.display(), e); + } + } + } + + fn remove(&self) -> io::Result<()> { + match fs::remove_file(&self.path) { + Ok(()) => Ok(()), + Err(e) => { + if e.kind() == io::ErrorKind::NotFound { + Ok(()) + } else { + Err(e) + } + } + } + } + + fn write(&self) -> io::Result<()> { + fs::write(&self.path, self.hash.as_deref().unwrap_or(b"")) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct CrtBeginEnd { + pub target: TargetSelection, +} + +impl Step for CrtBeginEnd { + type Output = PathBuf; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/llvm-project/compiler-rt/lib/crt") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(CrtBeginEnd { target: run.target }); + } + + /// Build crtbegin.o/crtend.o for musl target. + fn run(self, builder: &Builder<'_>) -> Self::Output { + let out_dir = builder.native_dir(self.target).join("crt"); + + if builder.config.dry_run { + return out_dir; + } + + let crtbegin_src = builder.src.join("src/llvm-project/compiler-rt/lib/crt/crtbegin.c"); + let crtend_src = builder.src.join("src/llvm-project/compiler-rt/lib/crt/crtend.c"); + if up_to_date(&crtbegin_src, &out_dir.join("crtbegin.o")) + && up_to_date(&crtend_src, &out_dir.join("crtendS.o")) + { + return out_dir; + } + + builder.info("Building crtbegin.o and crtend.o"); + t!(fs::create_dir_all(&out_dir)); + + let mut cfg = cc::Build::new(); + + if let Some(ar) = builder.ar(self.target) { + cfg.archiver(ar); + } + cfg.compiler(builder.cc(self.target)); + cfg.cargo_metadata(false) + .out_dir(&out_dir) + .target(&self.target.triple) + .host(&builder.config.build.triple) + .warnings(false) + .debug(false) + .opt_level(3) + .file(crtbegin_src) + .file(crtend_src); + + // Those flags are defined in src/llvm-project/compiler-rt/lib/crt/CMakeLists.txt + // Currently only consumer of those objects is musl, which use .init_array/.fini_array + // instead of .ctors/.dtors + cfg.flag("-std=c11") + .define("CRT_HAS_INITFINI_ARRAY", None) + .define("EH_USE_FRAME_REGISTRY", None); + + cfg.compile("crt"); + + t!(fs::copy(out_dir.join("crtbegin.o"), out_dir.join("crtbeginS.o"))); + t!(fs::copy(out_dir.join("crtend.o"), out_dir.join("crtendS.o"))); + out_dir + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct Libunwind { + pub target: TargetSelection, +} + +impl Step for Libunwind { + type Output = PathBuf; + + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path("src/llvm-project/libunwind") + } + + fn make_run(run: RunConfig<'_>) { + run.builder.ensure(Libunwind { target: run.target }); + } + + /// Build linunwind.a + fn run(self, builder: &Builder<'_>) -> Self::Output { + if builder.config.dry_run { + return PathBuf::new(); + } + + let out_dir = builder.native_dir(self.target).join("libunwind"); + let root = builder.src.join("src/llvm-project/libunwind"); + + if up_to_date(&root, &out_dir.join("libunwind.a")) { + return out_dir; + } + + builder.info(&format!("Building libunwind.a for {}", self.target.triple)); + t!(fs::create_dir_all(&out_dir)); + + let mut cc_cfg = cc::Build::new(); + let mut cpp_cfg = cc::Build::new(); + + cpp_cfg.cpp(true); + cpp_cfg.cpp_set_stdlib(None); + cpp_cfg.flag("-nostdinc++"); + cpp_cfg.flag("-fno-exceptions"); + cpp_cfg.flag("-fno-rtti"); + cpp_cfg.flag_if_supported("-fvisibility-global-new-delete-hidden"); + + for cfg in [&mut cc_cfg, &mut cpp_cfg].iter_mut() { + if let Some(ar) = builder.ar(self.target) { + cfg.archiver(ar); + } + cfg.target(&self.target.triple); + cfg.host(&builder.config.build.triple); + cfg.warnings(false); + cfg.debug(false); + // get_compiler() need set opt_level first. + cfg.opt_level(3); + cfg.flag("-fstrict-aliasing"); + cfg.flag("-funwind-tables"); + cfg.flag("-fvisibility=hidden"); + cfg.define("_LIBUNWIND_DISABLE_VISIBILITY_ANNOTATIONS", None); + cfg.include(root.join("include")); + cfg.cargo_metadata(false); + cfg.out_dir(&out_dir); + + if self.target.contains("x86_64-fortanix-unknown-sgx") { + cfg.static_flag(true); + cfg.flag("-fno-stack-protector"); + cfg.flag("-ffreestanding"); + cfg.flag("-fexceptions"); + + // easiest way to undefine since no API available in cc::Build to undefine + cfg.flag("-U_FORTIFY_SOURCE"); + cfg.define("_FORTIFY_SOURCE", "0"); + cfg.define("RUST_SGX", "1"); + cfg.define("__NO_STRING_INLINES", None); + cfg.define("__NO_MATH_INLINES", None); + cfg.define("_LIBUNWIND_IS_BAREMETAL", None); + cfg.define("__LIBUNWIND_IS_NATIVE_ONLY", None); + cfg.define("NDEBUG", None); + } + } + + cc_cfg.compiler(builder.cc(self.target)); + if let Ok(cxx) = builder.cxx(self.target) { + cpp_cfg.compiler(cxx); + } else { + cc_cfg.compiler(builder.cc(self.target)); + } + + // Don't set this for clang + // By default, Clang builds C code in GNU C17 mode. + // By default, Clang builds C++ code according to the C++98 standard, + // with many C++11 features accepted as extensions. + if cc_cfg.get_compiler().is_like_gnu() { + cc_cfg.flag("-std=c99"); + } + if cpp_cfg.get_compiler().is_like_gnu() { + cpp_cfg.flag("-std=c++11"); + } + + if self.target.contains("x86_64-fortanix-unknown-sgx") || self.target.contains("musl") { + // use the same GCC C compiler command to compile C++ code so we do not need to setup the + // C++ compiler env variables on the builders. + // Don't set this for clang++, as clang++ is able to compile this without libc++. + if cpp_cfg.get_compiler().is_like_gnu() { + cpp_cfg.cpp(false); + cpp_cfg.compiler(builder.cc(self.target)); + } + } + + let mut c_sources = vec![ + "Unwind-sjlj.c", + "UnwindLevel1-gcc-ext.c", + "UnwindLevel1.c", + "UnwindRegistersRestore.S", + "UnwindRegistersSave.S", + ]; + + let cpp_sources = vec!["Unwind-EHABI.cpp", "Unwind-seh.cpp", "libunwind.cpp"]; + let cpp_len = cpp_sources.len(); + + if self.target.contains("x86_64-fortanix-unknown-sgx") { + c_sources.push("UnwindRustSgx.c"); + } + + for src in c_sources { + cc_cfg.file(root.join("src").join(src).canonicalize().unwrap()); + } + + for src in &cpp_sources { + cpp_cfg.file(root.join("src").join(src).canonicalize().unwrap()); + } + + cpp_cfg.compile("unwind-cpp"); + + // FIXME: https://github.com/alexcrichton/cc-rs/issues/545#issuecomment-679242845 + let mut count = 0; + for entry in fs::read_dir(&out_dir).unwrap() { + let file = entry.unwrap().path().canonicalize().unwrap(); + if file.is_file() && file.extension() == Some(OsStr::new("o")) { + // file name starts with "Unwind-EHABI", "Unwind-seh" or "libunwind" + let file_name = file.file_name().unwrap().to_str().expect("UTF-8 file name"); + if cpp_sources.iter().any(|f| file_name.starts_with(&f[..f.len() - 4])) { + cc_cfg.object(&file); + count += 1; + } + } + } + assert_eq!(cpp_len, count, "Can't get object files from {:?}", &out_dir); + + cc_cfg.compile("unwind"); + out_dir + } +} diff --git a/src/bootstrap/tool.rs b/src/bootstrap/tool.rs index f13d365e3754d..acf2d5445e26d 100644 --- a/src/bootstrap/tool.rs +++ b/src/bootstrap/tool.rs @@ -215,68 +215,68 @@ macro_rules! bootstrap_tool { pub enum Tool { $( $name, - )+ + )+ } impl<'a> Builder<'a> { pub fn tool_exe(&self, tool: Tool) -> PathBuf { match tool { $(Tool::$name => - self.ensure($name { - compiler: self.compiler(0, self.config.build), - target: self.config.build, - }), - )+ + self.ensure($name { + compiler: self.compiler(0, self.config.build), + target: self.config.build, + }), + )+ } } } $( #[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)] - pub struct $name { - pub compiler: Compiler, - pub target: TargetSelection, - } + pub struct $name { + pub compiler: Compiler, + pub target: TargetSelection, + } - impl Step for $name { - type Output = PathBuf; + impl Step for $name { + type Output = PathBuf; - fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { - run.path($path) - } + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + run.path($path) + } - fn make_run(run: RunConfig<'_>) { - run.builder.ensure($name { - // snapshot compiler - compiler: run.builder.compiler(0, run.builder.config.build), - target: run.target, - }); - } + fn make_run(run: RunConfig<'_>) { + run.builder.ensure($name { + // snapshot compiler + compiler: run.builder.compiler(0, run.builder.config.build), + target: run.target, + }); + } - fn run(self, builder: &Builder<'_>) -> PathBuf { - builder.ensure(ToolBuild { - compiler: self.compiler, - target: self.target, - tool: $tool_name, - mode: if false $(|| $unstable)* { - // use in-tree libraries for unstable features - Mode::ToolStd - } else { - Mode::ToolBootstrap - }, - path: $path, - is_optional_tool: false, - source_type: if false $(|| $external)* { - SourceType::Submodule - } else { - SourceType::InTree - }, - extra_features: vec![], - allow_features: concat!($($allow_features)*), - }).expect("expected to build -- essential tool") + fn run(self, builder: &Builder<'_>) -> PathBuf { + builder.ensure(ToolBuild { + compiler: self.compiler, + target: self.target, + tool: $tool_name, + mode: if false $(|| $unstable)* { + // use in-tree libraries for unstable features + Mode::ToolStd + } else { + Mode::ToolBootstrap + }, + path: $path, + is_optional_tool: false, + source_type: if false $(|| $external)* { + SourceType::Submodule + } else { + SourceType::InTree + }, + extra_features: vec![], + allow_features: concat!($($allow_features)*), + }).expect("expected to build -- essential tool") + } } - } - )+ + )+ } } @@ -738,21 +738,21 @@ macro_rules! tool_extended { ;)+) => { $( #[derive(Debug, Clone, Hash, PartialEq, Eq)] - pub struct $name { - pub compiler: Compiler, - pub target: TargetSelection, - pub extra_features: Vec, - } + pub struct $name { + pub compiler: Compiler, + pub target: TargetSelection, + pub extra_features: Vec, + } - impl Step for $name { - type Output = Option; - const DEFAULT: bool = true; // Overwritten below - const ONLY_HOSTS: bool = true; + impl Step for $name { + type Output = Option; + const DEFAULT: bool = true; // Overwritten below + const ONLY_HOSTS: bool = true; - fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { - let builder = run.builder; - run.path($path).default_condition( - builder.config.extended + fn should_run(run: ShouldRun<'_>) -> ShouldRun<'_> { + let builder = run.builder; + run.path($path).default_condition( + builder.config.extended && builder.config.tools.as_ref().map_or( // By default, on nightly/dev enable all tools, else only // build stable tools. @@ -762,55 +762,55 @@ macro_rules! tool_extended { tools.iter().any(|tool| match tool.as_ref() { "clippy" => $tool_name == "clippy-driver", x => $tool_name == x, - }) - }), - ) - } + }) + }), + ) + } - fn make_run(run: RunConfig<'_>) { - run.builder.ensure($name { - compiler: run.builder.compiler(run.builder.top_stage, run.builder.config.build), - target: run.target, - extra_features: Vec::new(), - }); - } + fn make_run(run: RunConfig<'_>) { + run.builder.ensure($name { + compiler: run.builder.compiler(run.builder.top_stage, run.builder.config.build), + target: run.target, + extra_features: Vec::new(), + }); + } - #[allow(unused_mut)] - fn run(mut $sel, $builder: &Builder<'_>) -> Option { - let tool = $builder.ensure(ToolBuild { - compiler: $sel.compiler, - target: $sel.target, - tool: $tool_name, - mode: if false $(|| $tool_std)? { Mode::ToolStd } else { Mode::ToolRustc }, - path: $path, - extra_features: $sel.extra_features, - is_optional_tool: true, - source_type: SourceType::InTree, - allow_features: concat!($($allow_features)*), - })?; - - if (false $(|| !$add_bins_to_sysroot.is_empty())?) && $sel.compiler.stage > 0 { - let bindir = $builder.sysroot($sel.compiler).join("bin"); - t!(fs::create_dir_all(&bindir)); - - #[allow(unused_variables)] - let tools_out = $builder - .cargo_out($sel.compiler, Mode::ToolRustc, $sel.target); - - $(for add_bin in $add_bins_to_sysroot { - let bin_source = tools_out.join(exe(add_bin, $sel.target)); - let bin_destination = bindir.join(exe(add_bin, $sel.compiler.host)); - $builder.copy(&bin_source, &bin_destination); - })? - - let tool = bindir.join(exe($tool_name, $sel.compiler.host)); - Some(tool) - } else { - Some(tool) + #[allow(unused_mut)] + fn run(mut $sel, $builder: &Builder<'_>) -> Option { + let tool = $builder.ensure(ToolBuild { + compiler: $sel.compiler, + target: $sel.target, + tool: $tool_name, + mode: if false $(|| $tool_std)? { Mode::ToolStd } else { Mode::ToolRustc }, + path: $path, + extra_features: $sel.extra_features, + is_optional_tool: true, + source_type: SourceType::InTree, + allow_features: concat!($($allow_features)*), + })?; + + if (false $(|| !$add_bins_to_sysroot.is_empty())?) && $sel.compiler.stage > 0 { + let bindir = $builder.sysroot($sel.compiler).join("bin"); + t!(fs::create_dir_all(&bindir)); + + #[allow(unused_variables)] + let tools_out = $builder + .cargo_out($sel.compiler, Mode::ToolRustc, $sel.target); + + $(for add_bin in $add_bins_to_sysroot { + let bin_source = tools_out.join(exe(add_bin, $sel.target)); + let bin_destination = bindir.join(exe(add_bin, $sel.compiler.host)); + $builder.copy(&bin_source, &bin_destination); + })? + + let tool = bindir.join(exe($tool_name, $sel.compiler.host)); + Some(tool) + } else { + Some(tool) + } } } - } - )+ + )+ } } diff --git a/src/test/ui/terminal-width/flag-human.rs b/src/test/ui/terminal-width/flag-human.rs new file mode 100644 index 0000000000000..4b94ebb01fc8e --- /dev/null +++ b/src/test/ui/terminal-width/flag-human.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 + +// This test checks that `-Z diagnostic-width` effects the human error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.rs b/src/test/ui/terminal-width/flag-json.rs new file mode 100644 index 0000000000000..3add1d7d9301e --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.rs @@ -0,0 +1,9 @@ +// compile-flags: --diagnostic-width=20 --error-format=json + +// This test checks that `-Z diagnostic-width` effects the JSON error output by restricting it to an +// arbitrarily low value so that the effect is visible. + +fn main() { + let _: () = 42; + //~^ ERROR mismatched types +} diff --git a/src/test/ui/terminal-width/flag-json.stderr b/src/test/ui/terminal-width/flag-json.stderr new file mode 100644 index 0000000000000..b21391d1640ef --- /dev/null +++ b/src/test/ui/terminal-width/flag-json.stderr @@ -0,0 +1,40 @@ +{"message":"mismatched types","code":{"code":"E0308","explanation":"Expected type did not match the received type. + +Erroneous code examples: + +```compile_fail,E0308 +fn plus_one(x: i32) -> i32 { + x + 1 +} + +plus_one(\"Not a number\"); +// ^^^^^^^^^^^^^^ expected `i32`, found `&str` + +if \"Not a bool\" { +// ^^^^^^^^^^^^ expected `bool`, found `&str` +} + +let x: f32 = \"Not a float\"; +// --- ^^^^^^^^^^^^^ expected `f32`, found `&str` +// | +// expected due to this +``` + +This error occurs when an expression was used in a place where the compiler +expected an expression of a different type. It can occur in several cases, the +most common being when calling a function and passing an argument which has a +different type than the matching type in the function declaration. +"},"level":"error","spans":[{"file_name":"$DIR/flag-json.rs","byte_start":243,"byte_end":245,"line_start":7,"line_end":7,"column_start":17,"column_end":19,"is_primary":true,"text":[{"text":" let _: () = 42;","highlight_start":17,"highlight_end":19}],"label":"expected `()`, found integer","suggested_replacement":null,"suggestion_applicability":null,"expansion":null},{"file_name":"$DIR/flag-json.rs","byte_start":238,"byte_end":240,"line_start":7,"line_end":7,"column_start":12,"column_end":14,"is_primary":false,"text":[{"text":" let _: () = 42;","highlight_start":12,"highlight_end":14}],"label":"expected due to this","suggested_replacement":null,"suggestion_applicability":null,"expansion":null}],"children":[],"rendered":"error[E0308]: mismatched types + --> $DIR/flag-json.rs:7:17 + | +LL | ..._: () = 42; + | -- ^^ expected `()`, found integer + | | + | expected due to this + +"} +{"message":"aborting due to previous error","code":null,"level":"error","spans":[],"children":[],"rendered":"error: aborting due to previous error + +"} +{"message":"For more information about this error, try `rustc --explain E0308`.","code":null,"level":"failure-note","spans":[],"children":[],"rendered":"For more information about this error, try `rustc --explain E0308`. +"} diff --git a/src/tools/enzyme b/src/tools/enzyme new file mode 160000 index 0000000000000..18d3da56dffab --- /dev/null +++ b/src/tools/enzyme @@ -0,0 +1 @@ +Subproject commit 18d3da56dffab7459b222322a9fc2fcc36d092d5 diff --git a/tests/rustdoc-ui/doctest/terminal-width.rs b/tests/rustdoc-ui/doctest/terminal-width.rs new file mode 100644 index 0000000000000..61961d5ec710e --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.rs @@ -0,0 +1,5 @@ +// compile-flags: -Zunstable-options --diagnostic-width=10 +#![deny(rustdoc::bare_urls)] + +/// This is a long line that contains a http://link.com +pub struct Foo; //~^ ERROR diff --git a/tests/rustdoc-ui/doctest/terminal-width.stderr b/tests/rustdoc-ui/doctest/terminal-width.stderr new file mode 100644 index 0000000000000..fed049d2b37bc --- /dev/null +++ b/tests/rustdoc-ui/doctest/terminal-width.stderr @@ -0,0 +1,15 @@ +error: this URL is not a hyperlink + --> $DIR/diagnostic-width.rs:4:41 + | +LL | ... a http://link.com + | ^^^^^^^^^^^^^^^ help: use an automatic link instead: `` + | +note: the lint level is defined here + --> $DIR/diagnostic-width.rs:2:9 + | +LL | ...ny(rustdoc::bare_url... + | ^^^^^^^^^^^^^^^^^^ + = note: bare URLs are not automatically turned into clickable links + +error: aborting due to previous error + diff --git a/tests/ui/json/autodiff.rs b/tests/ui/json/autodiff.rs new file mode 100644 index 0000000000000..54f94c3765bf6 --- /dev/null +++ b/tests/ui/json/autodiff.rs @@ -0,0 +1,16 @@ +// Check autodiff attribute +// edition:2018 + +extern "C" fn rosenbrock(a: f32, b: f32, x: f32, y: f32) -> f32 { + let (z, w) = (a-x, y-x*x); + + z*z + b*w*w +} + +#[autodiff(rosenbrock, mode = "forward")] +extern "C" { + fn dx_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_x: &mut f32); + fn dy_rosenbrock(a: f32, b: f32, x: f32, y: f32, d_y: &mut f32); +} + +fn main() {}