Skip to content

Commit

Permalink
fix data race
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Jun 14, 2024
1 parent 80ba2ae commit bff3a20
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -850,16 +850,18 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin

const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);

x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = 0;

#pragma unroll
for (int l = 0; l < QR3_K; ++l) {
for (int l = 0; l < QR2_K; ++l) {
const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;

int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);

if (kqsx % QR2_K != 0) {
continue;
}

x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
}

Expand Down Expand Up @@ -1011,6 +1013,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);

if (kqsx % 2 != 0) {
continue;
}

x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
}
}
Expand Down

0 comments on commit bff3a20

Please sign in to comment.