Skip to content

Commit

Permalink
Incorporating review changes - added check elem count check in kerner…
Browse files Browse the repository at this point in the history
…, using for call strategy
  • Loading branch information
AnubhabB committed Oct 1, 2024
1 parent 942a617 commit bc60875
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
7 changes: 6 additions & 1 deletion candle-metal-kernels/src/fill.metal
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@ using namespace metal;
template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
}

#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \
fill_with<T>(out, value, tid); \
fill_with<T>(out, value, numel, tid); \
} \


Expand Down
10 changes: 4 additions & 6 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2375,14 +2375,12 @@ pub fn call_const_fill(

encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (output, v));
set_params!(encoder, (output, v, length));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);

encoder.use_resource(output, metal::MTLResourceUsage::Write);

let grid_size = MTLSize { width: length as u64, height: 1, depth: 1 };
let thread_group_size = MTLSize { width: pipeline.max_total_threads_per_threadgroup(), height: 1, depth: 1 };

encoder.dispatch_threads(grid_size, thread_group_size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

Ok(())
}
Expand Down

0 comments on commit bc60875

Please sign in to comment.