Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ members = [
"candle-rotary",
"candle-flash-attn-v1",
"candle-cublaslt",
"candle-moe",
]
resolver = "2"

Expand All @@ -23,7 +24,8 @@ candle = { version = "0.*", package = "candle-core", features = ["cuda"]}
cudarc = { version = "0.*" }
half = { version = "2.3.1", features = ["num-traits"] }
# Dev
candle-nn = { version = "0.*", features = ["cuda"] }
candle-nn = { version = "0.8", features = ["cuda"] }
candle-transformers = { version = "0.8" }
# Build
anyhow = { version = "1", features = ["backtrace"] }
bindgen_cuda = "0.1.1"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ raw candle expressions, usually because they *fuse* kernels directly.
- [candle-layer-norm](./candle-layer-norm)
- [candle-rotary](./candle-rotary)
- [candle-flash-attn-v1](./candle-flash-attn-v1)
- [candle-moe](./candle-moe)
23 changes: 23 additions & 0 deletions candle-moe/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "candle-moe"
description = "fused MoE layer for the candle ML framework."
readme = "README.md"
version.workspace = true
edition.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
repository.workspace = true

[dependencies]
candle = { workspace = true }
half = { workspace = true }

[build-dependencies]
anyhow = { workspace = true }
bindgen_cuda = { workspace = true }

[dev-dependencies]
anyhow = { workspace = true }
candle-nn = { workspace = true }
candle-transformer = { workspace = true }
3 changes: 3 additions & 0 deletions candle-moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# candle-moe

fused MoE kernel in Candle backend
77 changes: 77 additions & 0 deletions candle-moe/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use std::path::PathBuf;

const KERNEL_FILES: [&str; 2] = ["kernels/topk_softmax.cu", "kernels/fused_moe.cu"];

fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed={kernel_file}");
}

let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?);
let build_dir = match std::env::var("CANDLE_MOE_BUILD_DIR") {
Err(_) =>
{
#[allow(clippy::redundant_clone)]
out_dir.clone()
}
Ok(build_dir) => {
let path = PathBuf::from(build_dir);
let current_dir = std::env::current_dir()?;
path.canonicalize().unwrap_or_else(|_| {
panic!(
"Directory doesn't exists: {} (the current directory is {})",
&path.display(),
current_dir.display()
)
})
}
};

let kernels: Vec<_> = KERNEL_FILES.iter().collect();
let builder = bindgen_cuda::Builder::default()
.kernel_paths(kernels)
.out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("--compiler-options")
.arg("-fPIC")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--ptxas-options=-v")
.arg("--verbose");

let target = std::env::var("TARGET").unwrap();

let out_file = if target.contains("msvc") {
build_dir.join("moe.lib")
} else {
build_dir.join("libmoe.a")
};
builder.build_lib(out_file);

println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=moe");
println!("cargo:rustc-link-lib=dylib=cudart");

if target.contains("msvc") {
// nothing to link to
} else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") {
println!("cargo:rustc-link-lib=dylib=c++");
} else if target.contains("android") {
println!("cargo:rustc-link-lib=dylib=c++_shared");
} else {
println!("cargo:rustc-link-lib=dylib=stdc++");
}

Ok(())
}
49 changes: 49 additions & 0 deletions candle-moe/kernels/cuda_compat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif

#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor(var, lane_mask, width)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif

#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif

#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
Loading