diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index a270bb2888..be6160093e 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -2372,16 +2372,11 @@ pub fn call_const_fill( let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (output, v, length)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) } diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index f37ab5bb9c..637bf2e243 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -2309,66 +2309,32 @@ fn conv_transpose1d_u32() { assert_eq!(results, expected); } -fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { - let dev = device(); - let kernels = Kernels::new(); - let command_queue = dev.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - - let buffer = dev.new_buffer( - (len * std::mem::size_of::()) as u64, - MTLResourceOptions::StorageModePrivate, - ); - - call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); - - command_buffer.commit(); - command_buffer.wait_until_completed(); - - read_to_vec::(&buffer, len) -} - #[test] fn const_fill() { - let fills = [ - "fill_u8", - "fill_u32", - "fill_i64", - "fill_f16", - "fill_bf16", - "fill_f32", - ]; - - for name in fills { + fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { + let dev = device(); + let kernels = Kernels::new(); + let command_queue = dev.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let buffer = dev.new_buffer( + (len * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModePrivate, + ); + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec::(&buffer, len) + } + fn test T>(name: &'static str, f: F) { let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); let value = rand::thread_rng().gen_range(1. ..19.); - - match name { - "fill_u8" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as u8; len]) - } - "fill_u32" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as u32; len]) - } - "fill_i64" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as i64; len]) - } - "fill_f16" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![f16::from_f32(value); len]) - } - "fill_bf16" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![bf16::from_f32(value); len]) - } - "fill_f32" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value; len]) - } - _ => unimplemented!(), - }; + let v = constant_fill::(name, len, value); + assert_eq!(v, vec![f(value); len]) } + test::("fill_u8", |v| v as u8); + test::("fill_u32", |v| v as u32); + test::("fill_i64", |v| v as i64); + test::("fill_f16", f16::from_f32); + test::("fill_bf16", bf16::from_f32); + test::("fill_f32", |v| v); }