diff --git a/Cargo.toml b/Cargo.toml index bd6e1a856b..f27ec93326 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,9 @@ tokenizers = { version = "0.19.1", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" +ug = "0.0.2" +ug-cuda = "0.0.2" +ug-metal = "0.0.2" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/README.md b/README.md index 4c84a09185..246e2844ad 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,8 @@ [![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619) [![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core) [![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core) -![License](https://img.shields.io/crates/l/candle-core.svg) +[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT) +[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE) Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index cbf8f2007f..4ffc869ff8 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -28,6 +28,9 @@ rand_distr = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } +ug = { workspace = true } +ug-cuda = { workspace = true, optional = true } +ug-metal = { workspace = true, optional = true } yoke = { workspace = true } zip = { workspace = true } @@ -39,11 +42,11 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels"] +cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal", "dep:candle-metal-kernels"] +metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"] [[bench]] name = "bench_main" diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 89fe44a6e6..d3bd29030e 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -51,6 +51,27 @@ impl CudaDevice { self.device.clone() } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?; + let cuda_code = String::from_utf8(buf)?; + let opts = cudarc::nvrtc::CompileOptions { + use_fast_math: Some(true), + ..Default::default() + }; + let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?; + self.device.load_ptx(ptx, "ug", &[func_name]).w()?; + let func = match self.device.get_func("ug", func_name) { + Some(func) => func, + None => crate::bail!("unknown function ug::{func_name}"), + }; + Ok(func) + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 3a85dba9f4..c0d97d670a 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -375,3 +375,110 @@ impl Tensor { ) } } + +pub struct UgIOp1 { + name: &'static str, + #[cfg(feature = "cuda")] + func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, +} + +impl UgIOp1 { + #[allow(unused)] + pub fn new( + name: &'static str, + kernel: ug::lang::ssa::Kernel, + device: &crate::Device, + ) -> Result { + #[cfg(feature = "cuda")] + { + let device = device.as_cuda_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] + { + Ok(Self { name }) + } + } +} + +impl InplaceOp1 for UgIOp1 { + fn name(&self) -> &'static str { + self.name + } + + fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { + crate::bail!("ug ops are only supported on metal/cuda at the moment") + } + + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use candle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { + use crate::cuda_backend::WrapErr; + use cudarc::driver::LaunchAsync; + + let elem_count = layout.shape().elem_count(); + // TODO: support more dtypes. + let sto = sto.as_cuda_slice::()?; + let sto = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => sto.slice(o1..o2), + }; + let params = (&sto,); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (g as u32, 1, 1), + block_dim: (b as u32, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { self.func.clone().launch(cfg, params) }.w()?; + Ok(()) + } +} diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index c4a8e9361e..18aa61aff7 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,22 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> { + match self { + Self::Cuda(d) => Ok(d), + Self::Cpu => crate::bail!("expected a cuda device, got cpu"), + Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"), + } + } + + pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> { + match self { + Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"), + Self::Cpu => crate::bail!("expected a metal device, got cpu"), + Self::Metal(d) => Ok(d), + } + } + pub fn new_cuda_with_stream(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e7112e2e61..a35bec3cbe 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -165,6 +165,9 @@ pub enum Error { #[error("Metal error {0}")] Metal(#[from] MetalError), + #[error(transparent)] + Ug(#[from] ug::Error), + #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), @@ -179,6 +182,10 @@ pub enum Error { #[error(transparent)] ParseInt(#[from] std::num::ParseIntError), + /// Utf8 parse error. + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), + /// I/O error. #[error(transparent)] Io(#[from] std::io::Error), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index d8d6253213..39ca909d88 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -77,7 +77,7 @@ mod variable; pub use cuda_backend::cudnn; pub use cpu_backend::{CpuStorage, CpuStorageRef}; -pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; +pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index 29b8995bc9..46be6ce4bb 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -144,6 +144,28 @@ impl MetalDevice { self.use_mlx_mm = use_mlx_mm } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?; + let metal_code = String::from_utf8(buf)?; + let lib = self + .device + .new_library_with_source(&metal_code, &metal::CompileOptions::new()) + .map_err(MetalError::from)?; + let func = lib + .get_function(func_name, None) + .map_err(MetalError::from)?; + let pl = self + .device + .new_compute_pipeline_state_with_function(&func) + .map_err(MetalError::from)?; + Ok(pl) + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index be59e0c0c3..3572a4c9b2 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -143,3 +143,39 @@ fn inplace_op1() -> Result<()> { ); Ok(()) } + +#[cfg(any(feature = "cuda", feature = "metal"))] +#[allow(clippy::approx_constant)] +#[test] +fn ug_op() -> Result<()> { + let kernel = { + use ug::lang::op; + + let layout = ug::Layout::from_shape(&[12]); + let ptr = op::Arg::ptr(ug::DType::F32); + let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?; + let src = op::unary(op::UnaryOp::Exp, src)?; + let st = op::store(ptr.id(), layout, src)?; + let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]); + let opts: ug::lower_op::Opts = Default::default(); + kernel.lower(&opts.with_global(0, 12))? + }; + let device = if candle_core::utils::cuda_is_available() { + Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + Device::new_metal(0)? + } else { + candle_core::bail!("metal/cuda is mandatory for this test") + }; + let op = candle_core::UgIOp1::new("test", kernel, &device)?; + let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; + t.inplace_op1(&op)?; + assert_eq!( + to_vec1_round(&t, 2)?, + &[ + 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47, + 59874.13 + ] + ); + Ok(()) +} diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index be6160093e..222ae8ad85 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -mod utils; +pub mod utils; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderProvider}; diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index d2cc09f495..0092ecfa58 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M } // https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 -pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { +pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { let mut pows0 = 0u64; let mut pows1 = 0u64; let mut pows2 = 0u64; @@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { } } -pub(crate) fn set_param( - encoder: &ComputeCommandEncoderRef, - position: u64, - data: P, -) { +pub fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub(crate) trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive {