From afa0e1deb4d5ad71f352432471efa664280087a7 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 10 Sep 2025 21:02:51 +0900 Subject: [PATCH 1/2] feature: fused MoE kernel --- .vscode/settings.json | 14 + Cargo.toml | 4 +- README.md | 1 + candle-moe/Cargo.toml | 23 ++ candle-moe/README.md | 3 + candle-moe/build.rs | 77 ++++ candle-moe/kernels/cuda_compat.h | 49 +++ candle-moe/kernels/fused_moe.cu | 521 +++++++++++++++++++++++++ candle-moe/kernels/topk_softmax.cu | 497 +++++++++++++++++++++++ candle-moe/src/ffi.rs | 49 +++ candle-moe/src/lib.rs | 421 ++++++++++++++++++++ candle-moe/tests/moe_tests.rs | 143 +++++++ candle-moe/tests/topk_softmax_tests.rs | 58 +++ 13 files changed, 1859 insertions(+), 1 deletion(-) create mode 100644 .vscode/settings.json create mode 100644 candle-moe/Cargo.toml create mode 100644 candle-moe/README.md create mode 100644 candle-moe/build.rs create mode 100644 candle-moe/kernels/cuda_compat.h create mode 100644 candle-moe/kernels/fused_moe.cu create mode 100644 candle-moe/kernels/topk_softmax.cu create mode 100644 candle-moe/src/ffi.rs create mode 100644 candle-moe/src/lib.rs create mode 100644 candle-moe/tests/moe_tests.rs create mode 100644 candle-moe/tests/topk_softmax_tests.rs diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c9c987c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,14 @@ +{ + "files.associations": { + "array": "cpp", + "format": "cpp", + "initializer_list": "cpp", + "list": "cpp", + "utility": "cpp", + "vector": "cpp", + "xhash": "cpp", + "xstring": "cpp", + "xtree": "cpp", + "xutility": "cpp" + } +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 62d015f..85875d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "candle-rotary", "candle-flash-attn-v1", "candle-cublaslt", + "candle-moe", ] resolver = "2" @@ -23,7 +24,8 @@ candle = { version = "0.*", package = "candle-core", features = ["cuda"]} cudarc = { version = "0.*" } half = { version = "2.3.1", features = ["num-traits"] } # Dev -candle-nn = { version = "0.*", features = ["cuda"] } +candle-nn = { version = "0.8", features = ["cuda"] } +candle-transformers = { version = "0.8" } # Build anyhow = { version = "1", features = ["backtrace"] } bindgen_cuda = "0.1.1" diff --git a/README.md b/README.md index f7779c7..9247f25 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,4 @@ raw candle expressions, usually because they *fuse* kernels directly. - [candle-layer-norm](./candle-layer-norm) - [candle-rotary](./candle-rotary) - [candle-flash-attn-v1](./candle-flash-attn-v1) +- [candle-moe](./candle-moe) diff --git a/candle-moe/Cargo.toml b/candle-moe/Cargo.toml new file mode 100644 index 0000000..f924974 --- /dev/null +++ b/candle-moe/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "candle-moe" +description = "fused MoE layer for the candle ML framework." +readme = "README.md" +version.workspace = true +edition.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +repository.workspace = true + +[dependencies] +candle = { workspace = true } +half = { workspace = true } + +[build-dependencies] +anyhow = { workspace = true } +bindgen_cuda = { workspace = true } + +[dev-dependencies] +anyhow = { workspace = true } +candle-nn = { workspace = true } +candle-transformer = { workspace = true } diff --git a/candle-moe/README.md b/candle-moe/README.md new file mode 100644 index 0000000..7f18d7e --- /dev/null +++ b/candle-moe/README.md @@ -0,0 +1,3 @@ +# candle-moe + +fused MoE kernel in Candle backend diff --git a/candle-moe/build.rs b/candle-moe/build.rs new file mode 100644 index 0000000..a299d9d --- /dev/null +++ b/candle-moe/build.rs @@ -0,0 +1,77 @@ +// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. +// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment +// variable in order to cache the compiled artifacts and avoid recompiling too often. +use anyhow::{Context, Result}; +use std::path::PathBuf; + +const KERNEL_FILES: [&str; 2] = ["kernels/topk_softmax.cu", "kernels/fused_moe.cu"]; + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + for kernel_file in KERNEL_FILES.iter() { + println!("cargo:rerun-if-changed={kernel_file}"); + } + + let out_dir = PathBuf::from(std::env::var("OUT_DIR").context("OUT_DIR not set")?); + let build_dir = match std::env::var("CANDLE_MOE_BUILD_DIR") { + Err(_) => + { + #[allow(clippy::redundant_clone)] + out_dir.clone() + } + Ok(build_dir) => { + let path = PathBuf::from(build_dir); + let current_dir = std::env::current_dir()?; + path.canonicalize().unwrap_or_else(|_| { + panic!( + "Directory doesn't exists: {} (the current directory is {})", + &path.display(), + current_dir.display() + ) + }) + } + }; + + let kernels: Vec<_> = KERNEL_FILES.iter().collect(); + let builder = bindgen_cuda::Builder::default() + .kernel_paths(kernels) + .out_dir(build_dir.clone()) + .arg("-std=c++17") + .arg("-O3") + .arg("--compiler-options") + .arg("-fPIC") + .arg("-U__CUDA_NO_HALF_OPERATORS__") + .arg("-U__CUDA_NO_HALF_CONVERSIONS__") + .arg("-U__CUDA_NO_HALF2_OPERATORS__") + .arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__") + .arg("--expt-relaxed-constexpr") + .arg("--expt-extended-lambda") + .arg("--use_fast_math") + .arg("--ptxas-options=-v") + .arg("--verbose"); + + let target = std::env::var("TARGET").unwrap(); + + let out_file = if target.contains("msvc") { + build_dir.join("moe.lib") + } else { + build_dir.join("libmoe.a") + }; + builder.build_lib(out_file); + + println!("cargo:rustc-link-search={}", build_dir.display()); + println!("cargo:rustc-link-lib=moe"); + println!("cargo:rustc-link-lib=dylib=cudart"); + + if target.contains("msvc") { + // nothing to link to + } else if target.contains("apple") || target.contains("freebsd") || target.contains("openbsd") { + println!("cargo:rustc-link-lib=dylib=c++"); + } else if target.contains("android") { + println!("cargo:rustc-link-lib=dylib=c++_shared"); + } else { + println!("cargo:rustc-link-lib=dylib=stdc++"); + } + + Ok(()) +} diff --git a/candle-moe/kernels/cuda_compat.h b/candle-moe/kernels/cuda_compat.h new file mode 100644 index 0000000..82e5561 --- /dev/null +++ b/candle-moe/kernels/cuda_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/candle-moe/kernels/fused_moe.cu b/candle-moe/kernels/fused_moe.cu new file mode 100644 index 0000000..63892c9 --- /dev/null +++ b/candle-moe/kernels/fused_moe.cu @@ -0,0 +1,521 @@ +#include +#include +#include + +#include +#include +#include + +// Block sizes for tiled matrix multiplication +#define BLOCK_M 64 +#define BLOCK_N 64 +#define BLOCK_K 32 +#define WARP_SIZE 32 + +// Structure to hold sorted token-expert pairs +struct TokenExpertPair { + int token_idx; + int expert_idx; + float routing_weight; + int original_idx; // Position in original routing weights +}; + +// Helper function for SiLU activation +__device__ __forceinline__ float silu(float x) { + return x / (1.0f + expf(-x)); +} + +// Helper function for GELU activation +__device__ __forceinline__ float gelu(float x) { + return 0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x * x * x))); +} + +template +__global__ void nomic_fused_moe_kernel( + const T* __restrict__ input, // [num_tokens, hidden_dim] + const T* __restrict__ gate_weights, // [num_experts, hidden_dim, intermediate_dim] + const T* __restrict__ up_weights, // [num_experts, hidden_dim, intermediate_dim] + const float* __restrict__ routing_weights, // [num_tokens, num_selected_experts] + const uint32_t* __restrict__ expert_indices, // [num_tokens, num_selected_experts] + T* __restrict__ output, // [num_tokens, hidden_dim] + int num_tokens, + int hidden_dim, + int intermediate_dim, + int num_selected_experts, + int activation_type // 0: SiLU, 1: GELU, 2: ReLU +) { + extern __shared__ char shared_mem[]; + T* shared_input = (T*)shared_mem; + T* shared_intermediate = shared_input + hidden_dim; + + const int token_idx = blockIdx.x; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + if (token_idx >= num_tokens) { + return; + } + + for (int i = tid; i < hidden_dim; i += block_size) { + shared_input[i] = input[token_idx * hidden_dim + i]; + } + __syncthreads(); + + for (int i = tid; i < hidden_dim; i += block_size) { + output[token_idx * hidden_dim + i] = T(0.0f); + } + __syncthreads(); + + for (int k = 0; k < num_selected_experts; k++) { + int expert_id = expert_indices[token_idx * num_selected_experts + k]; + float routing_weight = routing_weights[token_idx * num_selected_experts + k]; + + const T* gate_w = gate_weights + expert_id * hidden_dim * intermediate_dim; + const T* up_w = up_weights + expert_id * hidden_dim * intermediate_dim; + + for (int i = tid; i < intermediate_dim; i += block_size) { + float gate_val = 0.0f; + + for (int j = 0; j < hidden_dim; j++) { + float input_val = float(shared_input[j]); + gate_val += input_val * float(gate_w[j * intermediate_dim + i]); + } + + if (activation_type == 0) { + gate_val = silu(gate_val); + } else if (activation_type == 1) { + gate_val = gelu(gate_val); + } else if (activation_type == 2) { + gate_val = fmaxf(0.0f, gate_val); + } + + shared_intermediate[i] = T(gate_val); + } + __syncthreads(); + + for (int i = tid; i < hidden_dim; i += block_size) { + float acc = 0.0f; + for (int j = 0; j < intermediate_dim; j++) { + acc += float(shared_intermediate[j]) * float(up_w[i * intermediate_dim + j]); + } + output[token_idx * hidden_dim + i] += T(acc * routing_weight); + } + __syncthreads(); + } +} + +template +__global__ void qwen3_fused_moe_kernel( + const T* __restrict__ input, // [num_tokens, hidden_dim] + const T* __restrict__ gate_weights, // [num_experts, hidden_dim, intermediate_dim] + const T* __restrict__ up_weights, // [num_experts, hidden_dim, intermediate_dim] + const T* __restrict__ down_weights, // [num_experts, intermediate_dim, hidden_dim] + const float* __restrict__ routing_weights, // [num_tokens, num_selected_experts] + const uint32_t* __restrict__ expert_indices, // [num_tokens, num_selected_experts] + T* __restrict__ output, // [num_tokens, hidden_dim] + int num_tokens, + int hidden_dim, + int intermediate_dim, + int num_selected_experts, + int activation_type // 0: SiLU, 1: GELU, 2: ReLU +) { + extern __shared__ char shared_mem[]; + T* shared_input = (T*)shared_mem; + T* shared_intermediate = shared_input + hidden_dim; + + const int token_idx = blockIdx.x; + const int tid = threadIdx.x; + const int block_size = blockDim.x; + + if (token_idx >= num_tokens) { + return; + } + + for (int i = tid; i < hidden_dim; i += block_size) { + shared_input[i] = input[token_idx * hidden_dim + i]; + } + __syncthreads(); + + for (int i = tid; i < hidden_dim; i += block_size) { + output[token_idx * hidden_dim + i] = T(0.0f); + } + __syncthreads(); + + for (int k = 0; k < num_selected_experts; k++) { + int expert_id = expert_indices[token_idx * num_selected_experts + k]; + float routing_weight = routing_weights[token_idx * num_selected_experts + k]; + + const T* gate_w = gate_weights + expert_id * hidden_dim * intermediate_dim; + const T* up_w = up_weights + expert_id * hidden_dim * intermediate_dim; + const T* down_w = down_weights + expert_id * intermediate_dim * hidden_dim; + + for (int i = tid; i < intermediate_dim; i += block_size) { + float gate_val = 0.0f; + float up_val = 0.0f; + + for (int j = 0; j < hidden_dim; j++) { + float input_val = float(shared_input[j]); + gate_val += input_val * float(gate_w[j * intermediate_dim + i]); + up_val += input_val * float(up_w[j * intermediate_dim + i]); + } + + if (activation_type == 0) { + gate_val = silu(gate_val); + } else if (activation_type == 1) { + gate_val = gelu(gate_val); + } else if (activation_type == 2) { + gate_val = fmaxf(0.0f, gate_val); + } + + shared_intermediate[i] = T(gate_val * up_val); + } + __syncthreads(); + + for (int i = tid; i < hidden_dim; i += block_size) { + float down_val = 0.0f; + + for (int j = 0; j < intermediate_dim; j++) { + down_val += float(shared_intermediate[j]) * float(down_w[j * hidden_dim + i]); + } + + output[token_idx * hidden_dim + i] += T(down_val * routing_weight); + } + __syncthreads(); + } +} + +// Optimized fused MoE kernel with tiling and better memory access patterns +template +__global__ void fused_moe_kernel_optimized( + const T* __restrict__ input, // [num_tokens, hidden_dim] + const T* __restrict__ gate_weights, // [num_experts, hidden_dim, intermediate_dim] + const T* __restrict__ up_weights, // [num_experts, hidden_dim, intermediate_dim] + const T* __restrict__ down_weights, // [num_experts, intermediate_dim, hidden_dim] + const TokenExpertPair* __restrict__ sorted_pairs, // Sorted token-expert pairs + const int* __restrict__ expert_offsets, // Start offset for each expert in sorted_pairs + T* __restrict__ output, // [num_tokens, hidden_dim] + T* __restrict__ intermediate_cache, // [num_tokens, intermediate_dim] workspace + int num_tokens, + int hidden_dim, + int intermediate_dim, + int num_experts, + int total_pairs, + int activation_type // 0: SiLU, 1: GELU, 2: ReLU +) { + // Thread block processes one tile of the output + const int tid = threadIdx.x + threadIdx.y * blockDim.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + + // Shared memory for tiles + extern __shared__ char shared_mem[]; + T* tile_input = (T*)shared_mem; + T* tile_weight = tile_input + TILE_M * TILE_K; + float* tile_accum = (float*)(tile_weight + TILE_K * TILE_N); + + // Grid-stride loop over expert groups + for (int expert_id = blockIdx.z; expert_id < num_experts; expert_id += gridDim.z) { + int start_idx = (expert_id == 0) ? 0 : expert_offsets[expert_id - 1]; + int end_idx = expert_offsets[expert_id]; + + if (start_idx >= end_idx) continue; + + // Get expert weight pointers + const T* gate_w = gate_weights + expert_id * hidden_dim * intermediate_dim; + const T* up_w = up_weights + expert_id * hidden_dim * intermediate_dim; + const T* down_w = down_weights + expert_id * intermediate_dim * hidden_dim; + + // Process tokens assigned to this expert in blocks + for (int token_block = start_idx + blockIdx.x * TILE_M; + token_block < end_idx; + token_block += gridDim.x * TILE_M) { + + // Phase 1: Compute gate and up projections + for (int out_tile = blockIdx.y * TILE_N; + out_tile < intermediate_dim; + out_tile += gridDim.y * TILE_N) { + + // Initialize accumulator + float accum_gate[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + float accum_up[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Loop over K dimension in tiles + for (int k_tile = 0; k_tile < hidden_dim; k_tile += TILE_K) { + // Cooperatively load input tile + __syncthreads(); + for (int idx = tid; idx < TILE_M * TILE_K; idx += blockDim.x * blockDim.y) { + int local_m = idx / TILE_K; + int local_k = idx % TILE_K; + int token_idx = token_block + local_m; + + if (token_idx < end_idx && k_tile + local_k < hidden_dim) { + int actual_token = sorted_pairs[token_idx].token_idx; + tile_input[local_m * TILE_K + local_k] = + input[actual_token * hidden_dim + k_tile + local_k]; + } else { + tile_input[local_m * TILE_K + local_k] = T(0); + } + } + + // Load weight tiles for gate and up + for (int idx = tid; idx < TILE_K * TILE_N; idx += blockDim.x * blockDim.y) { + int local_k = idx / TILE_N; + int local_n = idx % TILE_N; + + if (k_tile + local_k < hidden_dim && out_tile + local_n < intermediate_dim) { + // Gate weights + tile_weight[local_k * TILE_N + local_n] = + gate_w[(k_tile + local_k) * intermediate_dim + out_tile + local_n]; + } else { + tile_weight[local_k * TILE_N + local_n] = T(0); + } + } + __syncthreads(); + + // Compute partial dot products for gate + int local_m = threadIdx.y; + int local_n = threadIdx.x; + + if (local_m < TILE_M && local_n < TILE_N) { + for (int k = 0; k < TILE_K; k++) { + accum_gate[0] += float(tile_input[local_m * TILE_K + k]) * + float(tile_weight[k * TILE_N + local_n]); + } + } + + // Load up weights + __syncthreads(); + for (int idx = tid; idx < TILE_K * TILE_N; idx += blockDim.x * blockDim.y) { + int local_k = idx / TILE_N; + int local_n = idx % TILE_N; + + if (k_tile + local_k < hidden_dim && out_tile + local_n < intermediate_dim) { + tile_weight[local_k * TILE_N + local_n] = + up_w[(k_tile + local_k) * intermediate_dim + out_tile + local_n]; + } + } + __syncthreads(); + + // Compute partial dot products for up + if (local_m < TILE_M && local_n < TILE_N) { + for (int k = 0; k < TILE_K; k++) { + accum_up[0] += float(tile_input[local_m * TILE_K + k]) * + float(tile_weight[k * TILE_N + local_n]); + } + } + } + + // Apply activation and store intermediate results + __syncthreads(); + int local_m = threadIdx.y; + int local_n = threadIdx.x; + + if (local_m < TILE_M && local_n < TILE_N) { + int token_idx = token_block + local_m; + int out_idx = out_tile + local_n; + + if (token_idx < end_idx && out_idx < intermediate_dim) { + float gate_val = accum_gate[0]; + float up_val = accum_up[0]; + + // Apply activation + if (activation_type == 0) { + gate_val = silu(gate_val); + } else if (activation_type == 1) { + gate_val = gelu(gate_val); + } else { + gate_val = fmaxf(0.0f, gate_val); + } + + // Store to intermediate cache + int actual_token = sorted_pairs[token_idx].token_idx; + intermediate_cache[actual_token * intermediate_dim + out_idx] = + T(gate_val * up_val); + } + } + } + } + + __syncthreads(); + + // Phase 2: Down projection + for (int token_block = start_idx + blockIdx.x * TILE_M; + token_block < end_idx; + token_block += gridDim.x * TILE_M) { + + for (int out_tile = blockIdx.y * TILE_N; + out_tile < hidden_dim; + out_tile += gridDim.y * TILE_N) { + + float accum[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // Loop over K dimension (intermediate_dim) + for (int k_tile = 0; k_tile < intermediate_dim; k_tile += TILE_K) { + // Load intermediate results + __syncthreads(); + for (int idx = tid; idx < TILE_M * TILE_K; idx += blockDim.x * blockDim.y) { + int local_m = idx / TILE_K; + int local_k = idx % TILE_K; + int token_idx = token_block + local_m; + + if (token_idx < end_idx && k_tile + local_k < intermediate_dim) { + int actual_token = sorted_pairs[token_idx].token_idx; + tile_input[local_m * TILE_K + local_k] = + intermediate_cache[actual_token * intermediate_dim + k_tile + local_k]; + } else { + tile_input[local_m * TILE_K + local_k] = T(0); + } + } + + // Load down weights + for (int idx = tid; idx < TILE_K * TILE_N; idx += blockDim.x * blockDim.y) { + int local_k = idx / TILE_N; + int local_n = idx % TILE_N; + + if (k_tile + local_k < intermediate_dim && out_tile + local_n < hidden_dim) { + tile_weight[local_k * TILE_N + local_n] = + down_w[(k_tile + local_k) * hidden_dim + out_tile + local_n]; + } + } + __syncthreads(); + + // Compute partial products + int local_m = threadIdx.y; + int local_n = threadIdx.x; + + if (local_m < TILE_M && local_n < TILE_N) { + for (int k = 0; k < TILE_K; k++) { + accum[0] += float(tile_input[local_m * TILE_K + k]) * + float(tile_weight[k * TILE_N + local_n]); + } + } + } + + // Accumulate to output with routing weights + int local_m = threadIdx.y; + int local_n = threadIdx.x; + + if (local_m < TILE_M && local_n < TILE_N) { + int token_idx = token_block + local_m; + int out_idx = out_tile + local_n; + + if (token_idx < end_idx && out_idx < hidden_dim) { + int actual_token = sorted_pairs[token_idx].token_idx; + float routing_weight = sorted_pairs[token_idx].routing_weight; + + atomicAdd(&output[actual_token * hidden_dim + out_idx], + T(accum[0] * routing_weight)); + } + } + } + } + } +} + +// Kernel to sort tokens by expert for better data locality +__global__ void prepare_sorted_pairs( + const uint32_t* expert_indices, // [num_tokens, num_selected_experts] + const float* routing_weights, // [num_tokens, num_selected_experts] + TokenExpertPair* sorted_pairs, // Output: sorted pairs + int* expert_counts, // Output: count per expert + int num_tokens, + int num_selected_experts +) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < num_tokens * num_selected_experts) { + int token_idx = tid / num_selected_experts; + int k = tid % num_selected_experts; + + TokenExpertPair pair; + pair.token_idx = token_idx; + pair.expert_idx = expert_indices[tid]; + pair.routing_weight = routing_weights[tid]; + pair.original_idx = tid; + + sorted_pairs[tid] = pair; + + // Count tokens per expert + atomicAdd(&expert_counts[pair.expert_idx], 1); + } +} + +#define CALL_NOMIC_FUSED_MOE_FORWARD(T) \ + nomic_fused_moe_kernel<<>>( \ + reinterpret_cast(input), \ + reinterpret_cast(gate_weights), \ + reinterpret_cast(up_weights), \ + routing_weights, \ + expert_indices, \ + reinterpret_cast(output), \ + num_tokens, \ + hidden_dim, \ + intermediate_dim, \ + num_selected_experts, \ + activation_type \ + ); + +#define CALL_QWEN3_FUSED_MOE_FORWARD(T) \ + qwen3_fused_moe_kernel<<>>( \ + reinterpret_cast(input), \ + reinterpret_cast(gate_weights), \ + reinterpret_cast(up_weights), \ + reinterpret_cast(down_weights), \ + routing_weights, \ + expert_indices, \ + reinterpret_cast(output), \ + num_tokens, \ + hidden_dim, \ + intermediate_dim, \ + num_selected_experts, \ + activation_type \ + ); + +// C interface for optimized fused MoE +extern "C" { + +void fused_moe_forward( + void* input, + void* gate_weights, + void* up_weights, + void* down_weights, + float* routing_weights, + uint32_t* expert_indices, + void* output, + int num_tokens, + int hidden_dim, + int intermediate_dim, + int num_selected_experts, + int activation_type, + uint32_t moe_type, // 0 => qwen3, 1 => nomic + uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 +) { + const cudaStream_t stream = 0; + const int threads = 256; + + if (moe_type == 0) { + if (dtype == 0) { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(half); + CALL_QWEN3_FUSED_MOE_FORWARD(half); + } else if (dtype == 1) { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(__nv_bfloat16); + CALL_QWEN3_FUSED_MOE_FORWARD(__nv_bfloat16); + } else { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(float); + CALL_QWEN3_FUSED_MOE_FORWARD(float); + } + } else if (moe_type == 1) { + if (dtype == 0) { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(half); + CALL_NOMIC_FUSED_MOE_FORWARD(half); + } else if (dtype == 1) { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(__nv_bfloat16); + CALL_NOMIC_FUSED_MOE_FORWARD(__nv_bfloat16); + } else { + int shared_mem_size = (hidden_dim + intermediate_dim) * sizeof(float); + CALL_NOMIC_FUSED_MOE_FORWARD(float); + } + } +} + +} // extern "C" diff --git a/candle-moe/kernels/topk_softmax.cu b/candle-moe/kernels/topk_softmax.cu new file mode 100644 index 0000000..aab90ee --- /dev/null +++ b/candle-moe/kernels/topk_softmax.cu @@ -0,0 +1,497 @@ +/* + * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#include +#include + +#include "cuda_compat.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { +namespace moe { + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N +> +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + int* source_rows, const int k, const int start_expert, const int end_expert) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, + int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + const int num_tokens, + const int num_experts, + const int topk, + cudaStream_t stream +) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + } + } +} + +} // namespace moe +} // namespace vllm + +extern "C" void topk_softmax( + void *gating_output, // [num_tokens, num_experts] + void *topk_weights, // [num_tokens, topk] + void *topk_indices, // [num_tokens, topk] + void *token_expert_indices, // [num_tokens, topk] + + int32_t num_experts, + int64_t num_tokens, + int32_t topk +) { + const cudaStream_t stream = 0; + + vllm::moe::topkGatingSoftmaxKernelLauncher( + reinterpret_cast(gating_output), + reinterpret_cast(topk_weights), + reinterpret_cast(topk_indices), + reinterpret_cast(token_expert_indices), + num_tokens, + num_experts, + topk, + stream + ); +} diff --git a/candle-moe/src/ffi.rs b/candle-moe/src/ffi.rs new file mode 100644 index 0000000..7f53f5f --- /dev/null +++ b/candle-moe/src/ffi.rs @@ -0,0 +1,49 @@ +use core::ffi::{c_int, c_void}; + +unsafe extern "C" { + pub(crate) fn topk_softmax( + gating_output: *const c_void, + topk_weight: *const c_void, + topk_indices: *const c_void, + token_expert_indices: *const c_void, + + num_experts: c_int, + num_tokens: c_int, + topk: c_int, + ); + + // Fused MoE forward pass + pub(crate) fn fused_moe_forward( + input: *const c_void, + gate_weights: *const c_void, + up_weights: *const c_void, + down_weights: *const c_void, + routing_weights: *const c_void, + expert_indices: *const c_void, + output: *const c_void, + num_tokens: i32, + hidden_dim: i32, + intermediate_dim: i32, + num_selected_experts: i32, + activation_type: i32, + moe_type: u32, + dtype: u32, + ); + + #[allow(dead_code)] + pub(crate) fn fused_moe_forward_optimized( + input: *const c_void, + gate_weights: *const c_void, + up_weights: *const c_void, + down_weights: *const c_void, + routing_weights: *const c_void, + expert_indices: *const c_void, + output: *mut c_void, + num_tokens: i32, + hidden_dim: i32, + intermediate_dim: i32, + num_experts: i32, + activation_type: i32, + dtype: u32, + ); +} diff --git a/candle-moe/src/lib.rs b/candle-moe/src/lib.rs new file mode 100644 index 0000000..cb163e3 --- /dev/null +++ b/candle-moe/src/lib.rs @@ -0,0 +1,421 @@ +pub mod ffi; + +use candle::cuda_backend::cudarc::driver::DevicePtr; +use candle::{DType, Result, Storage, Tensor}; +use half::{bf16, f16}; +use std::ptr; + +pub fn apply_topk_softmax_< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, +>( + gating_output: &Tensor, + topk_weight: &Tensor, + topk_indices: &Tensor, + token_expert_indices: &Tensor, +) -> Result<()> { + let (g, g_l) = gating_output.storage_and_layout(); + let g: &candle::CudaStorage = match &*g { + Storage::Cuda(g) => g, + _ => candle::bail!("gating_output must be a cuda tensor"), + }; + + let (w, w_l) = topk_weight.storage_and_layout(); + let w = match &*w { + Storage::Cuda(w) => w, + _ => candle::bail!("topk_weight must be a cuda tensor"), + }; + + let (i, i_l) = topk_indices.storage_and_layout(); + let i = match &*i { + Storage::Cuda(i) => i, + _ => candle::bail!("topk_indices must be a cuda tensor"), + }; + + let (ei, ei_l) = token_expert_indices.storage_and_layout(); + let ei: &candle::CudaStorage = match &*ei { + Storage::Cuda(ei) => ei, + _ => candle::bail!("token_expert_indices must be a cuda tensor"), + }; + + let g_rank = g_l.stride().len(); + let w_rank = w_l.stride().len(); + let i_rank = i_l.stride().len(); + let ei_rank = ei_l.stride().len(); + + if g_rank != 2 || w_rank != 2 || i_rank != 2 || ei_rank != 2 { + candle::bail!( + "apply_topk_softmax_inplace expects input tensors of rank 2 (w: {w_l:?}, i: {i_l:?}, ei: {ei_l:?}, g: {g_l:?})" + ) + } + + // Get cuda slices for all tensors + let g = g.as_cuda_slice::()?; + let w = w.as_cuda_slice::()?; + let i = i.as_cuda_slice::()?; + let ei = ei.as_cuda_slice::()?; + + // Get cuda views for all tensors + let g = g.slice(g_l.start_offset()..); + let w = w.slice(w_l.start_offset()..); + let i = i.slice(i_l.start_offset()..); + let ei = ei.slice(ei_l.start_offset()..); + + let (num_tokens, top_k) = w_l.shape().dims2()?; + let (_, num_experts) = g_l.shape().dims2()?; + + let is_pow2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + if !is_pow2 || num_experts > 256 { + candle::bail!( + "num_experts should be power of 2 and smaller than 256 (num_experts: {num_experts:?})" + ) + } + + if (num_tokens, top_k) != i_l.shape().dims2()? { + candle::bail!( + "shape mismatch topk_indices {:?}, expected {:?}", + i_l.shape(), + (num_tokens, top_k) + ) + } + + if (num_tokens, top_k) != ei_l.shape().dims2()? { + candle::bail!( + "shape mismatch token_expert_indices {:?}, expected {:?}", + ei_l.shape(), + (num_tokens, top_k) + ) + } + + let gate_ptr = *g.device_ptr() as *const core::ffi::c_void; + let weight_ptr = *w.device_ptr() as *const core::ffi::c_void; + let indices_ptr = *i.device_ptr() as *const core::ffi::c_void; + let expert_indices_ptr = *ei.device_ptr() as *const core::ffi::c_void; + + unsafe { + ffi::topk_softmax( + gate_ptr, + weight_ptr, + indices_ptr, + expert_indices_ptr, + num_experts as i32, + num_tokens as i32, + top_k as i32, + ) + } + + Ok(()) +} + +pub fn apply_topk_softmax_inplace( + gating_output: &Tensor, + topk_weight: &Tensor, + topk_indices: &Tensor, + token_expert_indices: &Tensor, +) -> Result<()> { + match topk_weight.dtype() { + DType::F16 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + DType::BF16 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + DType::F32 => apply_topk_softmax_::( + gating_output, + topk_weight, + topk_indices, + token_expert_indices, + ), + dt => { + candle::bail!( + "apply_topk_softmax_inplace is only supported for f32, f16 and bf16 ({dt:?})" + ) + } + } +} + +pub struct FusedMoeForward { + num_experts: usize, + num_selected_experts: usize, + activation: Activation, +} + +#[derive(Clone, Copy, Debug)] +pub enum Activation { + Silu, + Gelu, + Relu, +} + +impl Activation { + fn to_int(self) -> i32 { + match self { + Activation::Silu => 0, + Activation::Gelu => 1, + Activation::Relu => 2, + } + } +} + +fn moe_internal_type(dtype: DType) -> Result { + let internal_type: u32 = match dtype { + DType::F16 => 0, + DType::BF16 => 1, + DType::F32 => 2, + dtype => candle::bail!("dtype {dtype:?} is not supported"), + }; + Ok(internal_type) +} + +impl FusedMoeForward { + pub fn new(num_experts: usize, num_selected_experts: usize, activation: Activation) -> Self { + Self { + num_experts, + num_selected_experts, + activation, + } + } + + /// Performs fused MoE forward pass + /// Args: + /// - input: [num_tokens, hidden_dim] + /// - gate_weights: [num_experts, hidden_dim, intermediate_dim] + /// - up_weights: [num_experts, hidden_dim, intermediate_dim] + /// - down_weights: [num_experts, intermediate_dim, hidden_dim] + /// - routing_weights: [num_tokens, num_selected_experts] + /// - expert_indices: [num_tokens, num_selected_experts] + /// - moe_type: qwen3: 0, nomic: 1 + /// + /// Returns: + /// - output: [num_tokens, hidden_dim] + #[allow(clippy::too_many_arguments)] + pub fn forward( + &self, + input: &Tensor, + gate_weights: &Tensor, + up_weights: &Tensor, + down_weights: Option<&Tensor>, + routing_weights: &Tensor, + expert_indices: &Tensor, + moe_type: u32, + ) -> Result { + let device = input.device(); + + // Validate inputs + let (num_tokens, hidden_dim) = input.dims2()?; + let (ne_g, hd_g, id_g) = gate_weights.dims3()?; + let (ne_u, hd_u, id_u) = up_weights.dims3()?; + let (ne_d, id_d, hd_d) = if let Some(dw) = down_weights { + dw.dims3()? + } else { + (self.num_experts, id_u, hd_u) + }; + let (nt, nse) = routing_weights.dims2()?; + let (nt2, nse2) = expert_indices.dims2()?; + + if ne_g != self.num_experts || ne_u != self.num_experts || ne_d != self.num_experts { + candle::bail!("Number of experts mismatch"); + } + if hd_g != hidden_dim || hd_u != hidden_dim { + candle::bail!("Hidden dimension mismatch for gate/up weights"); + } + if hd_d != hidden_dim { + candle::bail!( + "Hidden dimension mismatch for down weights (expected {}, got {})", + hidden_dim, + hd_d + ); + } + if id_g != id_u || id_u != id_d { + candle::bail!( + "Intermediate dimension mismatch (gate: {}, up: {}, down: {})", + id_g, + id_u, + id_d + ); + } + + if nt != num_tokens || nt2 != num_tokens { + candle::bail!("Number of tokens mismatch"); + } + if nse != self.num_selected_experts || nse2 != self.num_selected_experts { + candle::bail!("Number of selected experts mismatch"); + } + if moe_type > 1 { + candle::bail!("moe_type must be one of 0 or 1"); + } + + // Create output tensor + let output = Tensor::zeros((num_tokens, hidden_dim), input.dtype(), device)?; + + _ = match input.dtype() { + DType::F16 => self.cuda_fwd::( + input, + gate_weights, + up_weights, + down_weights, + routing_weights, + expert_indices, + moe_type, + &output, + ), + DType::BF16 => self.cuda_fwd::( + input, + gate_weights, + up_weights, + down_weights, + routing_weights, + expert_indices, + moe_type, + &output, + ), + DType::F32 => self.cuda_fwd::( + input, + gate_weights, + up_weights, + down_weights, + routing_weights, + expert_indices, + moe_type, + &output, + ), + dt => { + candle::bail!("FusedMoeForward is only supported for f32, f16 and bf16 ({dt:?})") + } + }; + + Ok(output) + } + + #[allow(clippy::too_many_arguments)] + fn cuda_fwd< + T: candle::cuda_backend::CudaDType + candle::cuda_backend::cudarc::driver::DeviceRepr, + >( + &self, + input: &Tensor, + gate_weights: &Tensor, + up_weights: &Tensor, + down_weights: Option<&Tensor>, + routing_weights: &Tensor, + expert_indices: &Tensor, + moe_type: u32, + output: &Tensor, + ) -> Result<()> { + let (num_tokens, hidden_dim) = input.dims2()?; + let (_, hd_gate, intermediate_dim) = gate_weights.dims3()?; + + // Validate that gate weights have correct dimensions + if hd_gate != hidden_dim { + candle::bail!( + "gate_weights hidden_dim {} doesn't match input {}", + hd_gate, + hidden_dim + ); + } + + // Get storage and layouts + let (input_storage, input_layout) = input.storage_and_layout(); + let (gate_storage, gate_layout) = gate_weights.storage_and_layout(); + let (up_storage, up_layout) = up_weights.storage_and_layout(); + let (routing_storage, routing_layout) = routing_weights.storage_and_layout(); + let (indices_storage, indices_layout) = expert_indices.storage_and_layout(); + let (output_storage, output_layout) = output.storage_and_layout(); + + // Extract CUDA storage + let input_cuda = match &*input_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("input must be a cuda tensor"), + }; + let gate_cuda = match &*gate_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("gate_weights must be a cuda tensor"), + }; + let up_cuda = match &*up_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("up_weights must be a cuda tensor"), + }; + let routing_cuda = match &*routing_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("routing_weights must be a cuda tensor"), + }; + let indices_cuda = match &*indices_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("expert_indices must be a cuda tensor"), + }; + let output_cuda = match &*output_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("output must be a cuda tensor"), + }; + + let input_slice = input_cuda + .as_cuda_slice::()? + .slice(input_layout.start_offset()..); + let gate_slice = gate_cuda + .as_cuda_slice::()? + .slice(gate_layout.start_offset()..); + let up_slice = up_cuda + .as_cuda_slice::()? + .slice(up_layout.start_offset()..); + let routing_slice = routing_cuda + .as_cuda_slice::()? + .slice(routing_layout.start_offset()..); + let indices_slice = indices_cuda + .as_cuda_slice::()? + .slice(indices_layout.start_offset()..); + let output_slice = output_cuda + .as_cuda_slice::()? + .slice(output_layout.start_offset()..); + + let input_ptr = *input_slice.device_ptr() as *const core::ffi::c_void; + let gate_ptr = *gate_slice.device_ptr() as *const core::ffi::c_void; + let up_ptr = *up_slice.device_ptr() as *const core::ffi::c_void; + let routing_ptr = *routing_slice.device_ptr() as *const core::ffi::c_void; + let indices_ptr = *indices_slice.device_ptr() as *const core::ffi::c_void; + let output_ptr = *output_slice.device_ptr() as *const core::ffi::c_void; + + let down_ptr = if let Some(dw) = down_weights { + let (down_storage, down_layout) = dw.storage_and_layout(); + + let down_cuda = match &*down_storage { + Storage::Cuda(cuda_storage) => cuda_storage, + _ => candle::bail!("down_weights must be a cuda tensor"), + }; + + let down_slice = down_cuda + .as_cuda_slice::()? + .slice(down_layout.start_offset()..); + + *down_slice.device_ptr() as *const core::ffi::c_void + } else { + ptr::null() + }; + + let internal_dtype = moe_internal_type(input.dtype())?; + + unsafe { + ffi::fused_moe_forward( + input_ptr, + gate_ptr, + up_ptr, + down_ptr, + routing_ptr, + indices_ptr, + output_ptr, + num_tokens as i32, + hidden_dim as i32, + intermediate_dim as i32, + self.num_selected_experts as i32, + self.activation.to_int(), + moe_type, + internal_dtype, + ); + } + + Ok(()) + } +} diff --git a/candle-moe/tests/moe_tests.rs b/candle-moe/tests/moe_tests.rs new file mode 100644 index 0000000..50f89ea --- /dev/null +++ b/candle-moe/tests/moe_tests.rs @@ -0,0 +1,143 @@ +use anyhow::Result; +use candle::{D, DType, Device, IndexOp, Tensor}; +use candle_transformers::models::deepseek2::{BincountOp, NonZeroOp}; + +#[allow(dead_code)] +fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|row| { + row.iter() + .map(|val| (val * b).round() / b) + .collect::>() + }) + .collect::>>(); + Ok(t) +} + +fn forward_moe_router( + weights: &Tensor, + seq_len: usize, + top_k: usize, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let topk_weight = Tensor::zeros((seq_len, top_k), DType::F32, device)?; + let topk_indices = Tensor::zeros((seq_len, top_k), DType::U32, device)?; + let token_expert_indices = Tensor::zeros((seq_len, top_k), DType::U32, device)?; + + candle_moe::apply_topk_softmax_inplace( + weights, + &topk_weight, + &topk_indices, + &token_expert_indices, + )?; + + Ok((topk_weight, topk_indices)) +} + +fn forward_moe_mlp(x: &Tensor, w1: &Tensor, w2: &Tensor, expert_idx: usize) -> Result { + let expert_w1 = w1.narrow(0, expert_idx, 1)?.squeeze(0)?.t()?; + let expert_w2 = w2.narrow(0, expert_idx, 1)?.squeeze(0)?; + + let x = x.broadcast_matmul(&expert_w1)?; + let x = x.gelu()?; + + Ok(x.broadcast_matmul(&expert_w2)?) +} + +fn forward_moe_expert( + hidden_states: &Tensor, + gate: &Tensor, + up: &Tensor, + scores: &Tensor, + indices: &Tensor, + hidden_size: usize, + num_experts: usize, +) -> Result { + let hidden_states = hidden_states.reshape(((), hidden_size))?; + + let mut out = Tensor::zeros_like(&hidden_states)?; + + let counts = indices.flatten_all()?.bincount(num_experts as u32)?; + + for (expert_idx, &count) in counts.iter().enumerate().take(num_experts) { + if count == 0u32 { + continue; + } + + let idx_top = indices.eq(expert_idx as f64)?.nonzero()?.t()?; + let idx = &idx_top.i(0)?.contiguous()?; + let top = &idx_top.i(1)?.contiguous()?; + + let expert_out = + forward_moe_mlp(&hidden_states.index_select(idx, 0)?, gate, up, expert_idx)? + .broadcast_mul( + &scores + .index_select(idx, 0)? + .gather(&top.unsqueeze(1)?, 1)? + .squeeze(1)? + .unsqueeze(D::Minus1)? + .to_dtype(hidden_states.dtype())?, + )?; + + out = out.index_add(idx, &expert_out, 0)?; + } + + Ok(out) +} + +#[test] +fn fused_moe() -> Result<()> { + let device = Device::new_cuda(0)?; + + let n_embed = 768; + let n_inner = n_embed * 4; + let seq_len = 7; + let num_experts = 8; + let top_k = 2; + + let hidden_states = + Tensor::randn(0.0, 1.0, (seq_len, n_embed), &device)?.to_dtype(DType::F32)?; + let weights = Tensor::randn(0.0, 1.0, (seq_len, num_experts), &device)?.to_dtype(DType::F32)?; + + let (scores, indices) = forward_moe_router(&weights, seq_len, top_k, &device)?; + + let gate_weights = + Tensor::randn(0.0, 1.0, (num_experts, n_embed, n_inner), &device)?.to_dtype(DType::F32)?; + let up_weights = + Tensor::randn(0.0, 1.0, (num_experts, n_embed, n_inner), &device)?.to_dtype(DType::F32)?; + + let fused_moe = + candle_moe::FusedMoeForward::new(num_experts, top_k, candle_moe::Activation::Gelu); + + let fused_moe_output = fused_moe.forward( + &hidden_states, + &gate_weights, + &up_weights, + None, + &scores, + &indices, + 1_u32, + )?; + println!("fused moe: {:}", fused_moe_output); + + let naive_moe_output = forward_moe_expert( + &hidden_states, + &gate_weights.permute((0, 2, 1))?, + &up_weights.permute((0, 2, 1))?, + &scores, + &indices, + n_embed, + num_experts, + )?; + println!("naive moe: {:}", naive_moe_output); + + assert_eq!( + to_vec2_round(fused_moe_output, 6)?, + to_vec2_round(naive_moe_output, 6)? + ); + + Ok(()) +} diff --git a/candle-moe/tests/topk_softmax_tests.rs b/candle-moe/tests/topk_softmax_tests.rs new file mode 100644 index 0000000..ceefaad --- /dev/null +++ b/candle-moe/tests/topk_softmax_tests.rs @@ -0,0 +1,58 @@ +use anyhow::Result; +use candle::{DType, Device, Tensor}; +use candle_transformers::models::deepseek2::{TopKLastDimOp, TopKOutput}; + +fn to_vec2_round(t: Tensor, digits: i32) -> Result>> { + let b = 10f32.powi(digits); + let t = t.to_vec2::()?; + let t = t + .iter() + .map(|row| { + row.iter() + .map(|val| (val * b).round() / b) + .collect::>() + }) + .collect::>>(); + Ok(t) +} + +#[test] +fn topk_softmax() -> Result<()> { + let device = Device::new_cuda(0)?; + + let seq_len = 8; + let num_experts = 4; + let top_k = 2; + + let weights = Tensor::randn(0.0, 1.0, (seq_len, num_experts), &device)?.to_dtype(DType::F32)?; + + let softmax_weights = candle_nn::ops::softmax_last_dim(&weights)?; + + let TopKOutput { + values: expected_values, + indices: expected_indices, + } = softmax_weights.topk(top_k)?; + + let topk_weight = Tensor::zeros((seq_len, top_k), DType::F32, &device)?; + let topk_indices = Tensor::zeros((seq_len, top_k), DType::U32, &device)?; + let token_expert_indices = Tensor::zeros((seq_len, top_k), DType::U32, &device)?; + + candle_moe::apply_topk_softmax_inplace( + &weights, + &topk_weight, + &topk_indices, + &token_expert_indices, + )?; + + assert_eq!( + to_vec2_round(expected_values, 3)?, + to_vec2_round(topk_weight, 3)? + ); + + assert_eq!( + expected_indices.to_vec2::()?, + topk_indices.to_vec2::()?, + ); + + Ok(()) +} From 618708dcb20bfaacdfa0d606498ef3e2c45197c1 Mon Sep 17 00:00:00 2001 From: kozistr Date: Wed, 10 Sep 2025 21:04:38 +0900 Subject: [PATCH 2/2] update: remove vscode --- .vscode/settings.json | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index c9c987c..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "files.associations": { - "array": "cpp", - "format": "cpp", - "initializer_list": "cpp", - "list": "cpp", - "utility": "cpp", - "vector": "cpp", - "xhash": "cpp", - "xstring": "cpp", - "xtree": "cpp", - "xutility": "cpp" - } -} \ No newline at end of file