Skip to content

Commit

Permalink
custom unary strided
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 17, 2024
1 parent 7a61a72 commit 8e827d3
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 18 deletions.
56 changes: 56 additions & 0 deletions native/candlex/src/metal_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,62 @@ pub fn call_custom_unary_contiguous(
Ok(())
}

pub fn call_custom_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernel_name: custom_unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
input_buffer: &Buffer,
input_offset: usize,
output_buffer: &Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::CustomUnary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);

let num_dims: usize = shape.len();
let length: usize = shape.iter().product();

encoder.set_bytes(
0,
core::mem::size_of::<usize>() as u64,
&length as *const usize as *const c_void,
);
encoder.set_bytes(
1,
core::mem::size_of::<usize>() as u64,
&num_dims as *const usize as *const c_void,
);

encoder.set_bytes(
2,
core::mem::size_of_val(shape) as u64,
shape.as_ptr() as *const c_void,
);

encoder.set_bytes(
3,
core::mem::size_of_val(strides) as u64,
strides.as_ptr() as *const c_void,
);

encoder.set_buffer(4, Some(input_buffer), input_offset as u64);
encoder.set_buffer(5, Some(output_buffer), output_offset as u64);

encoder.use_resource(input_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output_buffer, metal::MTLResourceUsage::Write);

let width: usize = shape.iter().product();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

encoder.end_encoding();

Ok(())
}

pub fn call_custom_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
Expand Down
29 changes: 29 additions & 0 deletions native/candlex/src/metal_kernels/custom_unary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

using namespace metal;

METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}

#define CUSTOM_UNARY(IN_TYPE, OUT_TYPE, FN_NAME, FN) \
kernel void FN_NAME( \
constant size_t &dim, \
Expand All @@ -13,6 +28,20 @@ kernel void FN_NAME( \
return; \
} \
output[tid] = OUT_TYPE(FN(IN_TYPE(input[tid]))); \
}\
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const IN_TYPE *input, \
device OUT_TYPE *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
return; \
} \
output[tid] = OUT_TYPE(FN(IN_TYPE(input[get_strided_index(tid, num_dims, dims, strides)]))); \
}

CUSTOM_UNARY(float, float, acos_f32, acos)
Expand Down
54 changes: 36 additions & 18 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,48 @@ macro_rules! custom_unary_op {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(layout.is_contiguous() && layout.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported");
}

let device = storage.device();
let command_buffer = device.command_buffer()?;
let elem_count = layout.shape().elem_count();
let dtype = storage.dtype();
let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?;

let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
dtype => {
candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented")
}
};
metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
if (layout.is_contiguous() && layout.start_offset() == 0) {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::contiguous::$name::FLOAT,
dtype => {
candle_core::bail!("Metal contiguous custom unary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_unary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
storage.buffer(),
&output_buffer,
).unwrap();
} else {
let kernel_name = match storage.dtype() {
DType::F32 => metal_kernels::custom_unary::strided::$name::FLOAT,
dtype => {
candle_core::bail!("Metal strided custom unary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_unary_strided(
&device.device(),
&command_buffer,
kernel_name,
layout.dims(),
layout.stride(),
storage.buffer(),
layout.start_offset() * dtype.size_in_bytes(),
&output_buffer,
0,
).unwrap();
}

Ok((MetalStorage::new(output_buffer, device.clone(), dtype), layout.shape().clone()))
}
Expand Down

0 comments on commit 8e827d3

Please sign in to comment.