Skip to content

Commit

Permalink
Merge.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 28, 2024
2 parents 40be50e + 0e2c8c1 commit d279939
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 12 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ tokenizers = { version = "0.19.1", default-features = false }
tracing = "0.1.37"
tracing-chrome = "0.7.1"
tracing-subscriber = "0.3.7"
ug = "0.0.2"
ug-cuda = "0.0.2"
ug-metal = "0.0.2"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "1.1.1", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
[![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619)
[![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core)
[![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core)
![License](https://img.shields.io/crates/l/candle-core.svg)
[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT)
[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE)

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support)
and ease of use. Try our online demos:
Expand Down
7 changes: 5 additions & 2 deletions candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ rand_distr = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
thiserror = { workspace = true }
ug = { workspace = true }
ug-cuda = { workspace = true, optional = true }
ug-metal = { workspace = true, optional = true }
yoke = { workspace = true }
zip = { workspace = true }

Expand All @@ -39,11 +42,11 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"]

[[bench]]
name = "bench_main"
Expand Down
21 changes: 21 additions & 0 deletions candle-core/src/cuda_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ impl CudaDevice {
self.device.clone()
}

pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<CudaFunction> {
let mut buf = vec![];
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
let cuda_code = String::from_utf8(buf)?;
let opts = cudarc::nvrtc::CompileOptions {
use_fast_math: Some(true),
..Default::default()
};
let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?;
self.device.load_ptx(ptx, "ug", &[func_name]).w()?;
let func = match self.device.get_func("ug", func_name) {
Some(func) => func,
None => crate::bail!("unknown function ug::{func_name}"),
};
Ok(func)
}

pub fn id(&self) -> DeviceId {
self.id
}
Expand Down
107 changes: 107 additions & 0 deletions candle-core/src/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,110 @@ impl Tensor {
)
}
}

pub struct UgIOp1 {
name: &'static str,
#[cfg(feature = "cuda")]
func: cudarc::driver::CudaFunction,
#[cfg(feature = "metal")]
func: metal::ComputePipelineState,
}

impl UgIOp1 {
#[allow(unused)]
pub fn new(
name: &'static str,
kernel: ug::lang::ssa::Kernel,
device: &crate::Device,
) -> Result<Self> {
#[cfg(feature = "cuda")]
{
let device = device.as_cuda_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(feature = "metal")]
{
let device = device.as_metal_device()?;
let func = device.compile(name, kernel)?;
Ok(Self { name, func })
}
#[cfg(not(any(feature = "cuda", feature = "metal")))]
{
Ok(Self { name })
}
}
}

impl InplaceOp1 for UgIOp1 {
fn name(&self) -> &'static str {
self.name
}

fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> {
crate::bail!("ug ops are only supported on metal/cuda at the moment")
}

#[cfg(feature = "metal")]
fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> {
use crate::backend::BackendStorage;
use candle_metal_kernels::utils::EncoderProvider;

let elem_count = layout.shape().elem_count();
if sto.dtype() != crate::DType::F32 {
// TODO: support more dtypes.
crate::bail!("input is not a f32 tensor")
}
let device = sto.device();
println!("here");
let command_buffer = device.command_buffer()?;
let command_buffer = &command_buffer;
let encoder = command_buffer.encoder();
let encoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&self.func);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let grid_dims = metal::MTLSize {
width: g as u64,
height: 1,
depth: 1,
};
let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1);
candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize));

encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write);
encoder.dispatch_threads(grid_dims, group_dims);

Ok(())
}

