diff --git a/.github/workflows/binaries.yml b/.github/workflows/binaries.yml index 85e005f..78a5f82 100644 --- a/.github/workflows/binaries.yml +++ b/.github/workflows/binaries.yml @@ -102,3 +102,37 @@ jobs: draft: true files: ${{ steps.precompile.outputs.file-path }} if: startsWith(github.ref, 'refs/tags/') + + build_metal: + name: metal / ${{ matrix.target }} / ${{ matrix.os }} + runs-on: macos-13 + permissions: + contents: write + strategy: + fail-fast: false + matrix: + include: + - target: aarch64-apple-darwin + os: macos-13 + + steps: + - uses: actions/checkout@v4 + - run: rustup target add ${{ matrix.target }} + + - uses: philss/rustler-precompiled-action@main + id: precompile + with: + project-dir: ${{ env.PROJECT_DIR }} + project-name: ${{ env.PROJECT_NAME }} + project-version: ${{ env.PROJECT_VERSION }} + target: ${{ matrix.target }} + use-cross: null + nif-version: ${{ env.NIF_VERSION }} + variant: metal + cargo-args: "--features metal" + + - uses: softprops/action-gh-release@v1 + with: + draft: true + files: ${{ steps.precompile.outputs.file-path }} + if: startsWith(github.ref, 'refs/tags/') diff --git a/README.md b/README.md index e5df25e..197eecb 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ only the host CPU. | --- | --- | | cpu | | | cuda | CUDA 12.x | +| metal | Metal ? | To use Candlex with NVidia GPU you need [CUDA](https://developer.nvidia.com/cuda-downloads) compatible with your GPU drivers. diff --git a/config/config.exs b/config/config.exs index aff71b9..8b0424b 100644 --- a/config/config.exs +++ b/config/config.exs @@ -1,3 +1,5 @@ import Config -config :candlex, use_cuda: System.get_env("CANDLEX_NIF_TARGET") == "cuda" +config :candlex, + use_cuda: System.get_env("CANDLEX_NIF_TARGET") == "cuda", + use_metal: System.get_env("CANDLEX_NIF_TARGET") == "metal" diff --git a/lib/candlex/native.ex b/lib/candlex/native.ex index 427d39b..64386c4 100644 --- a/lib/candlex/native.ex +++ b/lib/candlex/native.ex @@ -25,6 +25,7 @@ defmodule Candlex.Native do "x86_64-unknown-linux-gnu" ], variants: %{ + "aarch64-apple-darwin" => [metal: fn -> Application.compile_env(:candlex, :use_metal) end], "x86_64-unknown-linux-gnu" => [cuda: fn -> Application.compile_env(:candlex, :use_cuda) end] } diff --git a/native/candlex/Cargo.lock b/native/candlex/Cargo.lock index 8ea7256..f1816cf 100644 --- a/native/candlex/Cargo.lock +++ b/native/candlex/Cargo.lock @@ -71,6 +71,18 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + +[[package]] +name = "block" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" + [[package]] name = "bytemuck" version = "1.14.0" @@ -104,6 +116,8 @@ source = "git+https://github.com/huggingface/candle#c7e613ab5efd46934eddbc16f18a dependencies = [ "byteorder", "candle-kernels", + "candle-metal", + "candle-metal-kernels", "cudarc", "gemm", "half", @@ -129,6 +143,33 @@ dependencies = [ "rayon", ] +[[package]] +name = "candle-metal" +version = "0.27.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e697df3a971c0299102fed60943db63442c37cd9985fba7856266b5aa808b1b" +dependencies = [ + "bitflags 2.4.1", + "block", + "core-graphics-types", + "foreign-types", + "half", + "log", + "objc", + "paste", +] + +[[package]] +name = "candle-metal-kernels" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle#c7e613ab5efd46934eddbc16f18aeea4dab4366a" +dependencies = [ + "candle-metal", + "once_cell", + "thiserror", + "tracing", +] + [[package]] name = "candlex" version = "0.1.0" @@ -157,6 +198,33 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + +[[package]] +name = "core-graphics-types" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bb142d41022986c1d8ff29103a1411c8a3dfad3552f87a4f8dc50d61d4f4e33" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + [[package]] name = "crc32fast" version = "1.3.2" @@ -230,6 +298,33 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "gemm" version = "0.16.14" @@ -420,6 +515,21 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "malloc_buf" +version = "0.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62bb907fe88d54d8d9ce32a3cceab4218ed2f6b7d35617cafe9adf84e43919cb" +dependencies = [ + "libc", +] + [[package]] name = "matrixmultiply" version = "0.3.8" @@ -544,6 +654,25 @@ dependencies = [ "libc", ] +[[package]] +name = "objc" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" +dependencies = [ + "malloc_buf", + "objc_exception", +] + +[[package]] +name = "objc_exception" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" +dependencies = [ + "cc", +] + [[package]] name = "object" version = "0.32.1" @@ -565,6 +694,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -646,7 +781,7 @@ version = "10.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -903,6 +1038,37 @@ dependencies = [ "syn 2.0.37", ] +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/native/candlex/Cargo.toml b/native/candlex/Cargo.toml index 306760b..e71536a 100644 --- a/native/candlex/Cargo.toml +++ b/native/candlex/Cargo.toml @@ -22,3 +22,4 @@ anyhow = "1.0.75" [features] cuda = ["candle-core/cuda"] +metal = ["candle-core/metal"]