diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index b0df49978d..3c24c0e546 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -321,7 +321,7 @@ fn mul_mat_via_q8_1( // Start by quantizing y let k_padded = pad(k, MATRIX_ROW_PADDING); let y_size_in_bytes = - k_padded * y_rows * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); + k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?; @@ -707,4 +707,28 @@ mod test { assert_eq!(vs[15], 13138824.0); Ok(()) } + + // The following test used to fail under compute-sanitizer until #2526. + #[test] + fn cuda_mm_q8_1_pad() -> Result<()> { + let dev = CudaDevice::new(0)?; + let (x_rows, ncols, y_cols) = (4, 16, 2048); + let vs: Vec = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect(); + let y = dev.htod_sync_copy(&vs).w()?; + let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?; + xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; + let cuda_storage = mul_mat_via_q8_1( + &xs.data, + &y.slice(..), + /* dtype */ GgmlDType::Q4_0, + /* x_rows */ x_rows, + /* x_cols */ ncols, + /* y_rows */ ncols, + /* y_cols */ y_cols, + &dev, + )?; + let vs = cuda_storage.as_cuda_slice::()?; + let _vs = dev.dtoh_sync_copy(&vs.slice(..)).unwrap(); + Ok(()) + } }