#[cfg(feature = "cuda")]
fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> {
use crate::cuda_backend::WrapErr;
use cudarc::driver::LaunchAsync;

let elem_count = layout.shape().elem_count();
// TODO: support more dtypes.
let sto = sto.as_cuda_slice::<f32>()?;
let sto = match layout.contiguous_offsets() {
None => crate::bail!("input has to be contiguous"),
Some((o1, o2)) => sto.slice(o1..o2),
};
let params = (&sto,);
let (g, b) = if elem_count % 32 == 0 {
(elem_count / 32, 32)
} else {
(elem_count, 1)
};
let cfg = cudarc::driver::LaunchConfig {
grid_dim: (g as u32, 1, 1),
block_dim: (b as u32, 1, 1),
shared_mem_bytes: 0,
};
unsafe { self.func.clone().launch(cfg, params) }.w()?;
Ok(())
}
}
16 changes: 16 additions & 0 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}

pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> {
match self {
Self::Cuda(d) => Ok(d),
Self::Cpu => crate::bail!("expected a cuda device, got cpu"),
Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"),
}
}

pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> {
match self {
Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"),
Self::Cpu => crate::bail!("expected a metal device, got cpu"),
Self::Metal(d) => Ok(d),
}
}

pub fn new_cuda_with_stream(ordinal: usize) -> Result<Self> {
Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?))
}
Expand Down
7 changes: 7 additions & 0 deletions candle-core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ pub enum Error {
#[error("Metal error {0}")]
Metal(#[from] MetalError),

#[error(transparent)]
Ug(#[from] ug::Error),

#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

Expand All @@ -179,6 +182,10 @@ pub enum Error {
#[error(transparent)]
ParseInt(#[from] std::num::ParseIntError),

/// Utf8 parse error.
#[error(transparent)]
FromUtf8(#[from] std::string::FromUtf8Error),

/// I/O error.
#[error(transparent)]
Io(#[from] std::io::Error),
Expand Down
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mod variable;
pub use cuda_backend::cudnn;

pub use cpu_backend::{CpuStorage, CpuStorageRef};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
pub use device::{Device, DeviceLocation, NdArray};
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
pub use error::{Error, Result};
Expand Down
22 changes: 22 additions & 0 deletions candle-core/src/metal_backend/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,28 @@ impl MetalDevice {
self.use_mlx_mm = use_mlx_mm
}

pub fn compile(
&self,
func_name: &'static str,
kernel: ug::lang::ssa::Kernel,
) -> Result<metal::ComputePipelineState> {
let mut buf = vec![];
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
let metal_code = String::from_utf8(buf)?;
let lib = self
.device
.new_library_with_source(&metal_code, &metal::CompileOptions::new())
.map_err(MetalError::from)?;
let func = lib
.get_function(func_name, None)
.map_err(MetalError::from)?;
let pl = self
.device
.new_compute_pipeline_state_with_function(&func)
.map_err(MetalError::from)?;
Ok(pl)
}

pub fn id(&self) -> DeviceId {
self.id
}
Expand Down
36 changes: 36 additions & 0 deletions candle-core/tests/custom_op_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,39 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}

#[cfg(any(feature = "cuda", feature = "metal"))]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;

let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = if candle_core::utils::cuda_is_available() {
Device::new_cuda(0)?
} else if candle_core::utils::metal_is_available() {
Device::new_metal(0)?
} else {
candle_core::bail!("metal/cuda is mandatory for this test")
};
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 2)?,
&[
1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47,
59874.13
]
);
Ok(())
}
2 changes: 1 addition & 1 deletion candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;

mod utils;
pub mod utils;
pub use utils::BufferOffset;
use utils::{get_block_dims, linear_split, EncoderProvider};

Expand Down
10 changes: 3 additions & 7 deletions candle-metal-kernels/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M
}

// https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96
pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
let mut pows0 = 0u64;
let mut pows1 = 0u64;
let mut pows2 = 0u64;
Expand Down Expand Up @@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize {
}
}

pub(crate) fn set_param<P: EncoderParam>(
encoder: &ComputeCommandEncoderRef,
position: u64,
data: P,
) {
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}

/// Helper functions to create the various objects on the compute command encoder
/// on a single line.
/// Prevents getting wrong some arguments number and mixing length and size in bytes.
pub(crate) trait EncoderParam {
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self);
}
macro_rules! primitive {
Expand Down

0 comments on commit d279939

Please sign in to comment.