diff --git a/Cargo.toml b/Cargo.toml index bd6e1a856b..64e1460ebe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,8 @@ tokenizers = { version = "0.19.1", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" +ug = "0.0.2" +ug-cuda = "0.0.2" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index cbf8f2007f..8ea2b08c03 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -28,6 +28,8 @@ rand_distr = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } +ug = { workspace = true } +ug-cuda = { workspace = true, optional = true } yoke = { workspace = true } zip = { workspace = true } @@ -39,7 +41,7 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels"] +cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 89fe44a6e6..d3bd29030e 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -51,6 +51,27 @@ impl CudaDevice { self.device.clone() } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?; + let cuda_code = String::from_utf8(buf)?; + let opts = cudarc::nvrtc::CompileOptions { + use_fast_math: Some(true), + ..Default::default() + }; + let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?; + self.device.load_ptx(ptx, "ug", &[func_name]).w()?; + let func = match self.device.get_func("ug", func_name) { + Some(func) => func, + None => crate::bail!("unknown function ug::{func_name}"), + }; + Ok(func) + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 3a85dba9f4..276e3658e7 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -375,3 +375,70 @@ impl Tensor { ) } } + +pub struct UgIOp1 { + name: &'static str, + #[cfg(feature = "cuda")] + func: cudarc::driver::CudaFunction, +} + +impl UgIOp1 { + #[allow(unused)] + pub fn new( + name: &'static str, + kernel: ug::lang::ssa::Kernel, + device: &crate::Device, + ) -> Result { + #[cfg(feature = "cuda")] + { + let device = device.as_cuda_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(feature = "cuda"))] + { + Ok(Self { name }) + } + } +} + +impl InplaceOp1 for UgIOp1 { + fn name(&self) -> &'static str { + self.name + } + + fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { + crate::bail!("ug ops are only supported on cuda at the moment") + } + + fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> { + crate::bail!("ug ops are only supported on cuda at the moment") + } + + #[cfg(feature = "cuda")] + fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { + use crate::cuda_backend::WrapErr; + use cudarc::driver::LaunchAsync; + + let elem_count = layout.shape().elem_count(); + // TODO: support more dtypes. + let sto = sto.as_cuda_slice::()?; + let sto = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => sto.slice(o1..o2), + }; + let params = (&sto,); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (g as u32, 1, 1), + block_dim: (b as u32, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { self.func.clone().launch(cfg, params) }.w()?; + Ok(()) + } +} diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index c4a8e9361e..91925b5781 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,14 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> { + match self { + Self::Cuda(d) => Ok(d), + Self::Cpu => crate::bail!("expected a cuda device, got cpu"), + Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"), + } + } + pub fn new_cuda_with_stream(ordinal: usize) -> Result { Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) } diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index e7112e2e61..a35bec3cbe 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -165,6 +165,9 @@ pub enum Error { #[error("Metal error {0}")] Metal(#[from] MetalError), + #[error(transparent)] + Ug(#[from] ug::Error), + #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), @@ -179,6 +182,10 @@ pub enum Error { #[error(transparent)] ParseInt(#[from] std::num::ParseIntError), + /// Utf8 parse error. + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), + /// I/O error. #[error(transparent)] Io(#[from] std::io::Error), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index d8d6253213..39ca909d88 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -77,7 +77,7 @@ mod variable; pub use cuda_backend::cudnn; pub use cpu_backend::{CpuStorage, CpuStorageRef}; -pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; +pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use error::{Error, Result}; diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index be59e0c0c3..f2c01aca8e 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> { ); Ok(()) } + +#[cfg(feature = "cuda")] +#[allow(clippy::approx_constant)] +#[test] +fn ug_op() -> Result<()> { + let kernel = { + use ug::lang::op; + + let layout = ug::Layout::from_shape(&[12]); + let ptr = op::Arg::ptr(ug::DType::F32); + let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?; + let src = op::unary(op::UnaryOp::Exp, src)?; + let st = op::store(ptr.id(), layout, src)?; + let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]); + let opts: ug::lower_op::Opts = Default::default(); + kernel.lower(&opts.with_global(0, 12))? + }; + let device = Device::new_cuda(0)?; + let op = candle_core::UgIOp1::new("test", kernel, &device)?; + let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; + t.inplace_op1(&op)?; + assert_eq!( + to_vec1_round(&t, 4)?, + &[ + 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578, + 8103.0806, 22026.469, 59874.133 + ] + ); + Ok(()) +}