Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for UG kernels. #2579

Merged
merged 2 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
Expand Down
4 changes: 3 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }

Expand All @@ -39,7 +41,7 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
Expand Down
21 changes: 21 additions & 0 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ impl CudaDevice {
self.device.clone()
}

pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}

pub fn id(&self) -> DeviceId {
self.id
}
Expand Down
67 changes: 67 additions & 0 deletions candle-core/src/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,70 @@ impl Tensor {
)
}
}

pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
}

impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(feature = "cuda"))]
{
Ok(Self { name })
}
}
}

impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}

fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on cuda at the moment")
}

fn metal_fwd(&self, _: &mut MetalStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on cuda at the moment")
}

#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;

let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}
8 changes: 8 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}

pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}

pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),

#[error(transparent)]
Ug(#[from] ug::Error),

#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

Expand All @@ -179,6 +182,10 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),

/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),

/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mod variable;
pub use cuda_backend::cudnn;

pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
Expand Down
30 changes: 30 additions & 0 deletions candle-core/tests/custom_op_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}

#[cfg(feature = "cuda")]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;

let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = Device::new_cuda(0)?;
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 4)?,
&[
1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
8103.0806, 22026.469, 59874.133
]
);
Ok(())
}
Loading