diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 47f54c8d59..e8159f46ff 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1244,6 +1244,7 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F32) => "gather_u32_f32", (DType::U32, DType::F16) => "gather_u32_f16", (DType::U32, DType::BF16) => "gather_u32_bf16", + (DType::U32, DType::U32) => "gather_u32_u32", (left, right) => crate::bail!("Metal gather {left:?} {right:?} not implemented"), }; let command_buffer = self.device.command_buffer()?; diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index c14f2c1ff1..2594689cf7 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -17,33 +17,33 @@ METAL_FUNC uint get_strided_index( } template -METAL_FUNC void index( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, +METAL_FUNC void index( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, constant size_t &ids_size, constant bool &contiguous, constant size_t *src_dims, constant size_t *src_strides, const device TYPENAME *input, - const device INDEX_TYPENAME *input_ids, - device TYPENAME *output, - uint tid [[ thread_position_in_grid ]] -) { - if (tid >= dst_size) { + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { return; - } - const size_t id_i = (tid / right_size) % ids_size; - const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size / ids_size; - /* - // Force prevent out of bounds indexing - // since there doesn't seem to be a good way to force crash - // No need to check for zero we're only allowing unsized. - */ - const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; + } + const size_t id_i = (tid / right_size) % ids_size; + const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + /* + // Force prevent out of bounds indexing + // since there doesn't seem to be a good way to force crash + // No need to check for zero we're only allowing unsized. + */ + const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides); output[tid] = input[strided_src_i]; } @@ -68,25 +68,25 @@ kernel void NAME( \ template -METAL_FUNC void gather( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, - constant size_t &ids_size, - const device TYPENAME *input, - const device INDEX_TYPENAME *input_ids, - device TYPENAME *output, - uint tid [[ thread_position_in_grid ]] -) { - if (tid >= dst_size) { - return; - } - const INDEX_TYPENAME input_i = input_ids[tid]; - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size / ids_size; - const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; - output[tid] = input[src_i]; +METAL_FUNC void gather( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &ids_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const INDEX_TYPENAME input_i = input_ids[tid]; + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size / ids_size; + const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i; + output[tid] = input[src_i]; } # define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \ @@ -105,27 +105,27 @@ kernel void NAME( \ } template -METAL_FUNC void scatter_add( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, - constant size_t &dst_dim_size, - const device TYPENAME *input, - const device INDEX_TYPENAME *input_ids, - device TYPENAME *output, - uint tid [[ thread_position_in_grid ]] -) { - if (tid >= dst_size) { - return; - } - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size; +METAL_FUNC void scatter_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; for (unsigned int j = 0; j < src_dim_size; ++j) { - const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; const INDEX_TYPENAME idx = input_ids[src_i]; - const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; - output[dst_i] += input[src_i]; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } } @@ -145,28 +145,28 @@ kernel void NAME( \ } template -METAL_FUNC void index_add( - constant size_t &dst_size, - constant size_t &left_size, - constant size_t &src_dim_size, - constant size_t &right_size, - constant size_t &dst_dim_size, - constant size_t &ids_dim_size, - const device TYPENAME *input, - const device INDEX_TYPENAME *input_ids, - device TYPENAME *output, - uint tid [[ thread_position_in_grid ]] -) { - if (tid >= dst_size) { - return; - } - const size_t right_rank_i = tid % right_size; - const size_t left_rank_i = tid / right_size; +METAL_FUNC void index_add( + constant size_t &dst_size, + constant size_t &left_size, + constant size_t &src_dim_size, + constant size_t &right_size, + constant size_t &dst_dim_size, + constant size_t &ids_dim_size, + const device TYPENAME *input, + const device INDEX_TYPENAME *input_ids, + device TYPENAME *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dst_size) { + return; + } + const size_t right_rank_i = tid % right_size; + const size_t left_rank_i = tid / right_size; for (unsigned int j = 0; j < ids_dim_size; ++j) { const INDEX_TYPENAME idx = input_ids[j]; - const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; - const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; - output[dst_i] += input[src_i]; + const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i; + const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i; + output[dst_i] += input[src_i]; } } @@ -214,6 +214,7 @@ GATHER_OP(gather_u32_f16, uint, half) #if defined(__HAVE_BFLOAT__) GATHER_OP(gather_u32_bf16, uint, bfloat) #endif +GATHER_OP(gather_u32_u32, uint, uint) SCATTER_ADD_OP(sa_u32_f32, uint32_t, float) SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)