From ea2b79534ed6237e1fdfe1605738a9778f1eb23d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 3 Apr 2024 21:03:45 +0200 Subject: [PATCH 1/6] ggml : group all experts in a single ggml_mul_mat_id cuda : improve mmid row copy --- ggml-cuda.cu | 216 ++++++++++++++++++++++++++++++------- ggml-cuda/convert.cu | 2 + ggml.c | 125 +++++++++++---------- ggml.h | 7 +- llama.cpp | 59 +++++----- scripts/compare-commits.sh | 6 +- tests/test-backend-ops.cpp | 132 +++++++++++++++++++---- 7 files changed, 395 insertions(+), 152 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f51b2042df359..74a50594f632c 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas( if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 - ggml_cuda_pool_alloc src0_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src0_as_f16(ctx.pool(id)); if (src0->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas( } const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get(); - ggml_cuda_pool_alloc src1_as_f16(ctx.pool()); + ggml_cuda_pool_alloc src1_as_f16(ctx.pool(id)); if (src1->type != GGML_TYPE_F16) { const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); GGML_ASSERT(to_fp16_cuda != nullptr); @@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas( to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); } const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); - ggml_cuda_pool_alloc dst_f16(ctx.pool(), row_diff*src1_ncols); + ggml_cuda_pool_alloc dst_f16(ctx.pool(id), row_diff*src1_ncols); const half alpha_f16 = 1.0f; const half beta_f16 = 0.0f; @@ -1960,20 +1960,84 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } } +struct mmid_row_mapping { + int64_t i1; + int64_t i2; +}; + +static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous, + int * cur_src1_row, mmid_row_mapping * row_mapping, + const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0, + int64_t ids_ne1, int64_t n_ids, + int64_t ne11, + size_t nb11, size_t nb12) { + int64_t iid1 = blockIdx.x; + int64_t id = blockIdx.y; + + if (iid1 >= ids_ne1 || id >= n_ids) { + return; + } + + const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0); + + if (row_id_i != i02) { + return; + } + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + __shared__ int src1_row; + if (threadIdx.x == 0) { + src1_row = atomicAdd(cur_src1_row, 1); + row_mapping[src1_row] = {id, iid1}; + } + __syncthreads(); + + const char * src1_row_original = src1_original + i11*nb11 + i12*nb12; + char * src1_row_contiguous = src1_contiguous + src1_row*nb11; + + for (int i = threadIdx.x; i < nb11; i += blockDim.x) { + src1_row_contiguous[i] = src1_row_original[i]; + } +} + +static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous, + const mmid_row_mapping * row_mapping, + int64_t n_rows, + int64_t nb1, int64_t nb2) { + int64_t i = blockIdx.x; + + if (i >= n_rows) { + return; + } + + const int64_t i1 = row_mapping[i].i1; + const int64_t i2 = row_mapping[i].i2; + + const char * dst_row_contiguous = dst_contiguous + i*nb1; + char * dst_row_original = dst_original + i1*nb1 + i2*nb2; + + for (int j = threadIdx.x; j < nb1; j += blockDim.x) { + dst_row_original[j] = dst_row_contiguous[j]; + } +} + +//#define MMID_MEMCPY + static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * ids = dst->src[2]; + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers"); cudaStream_t stream = ctx.stream(); - const size_t nb11 = src1->nb[1]; - const size_t nb1 = dst->nb[1]; - - const int32_t id = ((int32_t *) dst->op_params)[0]; - const int32_t n_as = src0->ne[2]; + const int64_t n_as = ne02; + const int64_t n_ids = ids->ne[0]; std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -1982,7 +2046,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor src0_row = *src0; ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + ggml_tensor dst_row = *dst; char * src0_original = (char *) src0->data; char * src1_original = (char *) src1->data; @@ -1990,19 +2054,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src0_row.ne[2] = 1; src0_row.ne[3] = 1; - src0_row.nb[3] = src0->nb[2]; + src0_row.nb[3] = nb02; - if (src1->ne[1] == 1) { - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + src1_row.ne[1] = 1; + src1_row.ne[2] = 1; + src1_row.ne[3] = 1; + src1_row.nb[2] = nb11; + src1_row.nb[3] = nb11; - GGML_ASSERT(row_id >= 0 && row_id < n_as); + dst_row.ne[1] = 1; + dst_row.ne[2] = 1; + dst_row.ne[3] = 1; + dst_row.nb[2] = nb1; + dst_row.nb[3] = nb1; - src0_row.data = src0_original + row_id*src0->nb[2]; - src1_row.data = src1_original + i01*src1->nb[1]; - dst_row.data = dst_original + i01*dst->nb[1]; + if (ne12 == 1) { + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + + const int64_t i1 = id; + const int64_t i2 = i12; + + src0_row.data = src0_original + i02*nb02; + src1_row.data = src1_original + i11*nb11 + i12*nb12; + dst_row.data = dst_original + i1*nb1 + i2*nb2; + + ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } } } else { ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); @@ -2011,55 +2095,104 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int32_t row_id = 0; row_id < n_as; ++row_id) { + for (int64_t i02 = 0; i02 < n_as; i02++) { int64_t num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; - } + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + + if (row_id_i != i02) { + continue; + } - GGML_ASSERT(row_id >= 0 && row_id < n_as); + GGML_ASSERT(i02 >= 0 && i02 < n_as); - CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11, - nb11, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; +#ifdef MMID_MEMCPY + const int64_t i11 = id % ne11; + const int64_t i12 = iid1; + CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, + src1_original + i11*nb11 + i12*nb12, + nb11, cudaMemcpyDeviceToDevice, stream)); +#endif + num_src1_rows++; + } } if (num_src1_rows == 0) { continue; } - src0_row.data = src0_original + row_id*src0->nb[2]; +#ifndef MMID_MEMCPY + ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); + ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); + CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); - src1_row.ne[1] = num_src1_rows; - dst_row.ne[1] = num_src1_rows; + { + dim3 block_dims(std::min((uint)nb11, 1024u)); + dim3 grid_dims(ids->ne[1], n_ids); + k_copy_src1_to_contiguous<<>>( + src1_original, src1_contiguous.get(), + dev_cur_src1_row.get(), dev_row_mapping.get(), + ids_dev, i02, ids->nb[1], ids->nb[0], + ids->ne[1], n_ids, + ne11, + nb11, nb12); + CUDA_CHECK(cudaGetLastError()); + } +#endif + + src0_row.data = src0_original + i02*nb02; + GGML_ASSERT(nb11 == sizeof(float)*ne10); + GGML_ASSERT(nb1 == sizeof(float)*ne0); + + src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); +#ifndef MMID_MEMCPY + { + dim3 block_dims(std::min((uint)nb1, 1024u)); + dim3 grid_dims(num_src1_rows); + k_copy_dst_from_contiguous<<>>( + dst_original, dst_contiguous.get(), + dev_row_mapping.get(), + num_src1_rows, nb1, nb2); + CUDA_CHECK(cudaGetLastError()); + } +#endif + +#ifdef MMID_MEMCPY num_src1_rows = 0; - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - if (row_id_i != row_id) { - continue; - } + if (row_id_i != i02) { + continue; + } - GGML_ASSERT(row_id >= 0 && row_id < n_as); + GGML_ASSERT(i02 >= 0 && i02 < n_as); - CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1, - nb1, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; + const int64_t i1 = id; + const int64_t i2 = iid1; + + CUDA_CHECK(cudaMemcpyAsync(dst_original + i1*nb1 + i2*nb2, + dst_contiguous.get() + num_src1_rows*nb1, + nb1, cudaMemcpyDeviceToDevice, stream)); + num_src1_rows++; + } } +#endif } } } @@ -2487,7 +2620,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); GGML_UNUSED(backend); } diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu index 18a31edc34f83..fa24480b83b53 100644 --- a/ggml-cuda/convert.cu +++ b/ggml-cuda/convert.cu @@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h vals[ix] = x0[ix]; } + __syncthreads(); + #pragma unroll for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) { if (need_check && i0 + iy + 2*threadIdx.x >= k) { diff --git a/ggml.c b/ggml.c index c9b0a6a0ef776..1a206c7f8a4e2 100644 --- a/ggml.c +++ b/ggml.c @@ -4573,21 +4573,32 @@ void ggml_mul_mat_set_prec( // ggml_mul_mat_id -// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed -// this will allow computing all the used experts in a single matrix multiplication +/* + c = ggml_mul_mat_id(ctx, as, b, ids); + + as -> [cols, rows, n_expert] + ids -> [n_experts_used, n_tokens] (i32) + b -> [cols, n_expert_used, n_tokens] + c -> [cols, n_expert_used, n_tokens] + + in b, n_experts_used can be broadcasted to match the n_expert_used of ids + + c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e in ids +*/ struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b) { - + struct ggml_tensor * b, + struct ggml_tensor * ids) { + GGML_ASSERT(!ggml_is_transposed(as)); GGML_ASSERT(ids->type == GGML_TYPE_I32); + + GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert) + GGML_ASSERT(b->ne[3] == 1); // b is 3d GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d - GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row - GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]); - GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id + GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat + GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast bool is_node = false; @@ -4595,11 +4606,9 @@ struct ggml_tensor * ggml_mul_mat_id( is_node = true; } - const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - ggml_set_op_params_i32(result, 0, id); - result->op = GGML_OP_MUL_MAT_ID; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = as; @@ -10958,10 +10967,10 @@ static void ggml_compute_forward_mul_mat_id( enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; - GGML_ASSERT(ne0 == ne01); - GGML_ASSERT(ne1 == ne11); - GGML_ASSERT(ne2 == ne12); - GGML_ASSERT(ne3 == ne13); + //GGML_ASSERT(ne0 == ne01); + //GGML_ASSERT(ne1 == ne11); + //GGML_ASSERT(ne2 == ne12); + //GGML_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 GGML_ASSERT(nb00 == ggml_type_size(type)); @@ -10973,13 +10982,9 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // broadcast is not supported with mmid - assert(ne12 == 1); - assert(ne13 == 1); - // row groups - const int id = ggml_get_op_params_i32(dst, 0); - const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert char * wdata_src1_end = (src1->type == vec_dot_type) ? (char *) params->wdata : @@ -10988,8 +10993,6 @@ static void ggml_compute_forward_mul_mat_id( int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11] - #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)] - if (params->type == GGML_TASK_TYPE_INIT) { if (ith != 0) { return; @@ -11015,13 +11018,21 @@ static void ggml_compute_forward_mul_mat_id( GGML_ASSERT(wdata == wdata_src1_end); memset(matrix_row_counts, 0, n_as*sizeof(int64_t)); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)] +#define MAKE_I64(lo, hi) (((int64_t)(lo)) | (((int64_t)(hi)) << 32)) +#define LO_I64(i64) ((int32_t)(i64)) +#define HI_I64(i64) ((int32_t)((i64) >> 32)) + // group rows by src0 matrix - for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) { - const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + assert(i02 >= 0 && i02 < n_as); - GGML_ASSERT(row_id >= 0 && row_id < n_as); - MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01; - matrix_row_counts[row_id] += 1; + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = MAKE_I64(iid1, id); + matrix_row_counts[i02] += 1; + } } return; @@ -11039,15 +11050,13 @@ static void ggml_compute_forward_mul_mat_id( continue; } - size_t src0_offset = cur_a*src0->nb[2]; + const char * src0_cur = (const char *) src0->data + cur_a*nb02; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - const int64_t nr0 = ne01; // src0 rows - const int64_t nr1 = cne1*ne12*ne13; // src1 rows - - //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1); + const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows // distribute the thread work across the inner or outer loop based on which one is larger @@ -11066,13 +11075,11 @@ static void ggml_compute_forward_mul_mat_id( const int64_t ir110 = dr1*ith1; const int64_t ir111 = MIN(ir110 + dr1, nr1); - //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111); - // threads with no work simply yield (not sure if it helps) - if (ir010 >= ir011 || ir110 >= ir111) { - sched_yield(); - continue; - } + //if (ir010 >= ir011 || ir110 >= ir111) { + // sched_yield(); + // continue; + //} // block-tiling attempt const int64_t blck_0 = 16; @@ -11084,20 +11091,15 @@ static void ggml_compute_forward_mul_mat_id( for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) { for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) { for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) { - const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix - const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1; - const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1); - const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11); + const int64_t _i12 = ir1; // logical row index for this expert - // broadcast src0 into src1 - //const int64_t i03 = i13/r3; - //const int64_t i02 = i12/r2; + const int id = HI_I64(MMID_MATRIX_ROW(cur_a, _i12)); // selected expert index - const int64_t i1 = i11; - const int64_t i2 = i12; - const int64_t i3 = i13; + const int64_t i11 = id % ne11; + const int64_t i12 = LO_I64(MMID_MATRIX_ROW(cur_a, _i12)); // row index in src1 - const char * src0_row = (const char *) src0->data + src0_offset; + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using @@ -11105,25 +11107,30 @@ static void ggml_compute_forward_mul_mat_id( // TODO: this is a bit of a hack, we should probably have a better way to handle this const char * src1_col = (const char *) wdata + (src1_cont || src1->type != vec_dot_type - ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size - : (i11*nb11 + i12*nb12 + i13*nb13)); + ? (i11 + i12*ne11)*row_size + : (i11*nb11 + i12*nb12)); - float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); + float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2)); //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col); //} for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { - vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1); + vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1); } + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } } } } - #undef MMID_MATRIX_ROW +#undef MMID_MATRIX_ROW +#undef MAKE_I64 +#undef LO_I64 +#undef HI_I64 } // ggml_compute_forward_out_prod @@ -18462,7 +18469,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa const int n_as = src0->ne[2]; cur += GGML_PAD(cur, sizeof(int64_t)); // align cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows } break; case GGML_OP_OUT_PROD: { @@ -20862,12 +20869,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p ok = ok && cur != NULL; - ggml_set_name(cur, ctx->infos[i].name.data); - if (!ok) { break; } + ggml_set_name(cur, ctx->infos[i].name.data); + // point the data member to the appropriate location in the binary blob using the tensor infos if (!params.no_alloc) { //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file diff --git a/ggml.h b/ggml.h index 5cef45c0ba4ad..0350b4027fab4 100644 --- a/ggml.h +++ b/ggml.h @@ -1161,13 +1161,12 @@ extern "C" { enum ggml_prec prec); // indirect matrix multiplication - // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) + // TODO: document GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, struct ggml_tensor * as, - struct ggml_tensor * ids, - int id, - struct ggml_tensor * b); + struct ggml_tensor * b, + struct ggml_tensor * ids); // A: m columns, n rows, // B: p columns, n rows, diff --git a/llama.cpp b/llama.cpp index 9a1c11043b94a..c441ed5299b7a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6392,52 +6392,56 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens] cb(logits, "ffn_moe_logits", il); - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] cb(probs, "ffn_moe_probs", il); // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); + ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] cb(weights, "ffn_moe_weights", il); - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] + weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); + ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] + weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); - // compute expert outputs - ggml_tensor * moe_out = nullptr; + cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); + ggml_tensor * up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; + ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); - cb(cur_up, "ffn_moe_up", il); + gate = ggml_silu(ctx0, gate); + cb(gate, "ffn_moe_silu", il); - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); - cb(cur_gate, "ffn_moe_gate", il); + ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); - cur_gate = ggml_silu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_silu", il); + ggml_tensor * experts = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); + experts = ggml_mul(ctx0, experts, + ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens)); - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); + // FIXME: non-contiguous add broken in cuda + cur_expert = ggml_cont(ctx0, cur_expert); if (i == 0) { moe_out = cur_expert; @@ -6953,11 +6957,12 @@ struct llm_build_context { for (int i = 0; i < n_expert_used; ++i) { ggml_tensor * cur_expert; + // FIXME - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur); + ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, cur); cb(cur_up, "ffn_moe_up", il); - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur); + ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, cur); cb(cur_gate, "ffn_moe_gate", il); //GeLU @@ -6967,7 +6972,7 @@ struct llm_build_context { cur_expert = ggml_mul(ctx0, cur_up, cur_gate); cb(cur_expert, "ffn_moe_gate_par", il); - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, cur_expert); // [n_tokens, n_embd] cb(cur_expert, "ffn_moe_down", il); cur_expert = ggml_mul(ctx0, cur_expert, diff --git a/scripts/compare-commits.sh b/scripts/compare-commits.sh index d1272506cd58a..6d6699f40c1a4 100755 --- a/scripts/compare-commits.sh +++ b/scripts/compare-commits.sh @@ -22,9 +22,9 @@ fi make_opts="" -if [[ "$backend" == "cuda" ]]; then - make_opts="LLAMA_CUDA=1" -fi +#if [[ "$backend" == "cuda" ]]; then +# make_opts="LLAMA_CUDA=1" +#fi git checkout $1 make clean && make -j32 $make_opts llama-bench diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 51b3487b2a948..effeb39f7961a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -478,8 +478,9 @@ struct test_case { } double err = nmse(f1.data(), f2.data(), f1.size()); + printf("[%s] NMSE = %.9f > %.9f \n", ggml_op_desc(t1), err, ud->max_err); if (err > ud->max_err) { - printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); + //printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err); //for (int i = 0; i < (int) f1.size(); i++) { // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]); //} @@ -948,14 +949,14 @@ struct test_mul_mat_id : public test_case { const ggml_type type_a; const ggml_type type_b; const int n_mats; - const int id; + const int n_used; + const bool b; // brodcast b matrix const int64_t m; const int64_t n; const int64_t k; - const bool v; // view (non-contiguous ids) std::string vars() override { - return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v); + return VARS_TO_STR7(type_a, type_b, n_mats, b, m, n, k); } double max_nmse_err() override { @@ -972,20 +973,18 @@ struct test_mul_mat_id : public test_case { } test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, - int n_mats = 2, int id = 0, - int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false) - : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), - m(m), n(n), k(k), v(v) {} + int n_mats = 8, int n_used = 2, bool b = false, + int64_t m = 32, int64_t n = 32, int64_t k = 32) + : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b), + m(m), n(n), k(k) {} ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) - ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); + ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats); ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n); - if (v) { - ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0); - } - ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n); - ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b); + ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0); + ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n); + ggml_tensor * out = ggml_mul_mat_id(ctx, as, ids, b); return out; } @@ -1858,6 +1857,92 @@ struct test_falcon : public test_llm { } }; + +// Mixtral MOE +struct test_moe : public test_case { + const int n_expert; + const int n_expert_used; + const int n_tokens; + const int n_embd; + const int n_ff; + + std::string op_desc(ggml_tensor * t) override { + return "MOE"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR5(n_expert, n_expert_used, n_tokens, n_embd, n_ff); + } + + test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336) + : n_expert(n_experts), n_expert_used(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) { + } + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_type wtype = GGML_TYPE_F32; + ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_expert); + + ggml_tensor * ffn_gate_exps = ggml_new_tensor_3d(ctx, wtype, n_embd, n_ff, n_expert); + ggml_tensor * ffn_down_exps = ggml_new_tensor_3d(ctx, wtype, n_ff, n_embd, n_expert); + ggml_tensor * ffn_up_exps = ggml_new_tensor_3d(ctx, wtype, n_embd, n_ff, n_expert); + + ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_expert, n_tokens] + + //ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + + cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + ggml_tensor * up = ggml_mul_mat_id(ctx, ffn_up_exps, selected_experts, cur); // [n_ff, n_expert_used, n_tokens] + + ggml_tensor * gate = ggml_mul_mat_id(ctx, ffn_gate_exps, selected_experts, cur); // [n_ff, n_expert_used, n_tokens] + + gate = ggml_silu(ctx, gate); + + ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + + ggml_tensor * experts = ggml_mul_mat_id(ctx, ffn_down_exps, selected_experts, par); // [n_embd, n_expert_used, n_tokens] + + printf("mul: src0: %ld %ld %ld %ld\n", experts->ne[0], experts->ne[1], experts->ne[2], experts->ne[3]); + printf("mul: src1: %ld %ld %ld %ld\n", 1, n_expert_used, n_tokens, 1); + experts = ggml_mul(ctx, experts, + ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens)); + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + cur_expert = ggml_cont(ctx, cur_expert); + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + cur = moe_out; + + return cur; + } +}; + + static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { std::vector> test_cases; std::default_random_engine rng(0); @@ -1875,6 +1960,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, }; +test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024)); +test_cases.emplace_back(new test_moe(8, 2, 32, 4096, 8*1024)); + // unary ops for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { test_cases.emplace_back(new test_unary((ggml_unary_op) op)); @@ -1944,6 +2032,10 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } }; + // mul: src0: 4096 2 32 1 + // mul: src1: 1 2 32 1 + add_test_bin_bcast(GGML_TYPE_F32, {1, 2, 32, 1}, {4096, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1}); @@ -2012,10 +2104,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {2, 4, 8}) { - for (int id = 0; id < n_mats; id++) { - for (bool v : {false, true}) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v)); - } + for (bool b : {false, true}) { + // cur shape: 4096 1 32 1 + // ffn_up_exps shape: 4096 8192 8 1 + // selected_experts shape: 2 32 1 1 + int m = 8192; + int n = 32; + int k = 4096; + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, 2, b, m, n, k)); } } } From 1b5d78d3ee237571f9035332cd7c25ad7c4ed368 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 6 Apr 2024 15:19:47 +0200 Subject: [PATCH 2/6] minor --- ggml-cuda.cu | 61 ++++++++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 35 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 74a50594f632c..297fdbe130b41 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1961,22 +1961,17 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } struct mmid_row_mapping { - int64_t i1; - int64_t i2; + int32_t i1; + int32_t i2; }; -static __global__ void k_copy_src1_to_contiguous(const char * src1_original, char * src1_contiguous, - int * cur_src1_row, mmid_row_mapping * row_mapping, - const char * ids_dev, int64_t i02, int64_t ids_nb1, int64_t ids_nb0, - int64_t ids_ne1, int64_t n_ids, - int64_t ne11, +static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, + int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, + const char * ids_dev, int64_t i02, size_t ids_nb1, size_t ids_nb0, + int64_t ne11, int64_t ne10, size_t nb11, size_t nb12) { - int64_t iid1 = blockIdx.x; - int64_t id = blockIdx.y; - - if (iid1 >= ids_ne1 || id >= n_ids) { - return; - } + int32_t iid1 = blockIdx.x; + int32_t id = blockIdx.y; const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0); @@ -1994,31 +1989,27 @@ static __global__ void k_copy_src1_to_contiguous(const char * src1_original, cha } __syncthreads(); - const char * src1_row_original = src1_original + i11*nb11 + i12*nb12; - char * src1_row_contiguous = src1_contiguous + src1_row*nb11; + const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); + float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); - for (int i = threadIdx.x; i < nb11; i += blockDim.x) { + for (int i = threadIdx.x; i < ne10; i += blockDim.x) { src1_row_contiguous[i] = src1_row_original[i]; } } -static __global__ void k_copy_dst_from_contiguous(char * dst_original, const char * dst_contiguous, - const mmid_row_mapping * row_mapping, - int64_t n_rows, - int64_t nb1, int64_t nb2) { - int64_t i = blockIdx.x; - - if (i >= n_rows) { - return; - } +static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous, + const mmid_row_mapping * __restrict__ row_mapping, + int64_t ne0, + size_t nb1, size_t nb2) { + int32_t i = blockIdx.x; - const int64_t i1 = row_mapping[i].i1; - const int64_t i2 = row_mapping[i].i2; + const int32_t i1 = row_mapping[i].i1; + const int32_t i2 = row_mapping[i].i2; - const char * dst_row_contiguous = dst_contiguous + i*nb1; - char * dst_row_original = dst_original + i1*nb1 + i2*nb2; + const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1); + float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2); - for (int j = threadIdx.x; j < nb1; j += blockDim.x) { + for (int j = threadIdx.x; j < ne0; j += blockDim.x) { dst_row_original[j] = dst_row_contiguous[j]; } } @@ -2129,14 +2120,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); { - dim3 block_dims(std::min((uint)nb11, 1024u)); + dim3 block_dims(std::min((uint)ne10, 512u)); dim3 grid_dims(ids->ne[1], n_ids); k_copy_src1_to_contiguous<<>>( src1_original, src1_contiguous.get(), dev_cur_src1_row.get(), dev_row_mapping.get(), ids_dev, i02, ids->nb[1], ids->nb[0], - ids->ne[1], n_ids, - ne11, + ne11, ne10, nb11, nb12); CUDA_CHECK(cudaGetLastError()); } @@ -2161,12 +2151,13 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * #ifndef MMID_MEMCPY { - dim3 block_dims(std::min((uint)nb1, 1024u)); + dim3 block_dims(std::min((uint)ne0, 512u)); dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), dev_row_mapping.get(), - num_src1_rows, nb1, nb2); + ne0, + nb1, nb2); CUDA_CHECK(cudaGetLastError()); } #endif From bc615548d3f69b21a3602ca2afc37292a2139585 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 7 Apr 2024 20:31:09 +0200 Subject: [PATCH 3/6] fix windows build --- ggml-cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 297fdbe130b41..67c04a1765ae9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1967,13 +1967,13 @@ struct mmid_row_mapping { static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous, int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping, - const char * ids_dev, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12) { int32_t iid1 = blockIdx.x; int32_t id = blockIdx.y; - const int32_t row_id_i = *(const int32_t *) (ids_dev + iid1*ids_nb1 + id*ids_nb0); + const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); if (row_id_i != i02) { return; @@ -2120,7 +2120,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); { - dim3 block_dims(std::min((uint)ne10, 512u)); + dim3 block_dims(std::min((unsigned int)ne10, 512u)); dim3 grid_dims(ids->ne[1], n_ids); k_copy_src1_to_contiguous<<>>( src1_original, src1_contiguous.get(), @@ -2151,7 +2151,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * #ifndef MMID_MEMCPY { - dim3 block_dims(std::min((uint)ne0, 512u)); + dim3 block_dims(std::min((unsigned int)ne0, 512u)); dim3 grid_dims(num_src1_rows); k_copy_dst_from_contiguous<<>>( dst_original, dst_contiguous.get(), From f3f7627bd8e40e53d01227dd912dedc288786eb6 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 7 Apr 2024 21:14:23 +0200 Subject: [PATCH 4/6] refactor moe ffn to llm_build_moe_ffn --- llama.cpp | 222 +++++++++++++++++++++++++----------------------------- 1 file changed, 103 insertions(+), 119 deletions(-) diff --git a/llama.cpp b/llama.cpp index c441ed5299b7a..00132b53436e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5849,6 +5849,93 @@ static struct ggml_tensor * llm_build_ffn( return cur; } +static struct ggml_tensor * llm_build_moe_ffn( + struct ggml_context * ctx, + struct ggml_tensor * cur, + struct ggml_tensor * gate_inp, + struct ggml_tensor * up_exps, + struct ggml_tensor * gate_exps, + struct ggml_tensor * down_exps, + int64_t n_expert, + int64_t n_expert_used, + llm_ffn_op_type type_op, + const llm_build_cb & cb, + int il) { + int64_t n_embd = cur->ne[0]; + int64_t n_tokens = cur->ne[1]; + + ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens] + cb(probs, "ffn_moe_probs", il); + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens] + cb(selected_experts->src[0], "ffn_moe_argsort", il); + cb(selected_experts, "ffn_moe_topk", il); + + ggml_tensor * weights = ggml_get_rows(ctx, + ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights", il); + + weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens); + + ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] + cb(weights_sum, "ffn_moe_weights_sum", il); + + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] + cb(weights, "ffn_moe_weights_norm", il); + + cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); + ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + switch (type_op) { + case LLM_FFN_SILU: + { + gate = ggml_silu(ctx, gate); + cb(gate, "ffn_moe_silu", il); + } break; + case LLM_FFN_GELU: + { + gate = ggml_gelu(ctx, gate); + cb(gate, "ffn_moe_gelu", il); + } break; + default: + GGML_ASSERT(false); + } + + ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + cb(par, "ffn_moe_gate_par", il); + + ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + cb(experts, "ffn_moe_down", il); + + experts = ggml_mul(ctx, experts, + ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens)); + + // aggregate experts + ggml_tensor * moe_out = nullptr; + for (int i = 0; i < n_expert_used; ++i) { + ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens, + experts->nb[2], i*experts->nb[1]); + + // FIXME: non-contiguous add broken in cuda + cur_expert = ggml_cont(ctx, cur_expert); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx, moe_out, cur_expert); + } + } + + return moe_out; +} // if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, @@ -6392,66 +6479,14 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_expert_used, n_tokens] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - cb(selected_experts, "ffn_moe_topk", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens] - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens] - cb(weights, "ffn_moe_weights_norm", il); - - cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens); - ggml_tensor * up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); - - gate = ggml_silu(ctx0, gate); - cb(gate, "ffn_moe_silu", il); - - ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens] - cb(par, "ffn_moe_gate_par", il); - - ggml_tensor * experts = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] - cb(experts, "ffn_moe_down", il); - - experts = ggml_mul(ctx0, experts, - ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens)); - - // aggregate experts - ggml_tensor * moe_out = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); - - // FIXME: non-contiguous add broken in cuda - cur_expert = ggml_cont(ctx0, cur_expert); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, cb, il); + cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -6930,64 +6965,14 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] - cb(logits, "ffn_moe_logits", il); - - ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] - cb(probs, "ffn_moe_probs", il); - - // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok] - cb(selected_experts->src[0], "ffn_moe_argsort", il); - - ggml_tensor * weights = ggml_get_rows(ctx0, - ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); - cb(weights, "ffn_moe_weights", il); - - weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok] - - ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); - cb(weights_sum, "ffn_moe_weights_sum", il); - - weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok] - cb(weights, "ffn_moe_weights_norm", il); - - // compute expert outputs - ggml_tensor * moe_out = nullptr; - - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert; - // FIXME - - ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, cur); - cb(cur_up, "ffn_moe_up", il); - - ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, cur); - cb(cur_gate, "ffn_moe_gate", il); - - //GeLU - cur_gate = ggml_gelu(ctx0, cur_gate); - cb(cur_gate, "ffn_moe_gelu", il); - - cur_expert = ggml_mul(ctx0, cur_up, cur_gate); - cb(cur_expert, "ffn_moe_gate_par", il); - - cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, cur_expert); // [n_tokens, n_embd] - cb(cur_expert, "ffn_moe_down", il); - - cur_expert = ggml_mul(ctx0, cur_expert, - ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); - cb(cur_expert, "ffn_moe_weighted", il); - - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - cb(moe_out, "ffn_moe_out", il); - } - } - - cur = moe_out; + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_GELU, cb, il); + cb(cur, "ffn_moe_out", il); // Grok // if layer_out_norm is present then apply it before adding the input @@ -6999,7 +6984,6 @@ struct llm_build_context { cb(cur, "layer_out_norm", il); } - cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); From 23f7d71a2b98d3d596b82be025c5fa803a9f7d37 Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 7 Apr 2024 22:04:33 +0200 Subject: [PATCH 5/6] cleanup --- ggml-cuda.cu | 40 ++-------------------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 67c04a1765ae9..0511923c8da50 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2014,8 +2014,6 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin } } -//#define MMID_MEMCPY - static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -2093,19 +2091,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * for (int64_t id = 0; id < n_ids; id++) { const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + if (row_id_i != i02) { continue; } - GGML_ASSERT(i02 >= 0 && i02 < n_as); - -#ifdef MMID_MEMCPY - const int64_t i11 = id % ne11; - const int64_t i12 = iid1; - CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, - src1_original + i11*nb11 + i12*nb12, - nb11, cudaMemcpyDeviceToDevice, stream)); -#endif num_src1_rows++; } } @@ -2114,7 +2105,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * continue; } -#ifndef MMID_MEMCPY ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1); ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream)); @@ -2130,7 +2120,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * nb11, nb12); CUDA_CHECK(cudaGetLastError()); } -#endif src0_row.data = src0_original + i02*nb02; @@ -2149,7 +2138,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row); -#ifndef MMID_MEMCPY { dim3 block_dims(std::min((unsigned int)ne0, 512u)); dim3 grid_dims(num_src1_rows); @@ -2160,30 +2148,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * nb1, nb2); CUDA_CHECK(cudaGetLastError()); } -#endif - -#ifdef MMID_MEMCPY - num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); - - if (row_id_i != i02) { - continue; - } - - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - const int64_t i1 = id; - const int64_t i2 = iid1; - - CUDA_CHECK(cudaMemcpyAsync(dst_original + i1*nb1 + i2*nb2, - dst_contiguous.get() + num_src1_rows*nb1, - nb1, cudaMemcpyDeviceToDevice, stream)); - num_src1_rows++; - } - } -#endif } } } From 9a43e808200cd49a01095c18723cd9659a5a747e Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 8 Apr 2024 18:01:02 +0200 Subject: [PATCH 6/6] update imatrix --- examples/imatrix/imatrix.cpp | 57 ++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index d8cb0a6420456..295a326f5884c 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -44,7 +44,7 @@ class IMatrixCollector { std::mutex m_mutex; int m_last_call = 0; std::vector m_src1_data; - std::vector m_ids; // the expert ids from ggml_mul_mat_id + std::vector m_ids; // the expert ids from ggml_mul_mat_id // void save_imatrix(const char * file_name) const; void keep_imatrix(int ncall) const; @@ -81,6 +81,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * if (ask) { if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications if (t->op != GGML_OP_MUL_MAT) return false; + // why are small batches ignored (<16 tokens)? if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false; if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false; return true; @@ -101,16 +102,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // this has been adapted to the new format of storing merged experts in a single 3d tensor // ref: https://github.com/ggerganov/llama.cpp/pull/6387 if (t->op == GGML_OP_MUL_MAT_ID) { - const int idx = ((int32_t *) t->op_params)[0]; + // ids -> [n_experts_used, n_tokens] + // src1 -> [cols, n_expert_used, n_tokens] const ggml_tensor * ids = t->src[2]; const int n_as = src0->ne[2]; + const int n_ids = ids->ne[0]; // the top-k selected expert ids are stored in the ids tensor // for simplicity, always copy ids to host, because it is small // take into account that ids is not contiguous! - GGML_ASSERT(ids->ne[1] == src1->ne[1]); - GGML_ASSERT(n_as*ggml_nrows(ids)*sizeof(int) == GGML_PAD(ggml_nbytes(ids), n_as*sizeof(int))); - m_ids.resize(ggml_nbytes(ids)/sizeof(int)); + + GGML_ASSERT(ids->ne[1] == src1->ne[2]); + + m_ids.resize(ggml_nbytes(ids)); ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids)); auto & e = m_stats[wname]; @@ -120,26 +124,35 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * // using the following line, we can correct for that if needed by replacing the line above with: //if (idx == t->src[0]->ne[0] - 1) ++e.ncall; + if (e.values.empty()) { + e.values.resize(src1->ne[0]*n_as, 0); + } + else if (e.values.size() != (size_t)src1->ne[0]*n_as) { + fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); + exit(1); //GGML_ASSERT(false); + } + if (m_params.verbosity > 1) { + printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type); + } // loop over all possible experts, regardless if they are used or not in the batch for (int ex = 0; ex < n_as; ++ex) { size_t e_start = ex*src1->ne[0]; - if (e.values.empty()) { - e.values.resize(src1->ne[0]*n_as, 0); - } - else if (e.values.size() != (size_t)src1->ne[0]*n_as) { - fprintf(stderr, "Oops: inconsistent size for %s (%d vs %d)\n", wname.c_str(), (int)e.values.size(), (int)src1->ne[0]*n_as); - exit(1); //GGML_ASSERT(false); - } - if (m_params.verbosity > 1) { - printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type); - } - for (int row = 0; row < (int)src1->ne[1]; ++row) { - const int excur = m_ids[row*n_as + idx]; - GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check - if (excur != ex) continue; - const float * x = data + row * src1->ne[0]; - for (int j = 0; j < (int)src1->ne[0]; ++j) { - e.values[e_start + j] += x[j]*x[j]; + + for (int idx = 0; idx < n_ids; ++idx) { + for (int row = 0; row < (int)src1->ne[2]; ++row) { + const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]); + + GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check + + if (excur != ex) continue; + + const int64_t i11 = idx % src1->ne[1]; + const int64_t i12 = row; + const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]); + + for (int j = 0; j < (int)src1->ne[0]; ++j) { + e.values[e_start + j] += x[j]*x[j]; + } } } if (e.ncall > m_last_call) {