Skip to content

Commit

Permalink
feat: support metal
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 16, 2024
1 parent 0ace617 commit def76e3
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 11 deletions.
34 changes: 34 additions & 0 deletions .github/workflows/binaries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,37 @@ jobs:
draft: true
files: ${{ steps.precompile.outputs.file-path }}
if: startsWith(github.ref, 'refs/tags/')

build_metal:
name: metal / ${{ matrix.target }} / ${{ matrix.os }}
runs-on: macos-13
permissions:
contents: write
strategy:
fail-fast: false
matrix:
include:
- target: aarch64-apple-darwin
os: macos-13

steps:
- uses: actions/checkout@v4
- run: rustup target add ${{ matrix.target }}

- uses: philss/rustler-precompiled-action@main
id: precompile
with:
project-dir: ${{ env.PROJECT_DIR }}
project-name: ${{ env.PROJECT_NAME }}
project-version: ${{ env.PROJECT_VERSION }}
target: ${{ matrix.target }}
use-cross: null
nif-version: ${{ env.NIF_VERSION }}
variant: metal
cargo-args: "--features metal"

- uses: softprops/action-gh-release@v1
with:
draft: true
files: ${{ steps.precompile.outputs.file-path }}
if: startsWith(github.ref, 'refs/tags/')
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ only the host CPU.
| --- | --- |
| cpu | |
| cuda | CUDA 12.x |
| metal | Metal ? |

To use Candlex with NVidia GPU you need [CUDA](https://developer.nvidia.com/cuda-downloads) compatible with your
GPU drivers.
Expand Down
4 changes: 3 additions & 1 deletion config/config.exs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Config

config :candlex, use_cuda: System.get_env("CANDLEX_NIF_TARGET") == "cuda"
config :candlex,
use_cuda: System.get_env("CANDLEX_NIF_TARGET") == "cuda",
use_metal: System.get_env("CANDLEX_NIF_TARGET") == "metal"
18 changes: 14 additions & 4 deletions lib/candlex/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ defmodule Candlex.Backend do

@device_cuda :cuda
@device_cpu :cpu
@device_metal :metal

@impl true
def init(opts) do
Expand Down Expand Up @@ -1198,10 +1199,15 @@ defmodule Candlex.Backend do
end

defp default_device do
if cuda_available?() do
@device_cuda
else
@device_cpu
cond do
cuda_available?() ->
@device_cuda

metal_available?() ->
@device_metal

true ->
@device_cpu
end
end

Expand All @@ -1217,6 +1223,10 @@ defmodule Candlex.Backend do
Native.is_cuda_available()
end

def metal_available? do
Native.is_metal_available()
end

defp unsupported_dtype(t) do
raise("Unsupported candle dtype for #{inspect(t)}")
end
Expand Down
11 changes: 10 additions & 1 deletion lib/candlex/native.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@ defmodule Candlex.Native do
# mode = if Mix.env() in [:dev, :test], do: :debug, else: :release
mode = :release

features =
cond do
Application.compile_env(:candlex, :use_cuda) -> [:cuda]
Application.compile_env(:candlex, :use_metal) -> [:metal]
true -> []
end

use RustlerPrecompiled,
otp_app: :candlex,
features: if(Application.compile_env(:candlex, :use_cuda), do: [:cuda], else: []),
features: features,
base_url: "#{source_url}/releases/download/v#{version}",
force_build: System.get_env("CANDLEX_NIF_BUILD") in ["1", "true"],
mode: mode,
Expand All @@ -25,6 +32,7 @@ defmodule Candlex.Native do
"x86_64-unknown-linux-gnu"
],
variants: %{
"aarch64-apple-darwin" => [metal: fn -> Application.compile_env(:candlex, :use_metal) end],
"x86_64-unknown-linux-gnu" => [cuda: fn -> Application.compile_env(:candlex, :use_cuda) end]
}

Expand Down Expand Up @@ -138,6 +146,7 @@ defmodule Candlex.Native do
end

def is_cuda_available(), do: error()
def is_metal_available(), do: error()
def to_device(_tensor, _device), do: error()

defp error(), do: :erlang.nif_error(:nif_not_loaded)
Expand Down
168 changes: 168 additions & 0 deletions native/candlex/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions native/candlex/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ bindgen_cuda = { version = "0.1.1", optional = true }

[features]
cuda = ["candle-core/cuda", "dep:bindgen_cuda"]
metal = ["candle-core/metal"]
5 changes: 5 additions & 0 deletions native/candlex/src/devices.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
pub fn is_cuda_available() -> bool {
candle_core::utils::cuda_is_available()
}

#[rustler::nif(schedule = "DirtyCpu")]
pub fn is_metal_available() -> bool {
candle_core::utils::metal_is_available()
}
3 changes: 2 additions & 1 deletion native/candlex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ rustler::init! {
tensors::sum_pool2d,
tensors::max_pool2d,
tensors::contiguous,
devices::is_cuda_available
devices::is_cuda_available,
devices::is_metal_available
],
load = load
}
Loading

0 comments on commit def76e3

Please sign in to comment.