From 993ca95e9e3108f0352fa2a3384cab0775c7f7c1 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 16 Oct 2024 14:13:03 +0300 Subject: [PATCH] iq4_ks: faster dot product on Metal (#90) TG-128(LLaMA-3.1-8B) goes to 52.5 t/s up from 48.4 t/s. Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-metal.metal | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 72595c91a2347..dff9326f45e28 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -6079,6 +6079,7 @@ void kernel_mul_mv_iq4_ks_f32_impl( float4 yl[4]; float2 sumf = 0.f; + float d[2]; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -6087,22 +6088,25 @@ void kernel_mul_mv_iq4_ks_f32_impl( float4 qf1, qf2; + device const float * dptr = (device const float *)cx; + d[0] = *dptr; + device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1) + ix; + dptr += row_size/4; + d[1] = *dptr; + for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - device const float * dptr = (device const float *)cx; + device const uint8_t * scales = x->scales; for (int row = 0; row < 2; ++row) { - //device const float * dptr = (device const float *)(cx + row*row_size); - const float d = *dptr; - device const block_iq4_ks * x = (device const block_iq4_ks *)(dptr + 1); - device const block_iq4_ks & xb = x[ibl]; - device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); + threadgroup const float * block_values = shared_values + ((scales[ib] & 1) << 4); + const float ls = ((scales[ib] & 254) - 127); - threadgroup const float * block_values = shared_values + ((xb.scales[ib] & 1) << 4); + device const uint32_t * q4 = (device const uint32_t *)scales + QK_K/128 + 4*ib + 2*il; float4 acc1 = {0.f}, acc2 = {0.f}; @@ -6122,14 +6126,14 @@ void kernel_mul_mv_iq4_ks_f32_impl( acc1 += acc2; - const int ls = (xb.scales[ib] & 254) - 127; - sumf[row] += d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); + sumf[row] += d[row] * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - dptr += row_size/4; + scales += row_size; } yb += 2 * QK_K; + x += 2; } sumf = simd_sum(sumf);