Skip to content

Commit

Permalink
Add a few metal gather ops.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jan 25, 2025
1 parent 333d94a commit cb29411
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,12 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F16) => "gather_u32_f16",
(DType::U32, DType::BF16) => "gather_u32_bf16",
(DType::U32, DType::U32) => "gather_u32_u32",
(DType::U32, DType::I64) => "gather_u32_i64",
(DType::I64, DType::F32) => "gather_i64_f32",
(DType::I64, DType::F16) => "gather_i64_f16",
(DType::I64, DType::BF16) => "gather_i64_bf16",
(DType::I64, DType::U32) => "gather_i64_u32",
(DType::I64, DType::I64) => "gather_i64_i64",
(left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"),
};
let command_buffer = self.device.command_buffer()?;
Expand Down
6 changes: 6 additions & 0 deletions candle-metal-kernels/src/indexing.metal
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,18 @@ INDEX_OP(is_u8_f16, uint8_t, half)
INDEX_OP(is_u8_bf16, uint8_t, bfloat)
#endif

GATHER_OP(gather_i64_f32, int64_t, float)
GATHER_OP(gather_i64_f16, int64_t, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
#if defined(__HAVE_BFLOAT__)
GATHER_OP(gather_i64_bf16, int64_t, bfloat)
GATHER_OP(gather_u32_bf16, uint, bfloat)
#endif
GATHER_OP(gather_i64_u32, int64_t, uint)
GATHER_OP(gather_u32_u32, uint, uint)
GATHER_OP(gather_i64_i64, int64_t, int64_t)
GATHER_OP(gather_u32_i64, uint, int64_t)

SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
Expand Down
4 changes: 2 additions & 2 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2029,7 +2029,7 @@ pub fn call_sdpa_vector_2pass(
)]));

let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, &name_pass1, constants)?;
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down Expand Up @@ -2104,7 +2104,7 @@ pub fn call_sdpa_vector_2pass(

let b = (q_shape[0] * q_shape[1]) as i32;

let pipeline = kernels.load_pipeline(device, Source::Sdpa, &name_pass2)?;
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
Expand Down

0 comments on commit cb29411

Please sign in to comment.