From 2c5da05ca1839327b67f5ff93bf11fad6150d938 Mon Sep 17 00:00:00 2001 From: thucpham Date: Tue, 11 Jun 2024 14:05:53 +0200 Subject: [PATCH 01/10] test hqq quantization 4 bit --- include/ctranslate2/layers/common.h | 1 + include/ctranslate2/models/model.h | 1 - include/ctranslate2/ops/dequantize.h | 11 +++++ include/ctranslate2/types.h | 1 + python/ctranslate2/specs/common_spec.py | 1 + python/ctranslate2/specs/model_spec.py | 56 ++++++++++++++++++++++++- src/cpu/primitives.cc | 1 + src/layers/attention.cc | 3 ++ src/layers/common.cc | 9 ++-- src/layers/transformer.cc | 3 ++ src/models/model.cc | 3 +- src/ops/dequantize.cc | 30 +++++++++++++ src/ops/dequantize_cpu.cc | 41 ++++++++++++++++++ src/ops/dequantize_gpu.cu | 45 ++++++++++++++++++++ src/storage_view.cc | 6 +++ src/type_dispatch.h | 3 ++ src/types.cc | 2 + 17 files changed, 211 insertions(+), 6 deletions(-) diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index 3985b3feb..eb1f09db9 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -62,6 +62,7 @@ namespace ctranslate2 { const StorageView& _embeddings; const DataType _output_type; const StorageView* _qscale; + const StorageView* _qzero; }; // This enum order should remain fixed. diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 32e4f8403..aedf9930f 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -191,7 +191,6 @@ namespace ctranslate2 { std::unordered_map> _variable_index; bool _use_flash_attention = false; bool _tensor_parallel = false; - QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2; }; template<> diff --git a/include/ctranslate2/ops/dequantize.h b/include/ctranslate2/ops/dequantize.h index 06d3c5c85..fb2ec8734 100644 --- a/include/ctranslate2/ops/dequantize.h +++ b/include/ctranslate2/ops/dequantize.h @@ -14,6 +14,11 @@ namespace ctranslate2 { const StorageView& scale, StorageView& output) const; + void operator()(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const; + // Rescales the int32 GEMM output to float32, given the input scales. void operator()(const StorageView& c, const StorageView& a_scale, @@ -29,6 +34,12 @@ namespace ctranslate2 { const StorageView& scale, StorageView& output) const; + template + void dequantize_i4(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const; + template void dequantize_gemm_output(const StorageView& c, const StorageView& a_scale, diff --git a/include/ctranslate2/types.h b/include/ctranslate2/types.h index a6dc0e1fb..e0ed0e1ac 100644 --- a/include/ctranslate2/types.h +++ b/include/ctranslate2/types.h @@ -20,6 +20,7 @@ namespace ctranslate2 { INT32, FLOAT16, BFLOAT16, + UINT8, }; std::string dtype_name(DataType type); diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index b1162839c..f0f3b1a82 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -62,4 +62,5 @@ class EmbeddingsSpec(model_spec.LayerSpec): def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL + self.weight_zero = model_spec.OPTIONAL self.multiply_by_sqrt_depth = model_spec.OPTIONAL diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 41710ff41..9ed09ca38 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -25,6 +25,10 @@ CURRENT_BINARY_VERSION = 6 ACCEPTED_MODEL_TYPES = ( + "int4", + "int4_float32", + "int4_float16", + "int4_bfloat16", "int8", "int8_float32", "int8_float16", @@ -188,6 +192,11 @@ def _alias_variables(self): setattr(spec, attr_name, other_name) break + def _pack_4bit_u8(self, w_q): + w_q = w_q.astype(np.uint8) + _step = int(w_q.shape[1] / 2) + return (w_q[:, :_step] << 4) | w_q[:, _step:] + def _quantize(self, quantization): """Possibly quantizes the variable of the layer.""" if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES: @@ -202,6 +211,7 @@ def _quantize(spec, name, value): key = _split_scope(name)[-1] scale = None + zero = None is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16") @@ -244,6 +254,48 @@ def _quantize(spec, name, value): value = NumpyVariable(value) elif quantization in ("float16", "bfloat16", "float32"): value = value.to(quantization) + elif quantization in ( + "int4", + "int4_float32", + "int4_float16", + "int4_bfloat16", + ) and value.shape != 3: + value = value.to("float32").numpy() + print("AAAAAAAAAAAAA") + print(value) + group_size = 32 + old_shape = value.shape + new_shape = old_shape[:-1] + (old_shape[-1] // 2,) + value = value.reshape(-1, group_size) + gmax = np.amax(value, axis=1) + gmin = np.amin(value, axis=1) + + max_v = 15 + min_v = 0 + scale = np.clip(max_v / (gmax - gmin), 0, 2e4) + zero = -gmin * scale + + print(name) + print("XXXXXXXXXXXXXXXX") + print(value) + value = np.clip(np.round(value * np.expand_dims(scale, 1) + np.expand_dims(zero, 1)), min_v, max_v) + print("YYYYYYYYYYYYYYYY") + print(value) + value = self._pack_4bit_u8(value) + print(value.shape) + value = value.reshape(new_shape) + zero = zero.reshape(new_shape[0], -1) + scale = scale.reshape(new_shape[0], -1) + print("ZZZZZZZZZZZZZZZZ") + print(value) + print("scale") + print(scale) + print("zero") + print(zero) + + zero = NumpyVariable(zero) + scale = NumpyVariable(scale) + value = NumpyVariable(value) elif is_convertible: if quantization in ("float16", "int8_float16"): @@ -256,6 +308,8 @@ def _quantize(spec, name, value): setattr(spec, key, value) if scale is not None: setattr(spec, "%s_scale" % key, scale) + if zero is not None: + setattr(spec, "%s_zero" % key, zero) self._visit(_quantize) @@ -279,7 +333,7 @@ def _visit(self, fn): def _dtype_to_type_id(object_dtype): # Order should match the DataType enum in include/ctranslate2/types.h - dtypes = ("float32", "int8", "int16", "int32", "float16", "bfloat16") + dtypes = ("float32", "int8", "int16", "int32", "float16", "bfloat16", "uint8") try: return dtypes.index(object_dtype) except ValueError: diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index 0c6377bbb..dbbc04c4a 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -1207,6 +1207,7 @@ namespace ctranslate2 { template void \ primitives::mul(T a, const T* x, T* y, dim_t size); + DECLARE_IMPL_NO_FLOAT(uint8_t) DECLARE_IMPL_NO_FLOAT(int8_t) DECLARE_IMPL_NO_FLOAT(int16_t) DECLARE_IMPL_NO_FLOAT(int32_t) diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 18e2710f7..69130108e 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -9,6 +9,7 @@ #include "dispatch.h" #include "cpu/parallel.h" +#include namespace ctranslate2 { namespace layers { @@ -319,7 +320,9 @@ namespace ctranslate2 { q = &queries_proj; } + //std::cout << "qqqqqqqqqqqqqqqqqqq: " << *q << std::endl; _linear[0](*q, fused_proj); + //std::cout << "fused_projjjjjjjjjjjjj: " << fused_proj << std::endl; dim_t beam_size = 1; diff --git a/src/layers/common.cc b/src/layers/common.cc index 86fb66a7d..9f602c1bc 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -271,7 +271,6 @@ namespace ctranslate2 { , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) - , _qzero(model.get_variable_if_exists(scope + "/weight_zero")) , _u8_shift_compensation((_weight.device() == Device::CPU && _weight.dtype() == DataType::INT8 && cpu::prefer_u8s8s32_gemm()) @@ -282,7 +281,6 @@ namespace ctranslate2 { , _partial_qscale(_weight.device(), DataType::FLOAT32) , _partial_u8_shift_compensation(_weight.device(), DataType::INT32) , _output_type(get_default_float_type(model.effective_compute_type())) - , _quant_method(model.quant_method()) , _quantized_gemm(_weight.dtype() == DataType::INT16 || _weight.dtype() == DataType::INT8) , _gemm_op(/*alpha=*/1, /*beta=*/0, @@ -297,7 +295,6 @@ namespace ctranslate2 { /*shift_to_uint8=*/bool(_u8_shift_compensation), /*round_before_cast=*/model.round_before_cast_in_quantization()) , _dequantize_op(activation_type) - , _activation_type(activation_type) , _is_layer_out(is_layer_out) { } @@ -427,6 +424,12 @@ namespace ctranslate2 { throw std::invalid_argument("Dense forward: invalid quantized type," "support only ct2 and awq quantization"); } + } else if (input.dim(-1) != weight->dim(-1)) { + //std::cout << "weighttttttttttttt: " << *weight << std::endl; + StorageView weight_dequant(input.dtype(), input.device()); + _dequantize_op(*weight, *qscale, *_qzero, weight_dequant); + //std::cout << "weighttttttttttttt dequant: " << weight_dequant << std::endl; + _gemm_op(input, weight_dequant, output, nullptr, bias); } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 291101eae..cf470f41a 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -1,6 +1,7 @@ #include "ctranslate2/layers/transformer.h" #include +#include namespace ctranslate2 { namespace layers { @@ -486,7 +487,9 @@ namespace ctranslate2 { StorageView layer_in(dtype, device); StorageView layer_out(dtype, device); + //std::cout << "idssssssss: " << ids << std::endl; _embeddings(ids, layer_in); + //std::cout << "layer innnnnnnnnnn: " << layer_in << std::endl; if (_start_from_zero_embedding) zero_first_timestep(layer_in, step); if (_embeddings_scale && (!_start_from_zero_embedding || step != 0)) diff --git a/src/models/model.cc b/src/models/model.cc index b8e1c2d8f..d7b3903f7 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -6,6 +6,7 @@ #include "ctranslate2/ops/ops.h" #include "ctranslate2/utils.h" #include +#include #ifdef CT2_WITH_CUDA # include "cuda/utils.h" @@ -173,7 +174,7 @@ namespace ctranslate2 { _device_index = index; } - void Model::set_compute_type(ComputeType type, Device device, int device_index, bool update_weight) { + void Model::set_compute_type(ComputeType type, Device device, int device_index) { if (_device != Device::CPU) throw std::runtime_error("set_compute_type expects the variables to be on CPU"); diff --git a/src/ops/dequantize.cc b/src/ops/dequantize.cc index 4463cf948..b5b52abdd 100644 --- a/src/ops/dequantize.cc +++ b/src/ops/dequantize.cc @@ -43,6 +43,36 @@ namespace ctranslate2 { } } + void Dequantize::operator()(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const { + PROFILE("Dequantize lower bit"); + output.resize_as(input); + + //switch (input.dtype()) { + //case DataType::INT4: { + const dim_t block_size = 32; + const dim_t batch_size = input.size() / input.dim(-1); + const dim_t nb_group = input.dim(-1) / (block_size / 2); + if (scale.size() != batch_size * nb_group) + throw std::invalid_argument("INT4 dequantization expects per-block scales in each batch"); + auto shape = input.shape(); + shape[shape.size() - 1] *= 2; + output.resize(std::move(shape)); + + DEVICE_AND_FLOAT_DISPATCH("Dequantize", output.device(), output.dtype(), + (dequantize_i4(input, scale, zero, output))); + + // break; + // } + + // default: + // throw std::invalid_argument("Dequantize: invalid quantized type " + dtype_name(input.dtype()) + // + ", expected int8 or int16"); + //} + } + void Dequantize::operator()(const StorageView& c, const StorageView& a_scale, const StorageView& b_scale, diff --git a/src/ops/dequantize_cpu.cc b/src/ops/dequantize_cpu.cc index 5a3ab42bc..ad19eb8ce 100644 --- a/src/ops/dequantize_cpu.cc +++ b/src/ops/dequantize_cpu.cc @@ -20,6 +20,22 @@ namespace ctranslate2 { }); } + static inline void dequantize_i4_kernel(const uint8_t* x, + const float scale, + const float zero, + const dim_t size, + float* y, + bool rp) { + const float r_scale = 1.f / scale; + cpu::parallel_unary_transform(x, y, size, /*work_size=*/4, + [r_scale, rp, zero](uint8_t v) { + if (rp) + return (static_cast(v & 0xF) - zero) * r_scale; + else + return (static_cast(v >> 4) - zero) * r_scale; + }); + } + template<> void Dequantize::dequantize(const StorageView& input, const StorageView& scale, @@ -49,6 +65,31 @@ namespace ctranslate2 { }); } + + template<> + void Dequantize::dequantize_i4(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const { + const dim_t block_size = 32; + //const dim_t depth = input.dim(-1); + //const dim_t batch_size = input.size() / input.dim(-1); + const dim_t scale_size = scale.size(); + + const auto* input_data = input.data(); + const auto* scale_data = scale.data(); + const auto* zero_data = zero.data(); + auto* output_data = output.data(); + + cpu::parallel_for(0, scale_size, 1, [&](dim_t begin, dim_t end) { + for (dim_t i = begin; i < end; ++i) { + const dim_t offset = i * (block_size / 2); + dequantize_i4_kernel(input_data + offset, scale_data[i], zero_data[i], block_size / 2, output_data + offset * 2, false); + dequantize_i4_kernel(input_data + offset, scale_data[i], zero_data[i], block_size / 2, output_data + offset * 2 + block_size / 2, true); + } + }); + } + template<> void Dequantize::dequantize_gemm_output(const StorageView& c, const StorageView& a_scale, diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 241b3acdb..994de57ad 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -7,6 +7,10 @@ namespace ctranslate2 { template struct dequantize_func { + __device__ __forceinline__ + OutT operator()(float scale, InT x, float zero) const { + return __fdividef(__fsub_rn(static_cast(x), zero) , scale); + } __device__ __forceinline__ OutT operator()(float scale, InT x) const { return __fdividef(static_cast(x), scale); @@ -26,6 +30,41 @@ namespace ctranslate2 { cuda::repeat_vec_depth(depth)); } + template + __global__ void dequantize_i4_kernel(const float* a, + const float* z, + const uint8_t* b, + T* c, + cuda::index_t depth) { + const int32_t block_size = 32; + const auto rescale_func = dequantize_func(); + const cuda::index_t i = blockIdx.x; + for (cuda::index_t j = threadIdx.x; j < depth; j += blockDim.x) { + const cuda::index_t index = i * depth + j; + const float scale = a[index / (block_size / 2)]; + const float zero = z[index / (block_size / 2)]; + uint8_t b1 = b[index] >> 4; + uint8_t b2 = b[index] & 0xF; + T v1 = rescale_func(scale, b1, zero); + T v2 = rescale_func(scale, b2, zero); + c[index] = v1; + c[index % (block_size / 2) + (index / (block_size / 2)) * block_size + block_size / 2] = v2; + } + } + + template + void Dequantize::dequantize_i4(const StorageView& input, + const StorageView& scale, + const StorageView& zero, + StorageView& output) const { + const dim_t depth = input.dim(-1); + const dim_t batch_size = input.size() / depth; + const dim_t blocks = std::min(batch_size, cuda::max_blocks); + const dim_t threads = std::min(depth, cuda::max_threads); + dequantize_i4_kernel<<>>( + scale.data(), zero.data(), input.data(), output.data(), depth); + } + template __global__ void dequantize_gemm_output_kernel(const int32_t* c, @@ -144,6 +183,12 @@ namespace ctranslate2 { const StorageView&, \ StorageView&) const; \ template void \ + Dequantize::dequantize_i4( \ + const StorageView&, \ + const StorageView&, \ + const StorageView&, \ + StorageView&) const; \ + template void \ Dequantize::dequantize_gemm_output( \ const StorageView&, \ const StorageView&, \ diff --git a/src/storage_view.cc b/src/storage_view.cc index 0cbdf25c2..9636a3f06 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -440,6 +440,12 @@ namespace ctranslate2 { return os; } + template<> + std::ostream& print_value(std::ostream& os, const uint8_t& val) { + os << static_cast(val); + return os; + } + std::ostream& operator<<(std::ostream& os, const StorageView& storage) { StorageView printable(storage.dtype()); printable.copy_from(storage); diff --git a/src/type_dispatch.h b/src/type_dispatch.h index 7ecab93b5..9f6c87b36 100644 --- a/src/type_dispatch.h +++ b/src/type_dispatch.h @@ -41,6 +41,7 @@ namespace ctranslate2 { } MATCH_TYPE_AND_ENUM(float, DataType::FLOAT32); + MATCH_TYPE_AND_ENUM(uint8_t, DataType::UINT8); MATCH_TYPE_AND_ENUM(int8_t, DataType::INT8); MATCH_TYPE_AND_ENUM(int16_t, DataType::INT16); MATCH_TYPE_AND_ENUM(int32_t, DataType::INT32); @@ -60,6 +61,7 @@ namespace ctranslate2 { #define TYPE_DISPATCH(TYPE_ENUM, STMTS) \ switch (TYPE_ENUM) { \ TYPE_CASE(float, SINGLE_ARG(STMTS)) \ + TYPE_CASE(uint8_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int8_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int16_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int32_t, SINGLE_ARG(STMTS)) \ @@ -69,6 +71,7 @@ namespace ctranslate2 { #define DECLARE_ALL_TYPES(FUNC) \ FUNC(float) \ + FUNC(uint8_t) \ FUNC(int8_t) \ FUNC(int16_t) \ FUNC(int32_t) \ diff --git a/src/types.cc b/src/types.cc index 2431bce66..9c06c3b3f 100644 --- a/src/types.cc +++ b/src/types.cc @@ -15,6 +15,8 @@ namespace ctranslate2 { switch (type) { case DataType::FLOAT32: return "float32"; + case DataType::UINT8: + return "uint8"; case DataType::INT8: return "int8"; case DataType::INT16: From 253bceb4a903e6f6eac101ad44123e3e52a5fdd0 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Thu, 27 Jun 2024 13:50:11 +0200 Subject: [PATCH 02/10] fix dequantize gpu --- python/ctranslate2/specs/model_spec.py | 13 ------------- src/ops/dequantize_gpu.cu | 16 ++++++++-------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 9ed09ca38..478f8d39a 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -261,8 +261,6 @@ def _quantize(spec, name, value): "int4_bfloat16", ) and value.shape != 3: value = value.to("float32").numpy() - print("AAAAAAAAAAAAA") - print(value) group_size = 32 old_shape = value.shape new_shape = old_shape[:-1] + (old_shape[-1] // 2,) @@ -275,23 +273,12 @@ def _quantize(spec, name, value): scale = np.clip(max_v / (gmax - gmin), 0, 2e4) zero = -gmin * scale - print(name) - print("XXXXXXXXXXXXXXXX") - print(value) value = np.clip(np.round(value * np.expand_dims(scale, 1) + np.expand_dims(zero, 1)), min_v, max_v) - print("YYYYYYYYYYYYYYYY") - print(value) value = self._pack_4bit_u8(value) print(value.shape) value = value.reshape(new_shape) zero = zero.reshape(new_shape[0], -1) scale = scale.reshape(new_shape[0], -1) - print("ZZZZZZZZZZZZZZZZ") - print(value) - print("scale") - print(scale) - print("zero") - print(zero) zero = NumpyVariable(zero) scale = NumpyVariable(scale) diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index 994de57ad..d50f9d18b 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -41,14 +41,14 @@ namespace ctranslate2 { const cuda::index_t i = blockIdx.x; for (cuda::index_t j = threadIdx.x; j < depth; j += blockDim.x) { const cuda::index_t index = i * depth + j; - const float scale = a[index / (block_size / 2)]; - const float zero = z[index / (block_size / 2)]; - uint8_t b1 = b[index] >> 4; - uint8_t b2 = b[index] & 0xF; - T v1 = rescale_func(scale, b1, zero); - T v2 = rescale_func(scale, b2, zero); - c[index] = v1; - c[index % (block_size / 2) + (index / (block_size / 2)) * block_size + block_size / 2] = v2; + const cuda::index_t m = index / (block_size / 2); + const cuda::index_t n = index % (block_size / 2); + const float scale = a[m]; + const float zero = z[m]; + uint8_t b1 = (b[index] & 0xF0) >> 4; + uint8_t b2 = (b[index] & 0x0F); + c[n + m * block_size] = rescale_func(scale, b1, zero); + c[n + m * block_size + block_size / 2] = rescale_func(scale, b2, zero); } } From 836cd8ef1f3aa8571d10bda9d3fa93d8589ca19b Mon Sep 17 00:00:00 2001 From: thucpham Date: Fri, 28 Jun 2024 17:34:46 +0200 Subject: [PATCH 03/10] test gemm int4 --- CMakeLists.txt | 3 +- include/ctranslate2/ops/gemm.h | 17 +- python/ctranslate2/specs/model_spec.py | 42 +- src/layers/common.cc | 7 +- src/ops/gemm.cc | 16 +- src/ops/int4gemm_cpu.cc | 21 + src/ops/int4gemm_gpu.cu | 923 +++++++++++++++++++++++++ 7 files changed, 1021 insertions(+), 8 deletions(-) create mode 100644 src/ops/int4gemm_cpu.cc create mode 100644 src/ops/int4gemm_gpu.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index ac94aac57..5f0d92979 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,6 +198,7 @@ set(SOURCES src/ops/awq/gemv.cc src/ops/awq/gemv_cpu.cc src/ops/sum.cc + src/ops/int4gemm_cpu.cc src/padder.cc src/profiler.cc src/random.cc @@ -324,7 +325,7 @@ if(WITH_MKL) endif() # Find MKL libraries. - find_library(MKL_CORE_LIBRARY NAMES mkl_core HINTS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) + find_library(MKL_CORE_LIBRARY NAMES mkl_core PATHS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) if(MKL_CORE_LIBRARY) get_filename_component(MKL_LIBRARY_DIR ${MKL_CORE_LIBRARY} DIRECTORY) message(STATUS "Found MKL library directory: ${MKL_LIBRARY_DIR}") diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index 3c4efbb02..f80ba24c4 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -18,7 +18,8 @@ namespace ctranslate2 { bool trans_b = false, bool a_is_packed = false, bool b_is_packed = false, - const ActivationType* activation_type = nullptr); + const ActivationType* activation_type = nullptr, + int _group_size = 0); void operator()(const StorageView& a, const StorageView& b, @@ -26,6 +27,12 @@ namespace ctranslate2 { const StorageView* a_shift_compensation = nullptr, const StorageView* bias = nullptr) const; + void operator()(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c, + const StorageView* bias = nullptr) const; + // Return the packed representation of b, if implemented by the GEMM backend. static StorageView pack_b_input(const StorageView& b, const bool transpose, @@ -49,12 +56,20 @@ namespace ctranslate2 { bool _trans_b; bool _a_is_packed; bool _b_is_packed; + const ActivationType* _activation_type; + const int _group_size; template void compute(const StorageView& a, const StorageView& b, StorageView& c, const StorageView* a_shift_compensation) const; + + template + void compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const; }; } diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 478f8d39a..aa8a6fa38 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -197,6 +197,45 @@ def _pack_4bit_u8(self, w_q): _step = int(w_q.shape[1] / 2) return (w_q[:, :_step] << 4) | w_q[:, _step:] + def _hqq_quants_to_torch_quants( + self, w_q, scales, zeros, shape, nbits=4 + ): + w_q = w_q.astype(np.float32) + scales = scales.astype(np.float32) + zeros = zeros.astype(np.float32) + + max_int = 2**nbits - 1 + min_int = 0 + dump = 2 ** (nbits - 1) + + # HQQ -> torch logic + new_zeros = (scales * dump) - zeros * scales + + min_val = new_zeros - scales * dump + + # group_quantize_tensor_from_qparams + w_r = (w_q - zeros) * scales + + w_q = w_r - min_val + w_q = w_q / scales + w_q = np.round(w_q) + w_q = np.clip(w_q, min_int, max_int) + w_q = w_q.astype(np.int32) + w_q = w_q.reshape(shape) + n = w_q.shape[0] + k = w_q.shape[1] + inner_k_tiles = 8 + w_q = w_q.reshape([n // 8, k // (inner_k_tiles * 16), 32, inner_k_tiles // 2,]) + + scales = scales.reshape(shape[0], -1) + new_zeros = new_zeros.reshape(shape[0], -1) + scales = scales.reshape(scales.shape[0], scales.shape[1], 1) + new_zeros = new_zeros.reshape(new_zeros.shape[0], new_zeros.shape[1], 1) + scale_and_zero = np.concatenate([scales, new_zeros], axis=2) + scale_and_zero = scale_and_zero.transpose(1, 0, 2) + + return w_q, scale_and_zero + def _quantize(self, quantization): """Possibly quantizes the variable of the layer.""" if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES: @@ -275,12 +314,11 @@ def _quantize(spec, name, value): value = np.clip(np.round(value * np.expand_dims(scale, 1) + np.expand_dims(zero, 1)), min_v, max_v) value = self._pack_4bit_u8(value) - print(value.shape) value = value.reshape(new_shape) zero = zero.reshape(new_shape[0], -1) scale = scale.reshape(new_shape[0], -1) + value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) - zero = NumpyVariable(zero) scale = NumpyVariable(scale) value = NumpyVariable(value) diff --git a/src/layers/common.cc b/src/layers/common.cc index 9f602c1bc..50da9f256 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -426,10 +426,11 @@ namespace ctranslate2 { } } else if (input.dim(-1) != weight->dim(-1)) { //std::cout << "weighttttttttttttt: " << *weight << std::endl; - StorageView weight_dequant(input.dtype(), input.device()); - _dequantize_op(*weight, *qscale, *_qzero, weight_dequant); + //StorageView weight_dequant(input.dtype(), input.device()); + //_dequantize_op(*weight, *qscale, *_qzero, weight_dequant); //std::cout << "weighttttttttttttt dequant: " << weight_dequant << std::endl; - _gemm_op(input, weight_dequant, output, nullptr, bias); + //_gemm_op(input, weight_dequant, output, nullptr, bias); + _gemm_op(input, *weight, *qscale, output, bias); } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index e6ff87f9d..8900cc83e 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -25,7 +25,8 @@ namespace ctranslate2 { bool trans_b, bool a_is_packed, bool b_is_packed, - const ActivationType* activation_type) + const ActivationType* activation_type, + int group_size) : _alpha(alpha) , _beta(beta) , _trans_a(trans_a) @@ -33,6 +34,7 @@ namespace ctranslate2 { , _a_is_packed(a_is_packed) , _b_is_packed(b_is_packed) , _activation_type(activation_type) + , _group_size(group_size) { } @@ -69,6 +71,18 @@ namespace ctranslate2 { apply_bias_and_activation(c, bias, _activation_type); } + void Gemm::operator()(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c, + const StorageView* bias) const { + PROFILE("Gemm"); + + DEVICE_DISPATCH(a.device(), (compute(a, b, scaleAndZero, c))); + + apply_bias_and_activation(c, bias, _activation_type); + } + template void Gemm::compute(const StorageView& a, const StorageView& b, diff --git a/src/ops/int4gemm_cpu.cc b/src/ops/int4gemm_cpu.cc new file mode 100644 index 000000000..9a7c3eeb1 --- /dev/null +++ b/src/ops/int4gemm_cpu.cc @@ -0,0 +1,21 @@ +#include + +namespace ctranslate2 { + namespace ops { + template + void Gemm::compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const { + } + +#define DECLARE_IMPL(T) \ + template void \ + Gemm::compute(const StorageView& a, \ + const StorageView& b, \ + const StorageView& scaleAndZero, \ + StorageView& c) const; + + DECLARE_IMPL(bfloat16_t) + } +} \ No newline at end of file diff --git a/src/ops/int4gemm_gpu.cu b/src/ops/int4gemm_gpu.cu new file mode 100644 index 000000000..0562773e7 --- /dev/null +++ b/src/ops/int4gemm_gpu.cu @@ -0,0 +1,923 @@ +#include +#include "cuda/helpers.h" + +namespace ctranslate2 { + namespace ops { + template + constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return (a / b); + } + + template + constexpr __host__ __device__ auto divUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + // Overflow safe variant of (a + b - 1) / b + const uint64_t blocks = a / b + (a % b != 0); + return blocks; + } + + template + constexpr __host__ __device__ auto roundDown(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return divDown(a, b) * b; + } + + template + constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return divUp(a, b) * b; + } + + template + constexpr __host__ __device__ bool isEvenDivisor(U a, V b) { + static_assert(std::is_integral::value && std::is_integral::value, ""); + return (a % V(b) == 0) && ((a / V(b)) >= 1); + } + + template + constexpr __host__ __device__ T pow(T n, int power) { + return (power > 0 ? n * pow(n, power - 1) : 1); + } + + template + constexpr __host__ __device__ T pow2(int power) { + return pow(2, power); + } + + static_assert(pow2(8) == 256, "pow2"); + + template + constexpr __host__ __device__ int log2(T n, int p = 0) { + return (n <= 1) ? p : log2(n / 2, p + 1); + } + + static_assert(log2(2) == 1, "log2"); + static_assert(log2(3) == 1, "log2"); + static_assert(log2(4) == 2, "log2"); + + template + constexpr __host__ __device__ bool isPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (v && !(v & (v - 1))); + } + + static_assert(isPowerOf2(2048), "isPowerOf2"); + static_assert(!isPowerOf2(3333), "isPowerOf2"); + + template + constexpr __host__ __device__ T nextHighestPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); + } + + static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + + static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); + static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + + static_assert( + nextHighestPowerOf2(1536000000u) == 2147483648u, + "nextHighestPowerOf2"); + static_assert( + nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + + template + constexpr __host__ __device__ T nextLowestPowerOf2(T v) { + static_assert(std::is_integral::value, ""); + return (isPowerOf2(v) ? v / (T)2 : ((T)1 << (log2(v)))); + } + + static_assert(nextLowestPowerOf2(1) == 0, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(2) == 1, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(3) == 2, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(4) == 2, "nextLowestPowerOf2"); + + static_assert(nextLowestPowerOf2(15) == 8, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(16) == 8, "nextLowestPowerOf2"); + static_assert(nextLowestPowerOf2(17) == 16, "nextLowestPowerOf2"); + + inline __host__ __device__ bool isPointerAligned(const void* p, int align) { + return reinterpret_cast(p) % align == 0; + } + +// Returns the increment needed to aligned the pointer to the next highest +// aligned address + template + inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { + static_assert(isPowerOf2(Align), ""); + const uint32_t diff = uint32_t(uintptr_t(p) & uintptr_t(Align - 1)); + return diff == 0 ? 0 : uint32_t(Align) - diff; + } + + constexpr int32_t kWarpSize = 32; + +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) + // f16 vector types +struct __align__(2) f16x1 { + __half vals[1]; +}; + +struct __align__(4) f16x2 { + __half vals[2]; +}; + +struct __align__(8) f16x4 { + __half vals[4]; +}; + +struct __align__(16) f16x8 { + __half vals[8]; +}; + +// bf16 vector types +struct __align__(2) bf16x1 { + __nv_bfloat16 vals[1]; +}; + +struct __align__(4) bf16x2 { + __nv_bfloat16 vals[2]; +}; + +struct __align__(8) bf16x4 { + __nv_bfloat16 vals[4]; +}; + +struct __align__(16) bf16x8 { + __nv_bfloat16 vals[8]; +}; + +// bf162 vector types +struct __align__(4) bf16x2x1 { + __nv_bfloat162 vals[1]; +}; + +struct __align__(8) bf16x2x2 { + __nv_bfloat162 vals[2]; +}; + +struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; +}; + +struct __align__(16) bf16x2x4_u32 { + uint32_t vals[4]; +}; + +struct __align__(8) bf16x2x2_u32 { + uint32_t vals[2]; +}; + +struct __align__(4) bf16x2x1_u32 { + uint32_t vals[1]; +}; + +template +struct __align__(sizeof(T) * N) VectorType { + T vals[N]; +}; + +// from +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; +} + + + +enum class KReductionType { + // No k-reduction is needed between blocks as the number of k-tiles processed + // per block are exact and we can directly write the output + None, +}; + +// Loads the A matrix in 16-bit standard m x k row major layout, and writes +// the C matrix in 16-bit standard m x n row major layout: +// +// size [m][k] +template +struct ALayout_RM { + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + const void* A, + int32_t m, + int32_t k, + int32_t mTiles, + int32_t mTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad]) { + const auto mLane = mTile * kMTileSize + (laneId / 4); + const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; + + // access + // [mTile * kMTileSize + (laneId / 4)] + // [kTileStart * kKTileSize + (laneId % 4) * 2] + auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + + auto aPtrPlus8Rows = aPtr + 8 * k; + + bool m0InBounds = mLane < m; + bool m1InBounds = (mLane + 8) < m; + +#pragma unroll + for (int i = 0; i < KTilesToLoad; ++i) { + out[i].vals[0] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize) + : uint32_t(0); + out[i].vals[1] = m1InBounds + ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) + : uint32_t(0); + + out[i].vals[2] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize + 8) + : uint32_t(0); + out[i].vals[3] = m1InBounds ? *reinterpret_cast( + aPtrPlus8Rows + i * kKTileSize + 8) + : uint32_t(0); + } + } + + static __device__ void store( + void* C, + int32_t m, + int32_t n, + int32_t mOutTiles, + int32_t mTile, + int32_t nOutTiles, + int32_t nTile, + int32_t laneId, + const float4& out) { + static_assert(ReduceType == KReductionType::None, ""); + + if constexpr (ReduceType == KReductionType::None) { + // sum.x / sum.y are written at + // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // sum.z / sum.w are written at + // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // i.e., same columns, different row. + const int outRow = mTile * kMTileSize + (laneId / 4); + const int outCol = nTile * kNTileSize + (laneId % 4) * 2; + + // Pointer where sum.x / sum.y is written + auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; + + auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); + auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); + + if (outRow < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + } + + // sum.z, sum.w at +8 rows from cPtr + if (outRow + 8 < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + } + } + } +}; + +template +struct BLayout_TC_int4 { + static constexpr int32_t kInnerKTiles = InnerKTiles; + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] + // n / 8: n-tiles (n8) + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) + // 32: value per warp lane + // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. + // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest + // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a + // uint32x4 (128 bits) + const void* __restrict__ B, + // size [k / qGroupSize][n][2] + // Contains the scale and zero point of each of the quantized int4 values + // within B + // v_reconstructed = (bf16(B_int4_val) * scale) - zero + const void* __restrict__ quantizationInfo, + int32_t n, + int32_t k, + int32_t nTiles, + int32_t nTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { + // offset [nTile][kTileStart / InnerKTiles][laneId][0] + auto bPtr = reinterpret_cast(B) + + (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * + kWarpSize) + + laneId) * + (InnerKTiles / 2); + + int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); + + if constexpr (InnerKTiles == 2) { + b_int4[i][0] = bPtrCur[0]; + } + + if constexpr (InnerKTiles == 4) { + // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) + // : "l"(bPtrCur)); + + int2 load8 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load8.x; + b_int4[i][1] = load8.y; + } + + if constexpr (InnerKTiles == 8) { + // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), + // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); + + int4 load16 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load16.x; + b_int4[i][1] = load16.y; + b_int4[i][2] = load16.z; + b_int4[i][3] = load16.w; + } + } + + // Load needed info for dequantization + + static_assert(isPowerOf2(QGroupSize), ""); + static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); + // smallest quantization group size is 32 (2 k-tiles are packed in an int32) + static_assert(QGroupSize >= kKTileSize * 2, ""); + constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); + // a q-group could be larger than what we are handling in a single warp + constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 + ? 1 + : (KTilesToLoad / kKTilesPerQGroup); + + __nv_bfloat162 qScaleAndZero[kNumQGroups]; + { + int32_t laneN = nTile * kNTileSize + (laneId / 4); + int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; + + int32_t n = nTiles * kNTileSize; + + // offset [qScale_kGroup][qScale_n][0] + auto qInfoPtr = reinterpret_cast(quantizationInfo) + + (groupStart * n + laneN) * 2; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScaleAndZero[i] = + *reinterpret_cast(qInfoPtr + i * n * 2); + } + } + + // + // De-quantize int4 values to bf16. Values are dequantized as truly int4 + // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + { + // FIXME: does this negatively affect register counts, or will nvcc + // move this expansion (and data loads above) closer to the point of use? + __nv_bfloat162 qScale[kNumQGroups]; + __nv_bfloat162 qZero[kNumQGroups]; + +#pragma unroll + for (int i = 0; i < kNumQGroups; ++i) { + qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); + qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); + } + +#pragma unroll + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { +#pragma unroll + for (int j = 0; j < InnerKTiles / 2; ++j) { + bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); + + int curKTile = i * InnerKTiles + j * 2; + int curQGroup = (curKTile * kKTileSize) / QGroupSize; + + // The dequantized values in `v` for a given lane have the same n + // dimension (the B tensor core layout has all values in the same + // thread along the same n) but different k dimension, but all are + // guaranteed to occur within the same quantization group, so we need + // only load a single scale + zero to cover what this lane has +#pragma unroll + for (int k = 0; k < 4; ++k) { + v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + } + + // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and + // can't be used as a 32-bit asm register argument for `mma` + static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); + std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); + } + } + } + } +}; + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerIteration> +__global__ +__launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( + // Data for the A matrix, loaded as per ALayout + const void* const __restrict__ A, + + // Data for the B matrix, loaded as per BLayout + const void* const __restrict__ B, + + // Optional quantization data for dequantizing B, loaded as per BLayout + const void* const __restrict__ B_quantizationInfo, + + // Output data for the C matrix, stored as per CLayout + void* __restrict__ C, + + // The size of the matrix multiplication + int32_t m, + int32_t n, + int32_t k, + + // The size of the matrix multiplication, in multiples of our TC tile size + int32_t mTiles, + int32_t nTiles, + int32_t kTiles) { + constexpr int32_t kMTileSize = 16; + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + static_assert( + ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && + ALayout::kKTileSize == kKTileSize, + ""); + + static_assert( + BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && + BLayout::kKTileSize == kKTileSize, + ""); + + static_assert( + CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && + CLayout::kKTileSize == kKTileSize, + ""); + + constexpr int kInnerKTiles = BLayout::kInnerKTiles; + + // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads + static_assert( + kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); + + // We always process at least kInnerKTiles k-tiles back to back in a warp + static_assert( + KTilesPerIteration >= kInnerKTiles && + isEvenDivisor(KTilesPerIteration, kInnerKTiles), + ""); + + auto warpId = threadIdx.y; + auto laneId = threadIdx.x; + + int32_t mTile = blockIdx.z; + int32_t nTile = blockIdx.y; + + float4 c{0.0f, 0.0f, 0.0f, 0.0f}; + + // First, handle whole multiples of KTilesPerIteration + auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); + + // Each warp handles a set of KTilesPerIteration under the above limit + for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; + kTileBase < kTilesLimit; + kTileBase += Warps * KTilesPerIteration) { + // + // Load data from A + // + bf16x2x4_u32 a[KTilesPerIteration]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); + + // + // Load data from B and de-quantize as needed + // Each k-tile is bf16x2x2 + // + bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBase, + laneId, + b); + + // + // Now, perform the matrix multiplication + // + + // We accumulate across k-tiles here +#pragma unroll + for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { + static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; + +#pragma unroll + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), + "=f"(cTmp[k].y), + "=f"(cTmp[k].z), + "=f"(cTmp[k].w) + : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } + } + } // for all tiles under kTilesLimit + + // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles + // remaining. We guarantee that the number of warps is >= KTilesPerIteration / + // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its + // thing without needing more warps + static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); + + auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; + + // If we have any remainder k-tiles, some warps will handle them, processing + // kInnerKTiles k-tiles at a time + if (kTileBaseRemaining < kTiles) { + bf16x2x4_u32 a[kInnerKTiles]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); + + bf16x2x4_u32 b[1][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBaseRemaining, + laneId, + b); + +#pragma unroll + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; + +#pragma unroll + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) + : "r"(a[j * 2 + k].vals[0]), + "r"(a[j * 2 + k].vals[1]), + "r"(a[j * 2 + k].vals[2]), + "r"(a[j * 2 + k].vals[3]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } + +#pragma unroll + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } + } + + // + // Reduce independent k-tiles (same m/n) across warps + // + __shared__ float4 smem_sum[Warps][kWarpSize]; + + // FIXME: this likely doesn't need to be a true reduction tree, can just be a + // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) + // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); + smem_sum[warpId][laneId] = c; + + __syncthreads(); + + if (warpId == 0) { + float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; + + // Reduce across the block in the first warp + for (int i = 0; i < Warps; ++i) { + float4 v = smem_sum[i][laneId]; + sum_f32.x += v.x; + sum_f32.y += v.y; + sum_f32.z += v.z; + sum_f32.w += v.w; + } + + // Write the reduced result (in the first warp) into the output + CLayout::store( + C, + m, + n, + mTiles, + mTile, + // n for C output becomes k for A input, so for m16n8k16, + // we need to halve the tiles + nTiles / 2, + nTile, + laneId, + sum_f32); + } +} + + +template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerWarp> +void launch_tinygemm_kernel( + const StorageView& A, + const StorageView& B, + const StorageView* qScaleAndZeros, /* optional */ + StorageView& C_final, + int32_t mTiles, + int32_t nTiles, + int32_t kTiles, + int32_t m, + int32_t n, + int32_t k, + cudaStream_t stream) { + // After intra-block reduction across the k dimension, we are left with this + // many tiles + // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); + int32_t postKernelKTiles = 1; // we loop + + auto grid = dim3(postKernelKTiles, nTiles, mTiles); + auto block = dim3(kWarpSize, Warps); + + auto func = + tinygemm_m16n8k16_chunk_kernel; + + func<<>>( + A.data(), + B.data(), + qScaleAndZeros ? qScaleAndZeros->data() : nullptr, + C_final.data(), + m, + n, + k, + mTiles, + nTiles, + kTiles); +} +#endif + + template + void Gemm::compute(const StorageView& a, + const StorageView& b, + const StorageView& scaleAndZero, + StorageView& c) const { + constexpr int32_t kMTileSize = 16; + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + // row major layout + auto m = a.dim(0); + auto mTiles = divUp(m, kMTileSize); + + // tensor core layout + auto nTiles = b.dim(0); + auto n = nTiles * kNTileSize; + + // row major layout + auto k = a.dim(1); + auto kTiles = divUp(k, kKTileSize); + + // The number of inner k tiles is the innermost dimension of times 2 + // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4 + // packed into 1 int32 for int4 B + auto B_innerKTiles = b.dim(3) * 2; + + //TORCH_CHECK(qScaleAndZeros.dim() == 3); + auto numQGroups = scaleAndZero.dim(0); + // Output is a standard row-major matrix + c.resize({m, n}); + +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) + auto stream = cuda::get_cuda_stream(); +#define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ + do { \ + using ACLayout = ALayout_RM; \ + \ + switch (B_innerKTiles) { \ + case 2: \ + if constexpr (K_TILES_PER_WARP >= 2) { \ + using BLayout = BLayout_TC_int4<2, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 4: \ + if constexpr (K_TILES_PER_WARP >= 4) { \ + using BLayout = BLayout_TC_int4<4, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + case 8: \ + if constexpr (K_TILES_PER_WARP >= 8) { \ + using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \ + launch_tinygemm_kernel< \ + ACLayout, \ + BLayout, \ + ACLayout, \ + WARPS, \ + K_TILES_PER_WARP>( \ + a, \ + b, \ + &scaleAndZero, \ + c, \ + mTiles, \ + nTiles, \ + kTiles, \ + m, \ + n, \ + k, \ + stream); \ + } \ + break; \ + default: \ + break; \ + } \ + } while (false) + +#define HANDLE_Q_GROUP(WARPS, K_TILES_PER_WARP, REDUCE_TYPE) \ + do { \ + switch (_group_size) { \ + case 32: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 32, REDUCE_TYPE); \ + break; \ + case 64: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 64, REDUCE_TYPE); \ + break; \ + case 128: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 128, REDUCE_TYPE); \ + break; \ + case 256: \ + RUN_GEMM(WARPS, K_TILES_PER_WARP, 256, REDUCE_TYPE); \ + break; \ + } \ + } while (false) + + HANDLE_Q_GROUP(8, 8, KReductionType::None); + +#undef HANDLE_Q_GROUP +#undef RUN_GEMM +#endif + } + +#define DECLARE_IMPL(T) \ + template void \ + Gemm::compute(const StorageView& a, \ + const StorageView& b, \ + const StorageView& scaleAndZero, \ + StorageView& c) const; + + DECLARE_IMPL(bfloat16_t) + } +} \ No newline at end of file From 0b6e9cbe4614be417bca63eeaf1a7ec93136a478 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Thu, 4 Jul 2024 11:16:43 +0200 Subject: [PATCH 04/10] fix rebase --- include/ctranslate2/models/model.h | 1 + src/models/model.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index aedf9930f..32e4f8403 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -191,6 +191,7 @@ namespace ctranslate2 { std::unordered_map> _variable_index; bool _use_flash_attention = false; bool _tensor_parallel = false; + QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2; }; template<> diff --git a/src/models/model.cc b/src/models/model.cc index d7b3903f7..2a5aeebd7 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -174,7 +174,7 @@ namespace ctranslate2 { _device_index = index; } - void Model::set_compute_type(ComputeType type, Device device, int device_index) { + void Model::set_compute_type(ComputeType type, Device device, int device_index, bool update_weight) { if (_device != Device::CPU) throw std::runtime_error("set_compute_type expects the variables to be on CPU"); From 4ea2e509b73f756c14f4b5ed6682480d0bfb420f Mon Sep 17 00:00:00 2001 From: minhthuc Date: Mon, 8 Jul 2024 17:47:49 +0200 Subject: [PATCH 05/10] hqq quant --- CMakeLists.txt | 1 + include/ctranslate2/models/model.h | 5 +- include/ctranslate2/ops/gemm.h | 11 +- python/ctranslate2/specs/common_spec.py | 1 + python/ctranslate2/specs/model_spec.py | 68 ++++++------ src/layers/attention.cc | 3 - src/layers/common.cc | 3 + src/models/model.cc | 33 ++++-- src/ops/gemm.cc | 9 ++ src/ops/int4gemm_cpu.cc | 6 ++ src/ops/int4gemm_gpu.cu | 134 ++++++++++++++++++++++-- src/types.cc | 2 + 12 files changed, 221 insertions(+), 55 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5f0d92979..1fc3e2932 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -606,6 +606,7 @@ if (WITH_CUDA) src/ops/awq/gemm_gpu.cu src/ops/awq/gemv_gpu.cu src/ops/awq/dequantize_gpu.cu + src/ops/int4gemm_gpu.cu ) set_source_files_properties( diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 32e4f8403..86d44c1c4 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -14,7 +14,8 @@ namespace ctranslate2 { enum class QUANTIZATION_TYPE { CT2, AWQ_GEMM, - AWQ_GEMV + AWQ_GEMV, + HQQ_4BIT, }; static const size_t current_binary_version = 6; @@ -113,7 +114,7 @@ namespace ctranslate2 { } // If the model contains variables, they will be moved to the new device. - void set_device(const Device device, const int index = 0); + void set_device(const Device device, const int index = 0, const bool format_lower_bit = false); // Copy the model to another device. std::shared_ptr copy_to(Device device, int device_index = 0) const; diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index f80ba24c4..041c44482 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -19,7 +19,7 @@ namespace ctranslate2 { bool a_is_packed = false, bool b_is_packed = false, const ActivationType* activation_type = nullptr, - int _group_size = 0); + int _group_size = 32); void operator()(const StorageView& a, const StorageView& b, @@ -33,6 +33,9 @@ namespace ctranslate2 { StorageView& c, const StorageView* bias = nullptr) const; + StorageView convert_to_int4pack(const StorageView& input, + int32_t innerKTiles); + // Return the packed representation of b, if implemented by the GEMM backend. static StorageView pack_b_input(const StorageView& b, const bool transpose, @@ -56,7 +59,6 @@ namespace ctranslate2 { bool _trans_b; bool _a_is_packed; bool _b_is_packed; - const ActivationType* _activation_type; const int _group_size; template @@ -70,6 +72,11 @@ namespace ctranslate2 { const StorageView& b, const StorageView& scaleAndZero, StorageView& c) const; + + template + void convert_weight_to_int4pack(const StorageView& a, + StorageView& b, + int32_t innerKTiles); }; } diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index f0f3b1a82..d16ed88eb 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -29,6 +29,7 @@ class Quantization(enum.IntEnum): CT2 = 0 AWQ_GEMM = 1 AWQ_GEMV = 2 + HQQ_INT4 = 3 class LayerNormSpec(model_spec.LayerSpec): diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index aa8a6fa38..02c4c6376 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -21,14 +21,15 @@ except ImportError: torch_is_available = False +from ctranslate2.specs import ( + common_spec, +) + OPTIONAL = "__optional" CURRENT_BINARY_VERSION = 6 ACCEPTED_MODEL_TYPES = ( - "int4", - "int4_float32", - "int4_float16", - "int4_bfloat16", + "hqq_int4", "int8", "int8_float32", "int8_float16", @@ -197,35 +198,24 @@ def _pack_4bit_u8(self, w_q): _step = int(w_q.shape[1] / 2) return (w_q[:, :_step] << 4) | w_q[:, _step:] - def _hqq_quants_to_torch_quants( - self, w_q, scales, zeros, shape, nbits=4 - ): - w_q = w_q.astype(np.float32) - scales = scales.astype(np.float32) - zeros = zeros.astype(np.float32) + def _hqq_quants_to_torch_quants(self, w_q, scales, zeros, shape, nbits=4): + scales = scales.reshape([w_q.shape[0], -1]) + zeros = zeros.reshape([w_q.shape[0], -1]) max_int = 2**nbits - 1 min_int = 0 dump = 2 ** (nbits - 1) # HQQ -> torch logic - new_zeros = (scales * dump) - zeros * scales + new_zeros = scales * dump - zeros * scales min_val = new_zeros - scales * dump - - # group_quantize_tensor_from_qparams - w_r = (w_q - zeros) * scales - - w_q = w_r - min_val + w_q = w_q - min_val w_q = w_q / scales w_q = np.round(w_q) w_q = np.clip(w_q, min_int, max_int) w_q = w_q.astype(np.int32) w_q = w_q.reshape(shape) - n = w_q.shape[0] - k = w_q.shape[1] - inner_k_tiles = 8 - w_q = w_q.reshape([n // 8, k // (inner_k_tiles * 16), 32, inner_k_tiles // 2,]) scales = scales.reshape(shape[0], -1) new_zeros = new_zeros.reshape(shape[0], -1) @@ -251,7 +241,7 @@ def _quantize(spec, name, value): key = _split_scope(name)[-1] scale = None zero = None - is_quantizable = hasattr(spec, "%s_scale" % key) + is_quantizable = hasattr(spec, "%s_scale" % key) and 'embeddings' not in name is_convertible = value.dtype in ("float32", "float16", "bfloat16") if is_quantizable: @@ -294,10 +284,7 @@ def _quantize(spec, name, value): elif quantization in ("float16", "bfloat16", "float32"): value = value.to(quantization) elif quantization in ( - "int4", - "int4_float32", - "int4_float16", - "int4_bfloat16", + "hqq_int4", ) and value.shape != 3: value = value.to("float32").numpy() group_size = 32 @@ -312,20 +299,19 @@ def _quantize(spec, name, value): scale = np.clip(max_v / (gmax - gmin), 0, 2e4) zero = -gmin * scale - value = np.clip(np.round(value * np.expand_dims(scale, 1) + np.expand_dims(zero, 1)), min_v, max_v) - value = self._pack_4bit_u8(value) - value = value.reshape(new_shape) zero = zero.reshape(new_shape[0], -1) scale = scale.reshape(new_shape[0], -1) + scale = 1 / scale value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) scale = NumpyVariable(scale) + scale = scale.to("bfloat16") value = NumpyVariable(value) elif is_convertible: if quantization in ("float16", "int8_float16"): value = value.to("float16") - elif quantization in ("bfloat16", "int8_bfloat16"): + elif quantization in ("bfloat16", "int8_bfloat16", "hqq_int4"): value = value.to("bfloat16") elif quantization in ("float32", "int16", "int8_float32"): value = value.to("float32") @@ -333,8 +319,6 @@ def _quantize(spec, name, value): setattr(spec, key, value) if scale is not None: setattr(spec, "%s_scale" % key, scale) - if zero is not None: - setattr(spec, "%s_zero" % key, zero) self._visit(_quantize) @@ -493,6 +477,28 @@ def _write_string(string): _write_string(variable_name) + def _save_quantization_type(self, quantization): + if quantization == 'hqq_int4': + self._config.add_attribute("quantization_type", common_spec.Quantization.HQQ_INT4) + elif quantization is not None: + self._config.add_attribute("quantization_type", common_spec.Quantization.CT2) + + + def optimize(self, quantization: Optional[str] = None) -> None: + """Recursively applies some optimizations to its layer: + + * Alias variables with the same shape and value. + * Quantize weights. + + Arguments: + quantization: Weight quantization scheme (possible values are: int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + """ + self._save_quantization_type(quantization) + self._alias_variables() + self._quantize(quantization) + + def _flatten_vocabularies(vocabularies): for name, vocabulary in vocabularies.items(): if len(vocabulary) == 1: diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 69130108e..18e2710f7 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -9,7 +9,6 @@ #include "dispatch.h" #include "cpu/parallel.h" -#include namespace ctranslate2 { namespace layers { @@ -320,9 +319,7 @@ namespace ctranslate2 { q = &queries_proj; } - //std::cout << "qqqqqqqqqqqqqqqqqqq: " << *q << std::endl; _linear[0](*q, fused_proj); - //std::cout << "fused_projjjjjjjjjjjjj: " << fused_proj << std::endl; dim_t beam_size = 1; diff --git a/src/layers/common.cc b/src/layers/common.cc index 50da9f256..aecaa24cc 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -271,6 +271,7 @@ namespace ctranslate2 { , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) , _qscale(model.get_variable_if_exists(scope + "/weight_scale")) + , _qzero(model.get_variable_if_exists(scope + "/weight_zero")) , _u8_shift_compensation((_weight.device() == Device::CPU && _weight.dtype() == DataType::INT8 && cpu::prefer_u8s8s32_gemm()) @@ -281,6 +282,7 @@ namespace ctranslate2 { , _partial_qscale(_weight.device(), DataType::FLOAT32) , _partial_u8_shift_compensation(_weight.device(), DataType::INT32) , _output_type(get_default_float_type(model.effective_compute_type())) + , _quant_method(model.quant_method()) , _quantized_gemm(_weight.dtype() == DataType::INT16 || _weight.dtype() == DataType::INT8) , _gemm_op(/*alpha=*/1, /*beta=*/0, @@ -295,6 +297,7 @@ namespace ctranslate2 { /*shift_to_uint8=*/bool(_u8_shift_compensation), /*round_before_cast=*/model.round_before_cast_in_quantization()) , _dequantize_op(activation_type) + , _activation_type(activation_type) , _is_layer_out(is_layer_out) { } diff --git a/src/models/model.cc b/src/models/model.cc index 2a5aeebd7..344ca7e5a 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -51,6 +51,15 @@ namespace ctranslate2 { + std::to_string(position)); } + static void format_weight(const std::string name, StorageView& weight, int device_index) { + if (!ends_with(name, "embeddings/weight") && ends_with(name, "weight")) { + ops::Gemm gemm_ops; + StorageView tmp = gemm_ops.convert_to_int4pack(weight, 8); + weight = std::move(tmp); + } + synchronize_device(weight.device(), device_index); + } + template T consume(std::istream& in) { const std::streampos position = in.tellg(); @@ -88,19 +97,23 @@ namespace ctranslate2 { } template - static void move_variables_to_device(VariablesCollection& variables, const Device device) { + static void move_variables_to_device(VariablesCollection& variables, const Device device, + bool format_lower_bit, const int device_index) { for (auto& pair : variables) { StorageView& variable = *pair.second; if (variable.is_scalar() || variable.device() == device) continue; variable = variable.to(device); + if (format_lower_bit) + format_weight(pair.first, variable, device_index); } } template static void move_variables(VariablesCollection& variables, const Device src_device, const int src_device_index, - const Device dst_device, const int dst_device_index) { + const Device dst_device, const int dst_device_index, + const bool format_lower_bit) { if (variables.empty()) return; if (src_device == dst_device && src_device_index == dst_device_index) @@ -109,13 +122,13 @@ namespace ctranslate2 { // Move variables back to the CPU device. if (src_device != Device::CPU && dst_device == Device::CPU) { ScopedDeviceSetter scoped_device_setter(src_device, src_device_index); - move_variables_to_device(variables, Device::CPU); + move_variables_to_device(variables, Device::CPU, format_lower_bit, 0); } // Move variables to the destination device. if (src_device == Device::CPU && dst_device != Device::CPU) { ScopedDeviceSetter scoped_device_setter(dst_device, dst_device_index); - move_variables_to_device(variables, dst_device); + move_variables_to_device(variables, dst_device, format_lower_bit, dst_device_index); } synchronize_device(src_device, src_device_index); // Wait for asynchronous deallocations. @@ -168,8 +181,8 @@ namespace ctranslate2 { return 1; } - void Model::set_device(const Device device, const int index) { - move_variables(_variable_index, _device, _device_index, device, index); + void Model::set_device(const Device device, const int index, const bool format_lower_bit) { + move_variables(_variable_index, _device, _device_index, device, index, format_lower_bit); _device = device; _device_index = index; } @@ -641,8 +654,9 @@ namespace ctranslate2 { QUANTIZATION_TYPE quantization_type = QUANTIZATION_TYPE::CT2; if (model->config.contains("quantization_type")) - model->set_quant_method(model->config["quantization_type"]); + quantization_type =model->config["quantization_type"]; + model->set_quant_method(quantization_type); for (uint32_t i = 0; i < num_variables; ++i) { auto name = consume(model_file); const size_t rank = consume(model_file); @@ -760,13 +774,16 @@ namespace ctranslate2 { case QUANTIZATION_TYPE::AWQ_GEMV: model->set_compute_type(ComputeType::FLOAT16, device, device_index, false); break; + case QUANTIZATION_TYPE::HQQ_4BIT: + model->set_compute_type(compute_type, device, device_index, false); + break; default: throw std::invalid_argument("Quantization type is not supported"); break; } // Move variables to the target device. - model->set_device(device, device_index); + model->set_device(device, device_index, quantization_type == QUANTIZATION_TYPE::HQQ_4BIT); // Register variable aliases. if (binary_version >= 3) { diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index 8900cc83e..257340b43 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -78,7 +78,10 @@ namespace ctranslate2 { const StorageView* bias) const { PROFILE("Gemm"); + dim_t batch_size = a.dim(0); + dim_t time = a.dim(1); DEVICE_DISPATCH(a.device(), (compute(a, b, scaleAndZero, c))); + c.reshape({batch_size, time, -1}); apply_bias_and_activation(c, bias, _activation_type); } @@ -191,5 +194,11 @@ namespace ctranslate2 { return compensation; } + StorageView Gemm::convert_to_int4pack(const StorageView& input, + int32_t innerKTiles) { + StorageView output(input.device(), input.dtype()); + DEVICE_DISPATCH(input.device(), (convert_weight_to_int4pack(input, output, innerKTiles))); + return output; + } } } diff --git a/src/ops/int4gemm_cpu.cc b/src/ops/int4gemm_cpu.cc index 9a7c3eeb1..82f30c831 100644 --- a/src/ops/int4gemm_cpu.cc +++ b/src/ops/int4gemm_cpu.cc @@ -9,6 +9,12 @@ namespace ctranslate2 { StorageView& c) const { } + template <> + void Gemm::convert_weight_to_int4pack(const StorageView& a, + StorageView& b, + int32_t innerKTiles) { + } + #define DECLARE_IMPL(T) \ template void \ Gemm::compute(const StorageView& a, \ diff --git a/src/ops/int4gemm_gpu.cu b/src/ops/int4gemm_gpu.cu index 0562773e7..f9ae1eaaa 100644 --- a/src/ops/int4gemm_gpu.cu +++ b/src/ops/int4gemm_gpu.cu @@ -117,7 +117,7 @@ namespace ctranslate2 { constexpr int32_t kWarpSize = 32; -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) // f16 vector types struct __align__(2) f16x1 { __half vals[1]; @@ -733,6 +733,69 @@ __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( } } +// FIXME: parallelize better, smem staging etc? + template + __global__ void + matrix_to_m16n8k16_Bint4_layout( + // size [n][k] + const int32_t* in, + int32_t n, + int32_t depth, + int32_t depth_output1, + int32_t depth_output2, + int32_t depth_output3, + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + int32_t* out) { + // int4 values are packed into int32 values, which require at least 8. Given + // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of + // innermost k-tiles that we can use is 2. + static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // Two k-tiles are packed into an int32 at a time +#pragma unroll + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); + + bool n0Valid = n0 < n; + + int32_t ks[8]; + + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 1; + ks[2] = ks[0] + 8; + ks[3] = ks[0] + 8 + 1; + + auto kBase1 = kBase0 + kKTileSize; + ks[4] = kBase1 + (t % 4) * 2; + ks[5] = ks[4] + 1; + ks[6] = ks[4] + 8; + ks[7] = ks[4] + 8 + 1; + + auto pIn = in + (n0 * depth); + + uint32_t v[8]; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v[i] = (n0Valid && (ks[i] < depth)) ? pIn[ks[i]] : uint32_t(0); + } + + int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) | + (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; + + // inner k-tiles pack two at a time + out[nTile * depth_output3 + kOuterTile * depth_output2 + t * depth_output1 + innerKTile / 2] = pack; + } + } template < typename ALayout, @@ -775,7 +838,7 @@ void launch_tinygemm_kernel( nTiles, kTiles); } -#endif +//#endif template void Gemm::compute(const StorageView& a, @@ -787,7 +850,7 @@ void launch_tinygemm_kernel( constexpr int32_t kKTileSize = 16; // row major layout - auto m = a.dim(0); + auto m = a.rank() == 3 ? a.dim(0) * a.dim(1) : a.dim(0); auto mTiles = divUp(m, kMTileSize); // tensor core layout @@ -795,20 +858,19 @@ void launch_tinygemm_kernel( auto n = nTiles * kNTileSize; // row major layout - auto k = a.dim(1); + auto k = a.rank() == 3 ? a.dim(2) : a.dim(1); auto kTiles = divUp(k, kKTileSize); // The number of inner k tiles is the innermost dimension of times 2 // 2 k-tiles (4 values per lane per tile, 8 values total) quantized to int4 // packed into 1 int32 for int4 B - auto B_innerKTiles = b.dim(3) * 2; + const int32_t B_innerKTiles = b.dim(3) * 2; //TORCH_CHECK(qScaleAndZeros.dim() == 3); auto numQGroups = scaleAndZero.dim(0); // Output is a standard row-major matrix c.resize({m, n}); - -#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) auto stream = cuda::get_cuda_stream(); #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ do { \ @@ -860,7 +922,7 @@ void launch_tinygemm_kernel( } \ break; \ case 8: \ - if constexpr (K_TILES_PER_WARP >= 8) { \ + if constexpr (K_TILES_PER_WARP >= 8) { \ using BLayout = BLayout_TC_int4<8, Q_GROUP_SIZE>; \ launch_tinygemm_kernel< \ ACLayout, \ @@ -908,7 +970,61 @@ void launch_tinygemm_kernel( #undef HANDLE_Q_GROUP #undef RUN_GEMM -#endif +//#endif + } + + template <> + void Gemm::convert_weight_to_int4pack( + const StorageView& a, + StorageView& b, + int32_t innerKTiles) { + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + auto nTiles = divUp(a.dim(0), kNTileSize); + + // k-tiles are packed back to back in the innermost dimension in order to + // allow for 4/8/16 byte loads + // kSuperTiles is the number of k-tiles assuming k is innerKTiles * kKTileSize + auto kSuperTiles = divUp(a.dim(1), innerKTiles * kKTileSize); + + // each block handles `innerKTiles` k-tiles. + // 2 k-tiles are a single int32 + b.resize({nTiles, kSuperTiles, 32, innerKTiles / 2}); + +//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) + auto stream = ctranslate2::cuda::get_cuda_stream(); + dim3 grid(kSuperTiles, nTiles); + + if (innerKTiles == 2) { + matrix_to_m16n8k16_Bint4_layout<2><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim(-1), + b.stride(1), + b.stride(0), + b.data()); + } else if (innerKTiles == 4) { + matrix_to_m16n8k16_Bint4_layout<4><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim( 1), + b.stride(1), + b.stride(0), + b.data()); + } else if (innerKTiles == 8) { + matrix_to_m16n8k16_Bint4_layout<8><<>>( + a.data(), + a.dim(0), + a.dim(-1), + b.dim(-1), + b.stride(1), + b.stride(0), + b.data()); + } +//#endif } #define DECLARE_IMPL(T) \ diff --git a/src/types.cc b/src/types.cc index 9c06c3b3f..cb7c9b1ce 100644 --- a/src/types.cc +++ b/src/types.cc @@ -89,6 +89,8 @@ namespace ctranslate2 { return "float16"; case ComputeType::BFLOAT16: return "bfloat16"; + case ComputeType::INT32_BFLOAT16: + return "int32_bfloat16"; }; throw std::invalid_argument("Invalid compute type value"); } From 9b0d425d124a176128bbbff7d84f60a0fc44fdae Mon Sep 17 00:00:00 2001 From: minhthuc Date: Mon, 8 Jul 2024 17:48:59 +0200 Subject: [PATCH 06/10] hqq quant --- include/ctranslate2/types.h | 1 + src/layers/common.cc | 7 +------ src/types.cc | 19 +++++++++++++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/include/ctranslate2/types.h b/include/ctranslate2/types.h index e0ed0e1ac..90a974777 100644 --- a/include/ctranslate2/types.h +++ b/include/ctranslate2/types.h @@ -37,6 +37,7 @@ namespace ctranslate2 { INT16, FLOAT16, BFLOAT16, + INT32_BFLOAT16, }; ComputeType str_to_compute_type(const std::string& compute_type); diff --git a/src/layers/common.cc b/src/layers/common.cc index aecaa24cc..c9c26a845 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -5,7 +5,7 @@ #include "ctranslate2/ops/activation.h" #include "cpu/backend.h" #include "dispatch.h" - +#include namespace ctranslate2 { namespace layers { @@ -428,11 +428,6 @@ namespace ctranslate2 { "support only ct2 and awq quantization"); } } else if (input.dim(-1) != weight->dim(-1)) { - //std::cout << "weighttttttttttttt: " << *weight << std::endl; - //StorageView weight_dequant(input.dtype(), input.device()); - //_dequantize_op(*weight, *qscale, *_qzero, weight_dequant); - //std::cout << "weighttttttttttttt dequant: " << weight_dequant << std::endl; - //_gemm_op(input, weight_dequant, output, nullptr, bias); _gemm_op(input, *weight, *qscale, output, bias); } else { _gemm_op(input, *weight, output, nullptr, bias); diff --git a/src/types.cc b/src/types.cc index cb7c9b1ce..5ecc23449 100644 --- a/src/types.cc +++ b/src/types.cc @@ -166,8 +166,13 @@ namespace ctranslate2 { const bool support_float16 = mayiuse_float16(device, device_index); const bool support_int16 = mayiuse_int16(device, device_index); const bool support_int8 = mayiuse_int8(device, device_index); - - switch (requested_compute_type) { + const bool lower_bit_model = model_compute_type == ComputeType::INT32_BFLOAT16; + ComputeType accepted_request_compute_type = requested_compute_type; + if (lower_bit_model) { + // model quantized to lower bit could run in its mode only + accepted_request_compute_type = model_compute_type; + } + switch (accepted_request_compute_type) { case ComputeType::FLOAT32: { return ComputeType::FLOAT32; @@ -189,6 +194,12 @@ namespace ctranslate2 { return ComputeType::FLOAT32; } + case ComputeType::INT32_BFLOAT16: { + if (!support_bfloat16) + unsupported_compute_type("bfloat16"); + return ComputeType::INT32_BFLOAT16; + } + case ComputeType::INT16: { if (support_int16) return ComputeType::INT16; @@ -314,6 +325,8 @@ namespace ctranslate2 { return std::make_pair(DataType::FLOAT16, DataType::FLOAT16); case ComputeType::BFLOAT16: return std::make_pair(DataType::BFLOAT16, DataType::BFLOAT16); + case ComputeType::INT32_BFLOAT16: + return std::make_pair(DataType::INT32, DataType::BFLOAT16); default: throw std::invalid_argument("resolve_compute_type should be called first"); } @@ -333,6 +346,8 @@ namespace ctranslate2 { } case DataType::INT16: return ComputeType::INT16; + case DataType::INT32: + return ComputeType::INT32_BFLOAT16; case DataType::FLOAT16: return ComputeType::FLOAT16; case DataType::BFLOAT16: From 2ea15c50a0fc3deccc7a1e7e1e6d0e8a0d2b363a Mon Sep 17 00:00:00 2001 From: minhthuc Date: Tue, 9 Jul 2024 18:17:27 +0200 Subject: [PATCH 07/10] update converter --- python/ctranslate2/converters/converter.py | 24 +++- python/ctranslate2/specs/model_spec.py | 131 +++++++++++++-------- src/storage_view.cc | 2 +- 3 files changed, 107 insertions(+), 50 deletions(-) diff --git a/python/ctranslate2/converters/converter.py b/python/ctranslate2/converters/converter.py index ecede044a..f8ea1671b 100644 --- a/python/ctranslate2/converters/converter.py +++ b/python/ctranslate2/converters/converter.py @@ -5,7 +5,7 @@ from typing import Optional -from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, ModelSpec +from ctranslate2.specs.model_spec import ACCEPTED_MODEL_TYPES, DEVICE, ModelSpec class Converter(abc.ABC): @@ -30,6 +30,18 @@ def declare_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParse choices=ACCEPTED_MODEL_TYPES, help="Weight quantization type.", ) + parser.add_argument( + "--group_size", + type=int, + default=None, + help="Group size used in quantization lower bit", + ) + parser.add_argument( + "--device", + default=None, + choices=DEVICE, + help="Device where the quantization runs on. Available only with hqq quantization", + ) parser.add_argument( "--force", action="store_true", @@ -51,6 +63,8 @@ def convert_from_args(self, args: argparse.Namespace) -> str: args.output_dir, vmap=args.vocab_mapping, quantization=args.quantization, + group_size=args.group_size, + device=args.device, force=args.force, ) @@ -59,6 +73,8 @@ def convert( output_dir: str, vmap: Optional[str] = None, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None, force: bool = False, ) -> str: """Converts the model to the CTranslate2 format. @@ -69,6 +85,10 @@ def convert( in the converted model directory. quantization: Weight quantization scheme (possible values are: int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: group size used by the quantization in lower bit (possible values + are: 32, 64, 128...) + device: Device where the compute of the scales and zero in quantization lower bit + runs (possible values are: cuda, cpu) force: Override the output directory if it already exists. Returns: @@ -95,7 +115,7 @@ def convert( model_spec.register_vocabulary_mapping(vmap) model_spec.validate() - model_spec.optimize(quantization=quantization) + model_spec.optimize(quantization=quantization, group_size=group_size, device=device) # Create model directory. if os.path.exists(output_dir): diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 02c4c6376..fd909b7fd 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -11,6 +11,11 @@ import struct from typing import Dict, List, Optional +try: + from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer + hqq_is_available = True +except ImportError: + hqq_is_available = False import numpy as np @@ -42,6 +47,10 @@ SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor") +DEVICE = ( + "cuda", + "cpu", +) def _join_scope(scope, name): if not scope: @@ -199,34 +208,49 @@ def _pack_4bit_u8(self, w_q): return (w_q[:, :_step] << 4) | w_q[:, _step:] def _hqq_quants_to_torch_quants(self, w_q, scales, zeros, shape, nbits=4): - scales = scales.reshape([w_q.shape[0], -1]) - zeros = zeros.reshape([w_q.shape[0], -1]) - max_int = 2**nbits - 1 min_int = 0 dump = 2 ** (nbits - 1) # HQQ -> torch logic - new_zeros = scales * dump - zeros * scales + new_zeros = (scales * dump) - zeros * scales min_val = new_zeros - scales * dump - w_q = w_q - min_val - w_q = w_q / scales - w_q = np.round(w_q) - w_q = np.clip(w_q, min_int, max_int) - w_q = w_q.astype(np.int32) - w_q = w_q.reshape(shape) - - scales = scales.reshape(shape[0], -1) - new_zeros = new_zeros.reshape(shape[0], -1) - scales = scales.reshape(scales.shape[0], scales.shape[1], 1) - new_zeros = new_zeros.reshape(new_zeros.shape[0], new_zeros.shape[1], 1) - scale_and_zero = np.concatenate([scales, new_zeros], axis=2) - scale_and_zero = scale_and_zero.transpose(1, 0, 2) + + # group_quantize_tensor_from_qparams + w_r = (w_q - zeros) * scales + + w_q = ( + w_r.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape(shape) + .contiguous() + ) + + # group_dequantize_tensor_from_qparams + # W_r = W_q*scales + min_val + + scales = scales.contiguous().reshape(shape[0], -1) + new_zeros = new_zeros.contiguous().reshape(shape[0], -1) + scale_and_zero = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + new_zeros.reshape(new_zeros.size(0), new_zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + return w_q, scale_and_zero - def _quantize(self, quantization): + def _quantize(self, quantization, group_size, device): """Possibly quantizes the variable of the layer.""" if quantization is not None and quantization not in ACCEPTED_MODEL_TYPES: raise ValueError( @@ -241,7 +265,7 @@ def _quantize(spec, name, value): key = _split_scope(name)[-1] scale = None zero = None - is_quantizable = hasattr(spec, "%s_scale" % key) and 'embeddings' not in name + is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16") if is_quantizable: @@ -285,28 +309,31 @@ def _quantize(spec, name, value): value = value.to(quantization) elif quantization in ( "hqq_int4", - ) and value.shape != 3: - value = value.to("float32").numpy() - group_size = 32 - old_shape = value.shape - new_shape = old_shape[:-1] + (old_shape[-1] // 2,) - value = value.reshape(-1, group_size) - gmax = np.amax(value, axis=1) - gmin = np.amin(value, axis=1) - - max_v = 15 - min_v = 0 - scale = np.clip(max_v / (gmax - gmin), 0, 2e4) - zero = -gmin * scale - - zero = zero.reshape(new_shape[0], -1) - scale = scale.reshape(new_shape[0], -1) - scale = 1 / scale - value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) - - scale = NumpyVariable(scale) - scale = scale.to("bfloat16") - value = NumpyVariable(value) + ) and value.shape != 3 and hqq_is_available: + if 'embeddings' in name: + value = value.to("bfloat16") + else: + quant_config = BaseQuantizeConfig(nbits=4, group_size=group_size, + quant_zero=False, quant_scale=False, axis=1) + hqq_linear = HQQLinear(None, quant_config=quant_config, + compute_dtype=torch.bfloat16, device=device) + hqq_linear.quantize(value.to("bfloat16").tensor, **hqq_linear.quant_config) + + value = hqq_linear.W_q.cpu() + scale = hqq_linear.meta['scale'].cpu() + zero = hqq_linear.meta['zero'].cpu() + old_shape = hqq_linear.meta['shape'] + new_shape = old_shape[:-1] + (old_shape[-1] // 2,) + value = Quantizer.unpack[hqq_linear.meta["packing"]](value) + value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) + + scale = scale.cpu() + value = value.cpu() + scale = PyTorchVariable(scale) + value = PyTorchVariable(value) + del hqq_linear.W_q + del hqq_linear.meta['scale'] + del hqq_linear.meta['zero'] elif is_convertible: if quantization in ("float16", "int8_float16"): @@ -322,7 +349,9 @@ def _quantize(spec, name, value): self._visit(_quantize) - def optimize(self, quantization: Optional[str] = None) -> None: + def optimize(self, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None) -> None: """Recursively applies some optimizations to this layer: * Alias variables with the same shape and value. @@ -331,9 +360,11 @@ def optimize(self, quantization: Optional[str] = None) -> None: Arguments: quantization: Weight quantization scheme (possible values are: int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: Group size of quantization lower bits + device: device where the quantization runs on (support hqq only) """ self._alias_variables() - self._quantize(quantization) + self._quantize(quantization, group_size, device) def _visit(self, fn): """Recursively visits this layer and its children.""" @@ -477,14 +508,18 @@ def _write_string(string): _write_string(variable_name) - def _save_quantization_type(self, quantization): + def _save_quantization_type(self, quantization, group_size): if quantization == 'hqq_int4': self._config.add_attribute("quantization_type", common_spec.Quantization.HQQ_INT4) elif quantization is not None: self._config.add_attribute("quantization_type", common_spec.Quantization.CT2) + if group_size is not None: + self._config.add_attribute("quantization_group_size", group_size) - def optimize(self, quantization: Optional[str] = None) -> None: + def optimize(self, quantization: Optional[str] = None, + group_size: Optional[int] = None, + device: Optional[str] = None) -> None: """Recursively applies some optimizations to its layer: * Alias variables with the same shape and value. @@ -493,10 +528,12 @@ def optimize(self, quantization: Optional[str] = None) -> None: Arguments: quantization: Weight quantization scheme (possible values are: int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + group_size: Group size of quantization lower bits + device: device where the quantization runs on (support hqq only) """ - self._save_quantization_type(quantization) + self._save_quantization_type(quantization, group_size) self._alias_variables() - self._quantize(quantization) + self._quantize(quantization, group_size, device) def _flatten_vocabularies(vocabularies): diff --git a/src/storage_view.cc b/src/storage_view.cc index 9636a3f06..1ad45caf3 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -3,7 +3,7 @@ #include "ctranslate2/primitives.h" #include "dispatch.h" - +#include #define PRINT_MAX_VALUES 6 namespace ctranslate2 { From f57a72fa709c38934c41a7539490c40574913a74 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Thu, 11 Jul 2024 09:07:55 +0200 Subject: [PATCH 08/10] fix stuffs --- include/ctranslate2/ops/dequantize.h | 11 - include/ctranslate2/ops/gemm.h | 3 +- python/ctranslate2/specs/model_spec.py | 7 +- src/cpu/primitives.cc | 1 - src/layers/common.cc | 8 +- src/layers/transformer.cc | 3 - src/models/model.cc | 1 - src/ops/dequantize.cc | 32 +- src/ops/dequantize_cpu.cc | 41 - src/ops/dequantize_gpu.cu | 42 - src/ops/gemm.cc | 9 +- src/ops/int4gemm_cpu.cc | 4 + src/ops/int4gemm_gpu.cu | 1251 +++++++++++------------- src/storage_view.cc | 7 - src/type_dispatch.h | 3 - src/types.cc | 2 - 16 files changed, 613 insertions(+), 812 deletions(-) diff --git a/include/ctranslate2/ops/dequantize.h b/include/ctranslate2/ops/dequantize.h index fb2ec8734..06d3c5c85 100644 --- a/include/ctranslate2/ops/dequantize.h +++ b/include/ctranslate2/ops/dequantize.h @@ -14,11 +14,6 @@ namespace ctranslate2 { const StorageView& scale, StorageView& output) const; - void operator()(const StorageView& input, - const StorageView& scale, - const StorageView& zero, - StorageView& output) const; - // Rescales the int32 GEMM output to float32, given the input scales. void operator()(const StorageView& c, const StorageView& a_scale, @@ -34,12 +29,6 @@ namespace ctranslate2 { const StorageView& scale, StorageView& output) const; - template - void dequantize_i4(const StorageView& input, - const StorageView& scale, - const StorageView& zero, - StorageView& output) const; - template void dequantize_gemm_output(const StorageView& c, const StorageView& a_scale, diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index 041c44482..8ecd96083 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -67,6 +67,7 @@ namespace ctranslate2 { StorageView& c, const StorageView* a_shift_compensation) const; +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 template void compute(const StorageView& a, const StorageView& b, @@ -77,7 +78,7 @@ namespace ctranslate2 { void convert_weight_to_int4pack(const StorageView& a, StorageView& b, int32_t innerKTiles); +#endif }; - } } diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index fd909b7fd..116bba912 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -202,11 +202,6 @@ def _alias_variables(self): setattr(spec, attr_name, other_name) break - def _pack_4bit_u8(self, w_q): - w_q = w_q.astype(np.uint8) - _step = int(w_q.shape[1] / 2) - return (w_q[:, :_step] << 4) | w_q[:, _step:] - def _hqq_quants_to_torch_quants(self, w_q, scales, zeros, shape, nbits=4): max_int = 2**nbits - 1 min_int = 0 @@ -373,7 +368,7 @@ def _visit(self, fn): def _dtype_to_type_id(object_dtype): # Order should match the DataType enum in include/ctranslate2/types.h - dtypes = ("float32", "int8", "int16", "int32", "float16", "bfloat16", "uint8") + dtypes = ("float32", "int8", "int16", "int32", "float16", "bfloat16") try: return dtypes.index(object_dtype) except ValueError: diff --git a/src/cpu/primitives.cc b/src/cpu/primitives.cc index dbbc04c4a..0c6377bbb 100644 --- a/src/cpu/primitives.cc +++ b/src/cpu/primitives.cc @@ -1207,7 +1207,6 @@ namespace ctranslate2 { template void \ primitives::mul(T a, const T* x, T* y, dim_t size); - DECLARE_IMPL_NO_FLOAT(uint8_t) DECLARE_IMPL_NO_FLOAT(int8_t) DECLARE_IMPL_NO_FLOAT(int16_t) DECLARE_IMPL_NO_FLOAT(int32_t) diff --git a/src/layers/common.cc b/src/layers/common.cc index c9c26a845..82862eef1 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -5,7 +5,7 @@ #include "ctranslate2/ops/activation.h" #include "cpu/backend.h" #include "dispatch.h" -#include + namespace ctranslate2 { namespace layers { @@ -290,7 +290,9 @@ namespace ctranslate2 { /*trans_b=*/true, /*a_is_packed=*/false, _packed_weight, - _quantized_gemm ? nullptr : activation_type) + _quantized_gemm ? nullptr : activation_type, + (model.config.contains("quantization_group_size") ? + model.config["quantization_group_size"] : nullptr)) , _quantize_op(model.use_global_int16_scale() ? ops::Quantize::ScaleType::GLOBAL : ops::Quantize::ScaleType::PER_LAYER, @@ -427,7 +429,7 @@ namespace ctranslate2 { throw std::invalid_argument("Dense forward: invalid quantized type," "support only ct2 and awq quantization"); } - } else if (input.dim(-1) != weight->dim(-1)) { + } else if (_quant_method == models::QUANTIZATION_TYPE::HQQ_4BIT) { _gemm_op(input, *weight, *qscale, output, bias); } else { _gemm_op(input, *weight, output, nullptr, bias); diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index cf470f41a..291101eae 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -1,7 +1,6 @@ #include "ctranslate2/layers/transformer.h" #include -#include namespace ctranslate2 { namespace layers { @@ -487,9 +486,7 @@ namespace ctranslate2 { StorageView layer_in(dtype, device); StorageView layer_out(dtype, device); - //std::cout << "idssssssss: " << ids << std::endl; _embeddings(ids, layer_in); - //std::cout << "layer innnnnnnnnnn: " << layer_in << std::endl; if (_start_from_zero_embedding) zero_first_timestep(layer_in, step); if (_embeddings_scale && (!_start_from_zero_embedding || step != 0)) diff --git a/src/models/model.cc b/src/models/model.cc index 344ca7e5a..672b2c520 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -6,7 +6,6 @@ #include "ctranslate2/ops/ops.h" #include "ctranslate2/utils.h" #include -#include #ifdef CT2_WITH_CUDA # include "cuda/utils.h" diff --git a/src/ops/dequantize.cc b/src/ops/dequantize.cc index b5b52abdd..96d50f899 100644 --- a/src/ops/dequantize.cc +++ b/src/ops/dequantize.cc @@ -1,4 +1,4 @@ -#include "ctranslate2/ops/dequantize.h" + #include "ctranslate2/ops/dequantize.h" #include "dispatch.h" @@ -43,36 +43,6 @@ namespace ctranslate2 { } } - void Dequantize::operator()(const StorageView& input, - const StorageView& scale, - const StorageView& zero, - StorageView& output) const { - PROFILE("Dequantize lower bit"); - output.resize_as(input); - - //switch (input.dtype()) { - //case DataType::INT4: { - const dim_t block_size = 32; - const dim_t batch_size = input.size() / input.dim(-1); - const dim_t nb_group = input.dim(-1) / (block_size / 2); - if (scale.size() != batch_size * nb_group) - throw std::invalid_argument("INT4 dequantization expects per-block scales in each batch"); - auto shape = input.shape(); - shape[shape.size() - 1] *= 2; - output.resize(std::move(shape)); - - DEVICE_AND_FLOAT_DISPATCH("Dequantize", output.device(), output.dtype(), - (dequantize_i4(input, scale, zero, output))); - - // break; - // } - - // default: - // throw std::invalid_argument("Dequantize: invalid quantized type " + dtype_name(input.dtype()) - // + ", expected int8 or int16"); - //} - } - void Dequantize::operator()(const StorageView& c, const StorageView& a_scale, const StorageView& b_scale, diff --git a/src/ops/dequantize_cpu.cc b/src/ops/dequantize_cpu.cc index ad19eb8ce..5a3ab42bc 100644 --- a/src/ops/dequantize_cpu.cc +++ b/src/ops/dequantize_cpu.cc @@ -20,22 +20,6 @@ namespace ctranslate2 { }); } - static inline void dequantize_i4_kernel(const uint8_t* x, - const float scale, - const float zero, - const dim_t size, - float* y, - bool rp) { - const float r_scale = 1.f / scale; - cpu::parallel_unary_transform(x, y, size, /*work_size=*/4, - [r_scale, rp, zero](uint8_t v) { - if (rp) - return (static_cast(v & 0xF) - zero) * r_scale; - else - return (static_cast(v >> 4) - zero) * r_scale; - }); - } - template<> void Dequantize::dequantize(const StorageView& input, const StorageView& scale, @@ -65,31 +49,6 @@ namespace ctranslate2 { }); } - - template<> - void Dequantize::dequantize_i4(const StorageView& input, - const StorageView& scale, - const StorageView& zero, - StorageView& output) const { - const dim_t block_size = 32; - //const dim_t depth = input.dim(-1); - //const dim_t batch_size = input.size() / input.dim(-1); - const dim_t scale_size = scale.size(); - - const auto* input_data = input.data(); - const auto* scale_data = scale.data(); - const auto* zero_data = zero.data(); - auto* output_data = output.data(); - - cpu::parallel_for(0, scale_size, 1, [&](dim_t begin, dim_t end) { - for (dim_t i = begin; i < end; ++i) { - const dim_t offset = i * (block_size / 2); - dequantize_i4_kernel(input_data + offset, scale_data[i], zero_data[i], block_size / 2, output_data + offset * 2, false); - dequantize_i4_kernel(input_data + offset, scale_data[i], zero_data[i], block_size / 2, output_data + offset * 2 + block_size / 2, true); - } - }); - } - template<> void Dequantize::dequantize_gemm_output(const StorageView& c, const StorageView& a_scale, diff --git a/src/ops/dequantize_gpu.cu b/src/ops/dequantize_gpu.cu index d50f9d18b..ef3adc11d 100644 --- a/src/ops/dequantize_gpu.cu +++ b/src/ops/dequantize_gpu.cu @@ -30,42 +30,6 @@ namespace ctranslate2 { cuda::repeat_vec_depth(depth)); } - template - __global__ void dequantize_i4_kernel(const float* a, - const float* z, - const uint8_t* b, - T* c, - cuda::index_t depth) { - const int32_t block_size = 32; - const auto rescale_func = dequantize_func(); - const cuda::index_t i = blockIdx.x; - for (cuda::index_t j = threadIdx.x; j < depth; j += blockDim.x) { - const cuda::index_t index = i * depth + j; - const cuda::index_t m = index / (block_size / 2); - const cuda::index_t n = index % (block_size / 2); - const float scale = a[m]; - const float zero = z[m]; - uint8_t b1 = (b[index] & 0xF0) >> 4; - uint8_t b2 = (b[index] & 0x0F); - c[n + m * block_size] = rescale_func(scale, b1, zero); - c[n + m * block_size + block_size / 2] = rescale_func(scale, b2, zero); - } - } - - template - void Dequantize::dequantize_i4(const StorageView& input, - const StorageView& scale, - const StorageView& zero, - StorageView& output) const { - const dim_t depth = input.dim(-1); - const dim_t batch_size = input.size() / depth; - const dim_t blocks = std::min(batch_size, cuda::max_blocks); - const dim_t threads = std::min(depth, cuda::max_threads); - dequantize_i4_kernel<<>>( - scale.data(), zero.data(), input.data(), output.data(), depth); - } - - template __global__ void dequantize_gemm_output_kernel(const int32_t* c, const float* a_scales, @@ -183,12 +147,6 @@ namespace ctranslate2 { const StorageView&, \ StorageView&) const; \ template void \ - Dequantize::dequantize_i4( \ - const StorageView&, \ - const StorageView&, \ - const StorageView&, \ - StorageView&) const; \ - template void \ Dequantize::dequantize_gemm_output( \ const StorageView&, \ const StorageView&, \ diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index 257340b43..62d758925 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -77,13 +77,16 @@ namespace ctranslate2 { StorageView& c, const StorageView* bias) const { PROFILE("Gemm"); - +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 dim_t batch_size = a.dim(0); dim_t time = a.dim(1); DEVICE_DISPATCH(a.device(), (compute(a, b, scaleAndZero, c))); c.reshape({batch_size, time, -1}); apply_bias_and_activation(c, bias, _activation_type); +#else + throw std::runtime_error("int4mm is supported only GPU Arch >= 800"); +#endif } template @@ -197,7 +200,11 @@ namespace ctranslate2 { StorageView Gemm::convert_to_int4pack(const StorageView& input, int32_t innerKTiles) { StorageView output(input.device(), input.dtype()); +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800) DEVICE_DISPATCH(input.device(), (convert_weight_to_int4pack(input, output, innerKTiles))); +#else + throw std::runtime_error("convert weight to int4pack is supported only GPU Arch >= 800"); +#endif return output; } } diff --git a/src/ops/int4gemm_cpu.cc b/src/ops/int4gemm_cpu.cc index 82f30c831..bf24b21f0 100644 --- a/src/ops/int4gemm_cpu.cc +++ b/src/ops/int4gemm_cpu.cc @@ -7,12 +7,16 @@ namespace ctranslate2 { const StorageView& b, const StorageView& scaleAndZero, StorageView& c) const { + // todo + throw std::runtime_error("int4mm is not supported for CPU"); } template <> void Gemm::convert_weight_to_int4pack(const StorageView& a, StorageView& b, int32_t innerKTiles) { + // todo + throw std::runtime_error("convert_weight_to_int4pack is not supported for CPU"); } #define DECLARE_IMPL(T) \ diff --git a/src/ops/int4gemm_gpu.cu b/src/ops/int4gemm_gpu.cu index f9ae1eaaa..71c890ba6 100644 --- a/src/ops/int4gemm_gpu.cu +++ b/src/ops/int4gemm_gpu.cu @@ -23,12 +23,6 @@ namespace ctranslate2 { return divDown(a, b) * b; } - template - constexpr __host__ __device__ auto roundUp(U a, V b) -> decltype(a + b) { - static_assert(std::is_integral::value && std::is_integral::value, ""); - return divUp(a, b) * b; - } - template constexpr __host__ __device__ bool isEvenDivisor(U a, V b) { static_assert(std::is_integral::value && std::is_integral::value, ""); @@ -116,729 +110,672 @@ namespace ctranslate2 { } constexpr int32_t kWarpSize = 32; - -//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) // f16 vector types -struct __align__(2) f16x1 { - __half vals[1]; -}; - -struct __align__(4) f16x2 { - __half vals[2]; -}; - -struct __align__(8) f16x4 { - __half vals[4]; -}; - -struct __align__(16) f16x8 { - __half vals[8]; -}; - -// bf16 vector types -struct __align__(2) bf16x1 { - __nv_bfloat16 vals[1]; -}; - -struct __align__(4) bf16x2 { - __nv_bfloat16 vals[2]; -}; - -struct __align__(8) bf16x4 { - __nv_bfloat16 vals[4]; -}; - -struct __align__(16) bf16x8 { - __nv_bfloat16 vals[8]; -}; - -// bf162 vector types -struct __align__(4) bf16x2x1 { - __nv_bfloat162 vals[1]; -}; - -struct __align__(8) bf16x2x2 { - __nv_bfloat162 vals[2]; -}; - -struct __align__(16) bf16x2x4 { - __nv_bfloat162 vals[4]; -}; - -struct __align__(16) bf16x2x4_u32 { - uint32_t vals[4]; -}; - -struct __align__(8) bf16x2x2_u32 { - uint32_t vals[2]; -}; - -struct __align__(4) bf16x2x1_u32 { - uint32_t vals[1]; -}; - -template -struct __align__(sizeof(T) * N) VectorType { - T vals[N]; -}; - -// from -// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { - bf16x2x4 result; - constexpr int kElements = 8; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = source; - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so - // we must loop. No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); -#pragma unroll - for (int ii = 1; ii < kElements / 2; ++ii) { - i4s >>= 4; // or is it 8? - // (i4s & 0x000f000f) | 0x43004300 - asm volatile( - "lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; -// Finally, we construct the output numbers. + struct __align__(16) bf16x2x4 { + __nv_bfloat162 vals[4]; + }; + + struct __align__(16) bf16x2x4_u32 { + uint32_t vals[4]; + }; + + // from + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + inline __device__ bf16x2x4 convert_i4x8_to_bf16x2x4(uint32_t source) { + bf16x2x4 result; + constexpr int kElements = 8; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = source; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so + // we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); #pragma unroll - for (int ii = 0; ii < kElements / 2; ++ii) { - // Since this section is for Ampere+, we use bf16 fma to do the bias - // subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" - : "=r"(h[ii]) - : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } + for (int ii = 1; ii < kElements / 2; ++ii) { + i4s >>= 4; // or is it 8? + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } - return result; -} - - - -enum class KReductionType { - // No k-reduction is needed between blocks as the number of k-tiles processed - // per block are exact and we can directly write the output - None, -}; - -// Loads the A matrix in 16-bit standard m x k row major layout, and writes -// the C matrix in 16-bit standard m x n row major layout: -// -// size [m][k] -template -struct ALayout_RM { - static constexpr int32_t kMTileSize = 16; - static constexpr int32_t kNTileSize = 8; - static constexpr int32_t kKTileSize = 16; - - template - static __device__ void load( - const void* A, - int32_t m, - int32_t k, - int32_t mTiles, - int32_t mTile, - int32_t kTiles, - int32_t kTileStart, - int32_t laneId, - bf16x2x4_u32 out[KTilesToLoad]) { - const auto mLane = mTile * kMTileSize + (laneId / 4); - const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; - - // access - // [mTile * kMTileSize + (laneId / 4)] - // [kTileStart * kKTileSize + (laneId % 4) * 2] - auto aPtr = reinterpret_cast(A) + mLane * k + kLane; - - auto aPtrPlus8Rows = aPtr + 8 * k; - - bool m0InBounds = mLane < m; - bool m1InBounds = (mLane + 8) < m; + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + // Finally, we construct the output numbers. #pragma unroll - for (int i = 0; i < KTilesToLoad; ++i) { - out[i].vals[0] = m0InBounds - ? *reinterpret_cast(aPtr + i * kKTileSize) - : uint32_t(0); - out[i].vals[1] = m1InBounds - ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) - : uint32_t(0); - - out[i].vals[2] = m0InBounds - ? *reinterpret_cast(aPtr + i * kKTileSize + 8) - : uint32_t(0); - out[i].vals[3] = m1InBounds ? *reinterpret_cast( - aPtrPlus8Rows + i * kKTileSize + 8) - : uint32_t(0); + for (int ii = 0; ii < kElements / 2; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + + return result; } - } - static __device__ void store( - void* C, - int32_t m, - int32_t n, - int32_t mOutTiles, - int32_t mTile, - int32_t nOutTiles, - int32_t nTile, - int32_t laneId, - const float4& out) { - static_assert(ReduceType == KReductionType::None, ""); - - if constexpr (ReduceType == KReductionType::None) { - // sum.x / sum.y are written at - // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] - // sum.z / sum.w are written at - // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] - // i.e., same columns, different row. - const int outRow = mTile * kMTileSize + (laneId / 4); - const int outCol = nTile * kNTileSize + (laneId % 4) * 2; - - // Pointer where sum.x / sum.y is written - auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; - - auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); - auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); - - if (outRow < m) { - *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + + + enum class KReductionType { + // No k-reduction is needed between blocks as the number of k-tiles processed + // per block are exact and we can directly write the output + None, + }; + + // Loads the A matrix in 16-bit standard m x k row major layout, and writes + // the C matrix in 16-bit standard m x n row major layout: + // + // size [m][k] + template + struct ALayout_RM { + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + const void* A, + int32_t m, + int32_t k, + int32_t mTiles, + int32_t mTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad]) { + const auto mLane = mTile * kMTileSize + (laneId / 4); + const auto kLane = kTileStart * kKTileSize + (laneId % 4) * 2; + + // access + // [mTile * kMTileSize + (laneId / 4)] + // [kTileStart * kKTileSize + (laneId % 4) * 2] + auto aPtr = reinterpret_cast(A) + mLane * k + kLane; + + auto aPtrPlus8Rows = aPtr + 8 * k; + + bool m0InBounds = mLane < m; + bool m1InBounds = (mLane + 8) < m; + +#pragma unroll + for (int i = 0; i < KTilesToLoad; ++i) { + out[i].vals[0] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize) + : uint32_t(0); + out[i].vals[1] = m1InBounds + ? *reinterpret_cast(aPtrPlus8Rows + i * kKTileSize) + : uint32_t(0); + + out[i].vals[2] = m0InBounds + ? *reinterpret_cast(aPtr + i * kKTileSize + 8) + : uint32_t(0); + out[i].vals[3] = m1InBounds ? *reinterpret_cast( + aPtrPlus8Rows + i * kKTileSize + 8) + : uint32_t(0); + } } - // sum.z, sum.w at +8 rows from cPtr - if (outRow + 8 < m) { - *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + static __device__ void store( + void* C, + int32_t m, + int32_t n, + int32_t mOutTiles, + int32_t mTile, + int32_t nOutTiles, + int32_t nTile, + int32_t laneId, + const float4& out) { + static_assert(ReduceType == KReductionType::None, ""); + + if constexpr (ReduceType == KReductionType::None) { + // sum.x / sum.y are written at + // [laneId / 4], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // sum.z / sum.w are written at + // [8 + (laneId / 4)], [(laneId % 4) * 2, (laneId % 4) * 2 + 1] + // i.e., same columns, different row. + const int outRow = mTile * kMTileSize + (laneId / 4); + const int outCol = nTile * kNTileSize + (laneId % 4) * 2; + + // Pointer where sum.x / sum.y is written + auto cPtr = reinterpret_cast<__nv_bfloat16*>(C) + outRow * n + outCol; + + auto v01 = __float22bfloat162_rn(float2{out.x, out.y}); + auto v23 = __float22bfloat162_rn(float2{out.z, out.w}); + + if (outRow < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr) = v01; + } + + // sum.z, sum.w at +8 rows from cPtr + if (outRow + 8 < m) { + *reinterpret_cast<__nv_bfloat162*>(cPtr + 8 * n) = v23; + } + } } - } - } -}; - -template -struct BLayout_TC_int4 { - static constexpr int32_t kInnerKTiles = InnerKTiles; - static constexpr int32_t kMTileSize = 16; - static constexpr int32_t kNTileSize = 8; - static constexpr int32_t kKTileSize = 16; - - template - static __device__ void load( - // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] - // n / 8: n-tiles (n8) - // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) - // 32: value per warp lane - // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. - // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest - // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a - // uint32x4 (128 bits) - const void* __restrict__ B, - // size [k / qGroupSize][n][2] - // Contains the scale and zero point of each of the quantized int4 values - // within B - // v_reconstructed = (bf16(B_int4_val) * scale) - zero - const void* __restrict__ quantizationInfo, - int32_t n, - int32_t k, - int32_t nTiles, - int32_t nTile, - int32_t kTiles, - int32_t kTileStart, - int32_t laneId, - bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { - // offset [nTile][kTileStart / InnerKTiles][laneId][0] - auto bPtr = reinterpret_cast(B) + - (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * - kWarpSize) + - laneId) * - (InnerKTiles / 2); - - int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; + }; + + template + struct BLayout_TC_int4 { + static constexpr int32_t kInnerKTiles = InnerKTiles; + static constexpr int32_t kMTileSize = 16; + static constexpr int32_t kNTileSize = 8; + static constexpr int32_t kKTileSize = 16; + + template + static __device__ void load( + // type uint32, size [n / 8][k / (InnerKTiles * 16)][32][InnerKTiles / 2] + // n / 8: n-tiles (n8) + // k / (InnerKTiles * 16): TC size per k-tile is 16 (m16n8k16) + // 32: value per warp lane + // (InnerKTiles / 2): B layout has 4 values per lane (16 bits) per k-tile. + // 2 k-tiles packed is a uint32 (hence InnerKTiles == 2 is our smallest + // value) 4 k-tiles packed is a uint32x2 (64 bits) 8 k-tiles packed is a + // uint32x4 (128 bits) + const void* __restrict__ B, + // size [k / qGroupSize][n][2] + // Contains the scale and zero point of each of the quantized int4 values + // within B + // v_reconstructed = (bf16(B_int4_val) * scale) - zero + const void* __restrict__ quantizationInfo, + int32_t n, + int32_t k, + int32_t nTiles, + int32_t nTile, + int32_t kTiles, + int32_t kTileStart, + int32_t laneId, + bf16x2x4_u32 out[KTilesToLoad / InnerKTiles][InnerKTiles / 2]) { + // offset [nTile][kTileStart / InnerKTiles][laneId][0] + auto bPtr = reinterpret_cast(B) + + (((nTile * (kTiles / InnerKTiles) + (kTileStart / InnerKTiles)) * + kWarpSize) + + laneId) * + (InnerKTiles / 2); + + int32_t b_int4[KTilesToLoad / InnerKTiles][InnerKTiles / 2]; #pragma unroll - for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { - auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + auto bPtrCur = bPtr + i * kWarpSize * (InnerKTiles / 2); - if constexpr (InnerKTiles == 2) { - b_int4[i][0] = bPtrCur[0]; - } + if constexpr (InnerKTiles == 2) { + b_int4[i][0] = bPtrCur[0]; + } - if constexpr (InnerKTiles == 4) { - // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" - // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) - // : "l"(bPtrCur)); + if constexpr (InnerKTiles == 4) { + // asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]) + // : "l"(bPtrCur)); - int2 load8 = reinterpret_cast(bPtrCur)[0]; - b_int4[i][0] = load8.x; - b_int4[i][1] = load8.y; - } + int2 load8 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load8.x; + b_int4[i][1] = load8.y; + } - if constexpr (InnerKTiles == 8) { - // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" - // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), - // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); + if constexpr (InnerKTiles == 8) { + // asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n" + // : "=r"(b_int4[i][0]), "=r"(b_int4[i][1]), + // "=r"(b_int4[i][2]), "=r"(b_int4[i][3]) : "l"(bPtrCur)); - int4 load16 = reinterpret_cast(bPtrCur)[0]; - b_int4[i][0] = load16.x; - b_int4[i][1] = load16.y; - b_int4[i][2] = load16.z; - b_int4[i][3] = load16.w; - } - } + int4 load16 = reinterpret_cast(bPtrCur)[0]; + b_int4[i][0] = load16.x; + b_int4[i][1] = load16.y; + b_int4[i][2] = load16.z; + b_int4[i][3] = load16.w; + } + } - // Load needed info for dequantization + // Load needed info for dequantization - static_assert(isPowerOf2(QGroupSize), ""); - static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); - // smallest quantization group size is 32 (2 k-tiles are packed in an int32) - static_assert(QGroupSize >= kKTileSize * 2, ""); - constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); - // a q-group could be larger than what we are handling in a single warp - constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 - ? 1 - : (KTilesToLoad / kKTilesPerQGroup); + static_assert(isPowerOf2(QGroupSize), ""); + static_assert(isEvenDivisor(QGroupSize, kKTileSize), ""); + // smallest quantization group size is 32 (2 k-tiles are packed in an int32) + static_assert(QGroupSize >= kKTileSize * 2, ""); + constexpr int kKTilesPerQGroup = (QGroupSize / kKTileSize); + // a q-group could be larger than what we are handling in a single warp + constexpr int kNumQGroups = (KTilesToLoad / kKTilesPerQGroup) < 1 + ? 1 + : (KTilesToLoad / kKTilesPerQGroup); - __nv_bfloat162 qScaleAndZero[kNumQGroups]; - { - int32_t laneN = nTile * kNTileSize + (laneId / 4); - int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; + __nv_bfloat162 qScaleAndZero[kNumQGroups]; + { + int32_t laneN = nTile * kNTileSize + (laneId / 4); + int32_t groupStart = (kTileStart * kKTileSize) / QGroupSize; - int32_t n = nTiles * kNTileSize; + int32_t n = nTiles * kNTileSize; - // offset [qScale_kGroup][qScale_n][0] - auto qInfoPtr = reinterpret_cast(quantizationInfo) + - (groupStart * n + laneN) * 2; + // offset [qScale_kGroup][qScale_n][0] + auto qInfoPtr = reinterpret_cast(quantizationInfo) + + (groupStart * n + laneN) * 2; #pragma unroll - for (int i = 0; i < kNumQGroups; ++i) { - qScaleAndZero[i] = - *reinterpret_cast(qInfoPtr + i * n * 2); - } - } + for (int i = 0; i < kNumQGroups; ++i) { + qScaleAndZero[i] = + *reinterpret_cast(qInfoPtr + i * n * 2); + } + } - // - // De-quantize int4 values to bf16. Values are dequantized as truly int4 - // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero - // - { - // FIXME: does this negatively affect register counts, or will nvcc - // move this expansion (and data loads above) closer to the point of use? - __nv_bfloat162 qScale[kNumQGroups]; - __nv_bfloat162 qZero[kNumQGroups]; + // + // De-quantize int4 values to bf16. Values are dequantized as truly int4 + // [-8, 7] range; dequant = (bf16(int4_value) * bf16_scale) + bf16_zero + // + { + // FIXME: does this negatively affect register counts, or will nvcc + // move this expansion (and data loads above) closer to the point of use? + __nv_bfloat162 qScale[kNumQGroups]; + __nv_bfloat162 qZero[kNumQGroups]; #pragma unroll - for (int i = 0; i < kNumQGroups; ++i) { - qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); - qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); - } + for (int i = 0; i < kNumQGroups; ++i) { + qScale[i] = __bfloat162bfloat162(qScaleAndZero[i].x); + qZero[i] = __bfloat162bfloat162(qScaleAndZero[i].y); + } #pragma unroll - for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { + for (int i = 0; i < KTilesToLoad / InnerKTiles; ++i) { #pragma unroll - for (int j = 0; j < InnerKTiles / 2; ++j) { - bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); + for (int j = 0; j < InnerKTiles / 2; ++j) { + bf16x2x4 v = convert_i4x8_to_bf16x2x4(b_int4[i][j]); - int curKTile = i * InnerKTiles + j * 2; - int curQGroup = (curKTile * kKTileSize) / QGroupSize; + int curKTile = i * InnerKTiles + j * 2; + int curQGroup = (curKTile * kKTileSize) / QGroupSize; - // The dequantized values in `v` for a given lane have the same n - // dimension (the B tensor core layout has all values in the same - // thread along the same n) but different k dimension, but all are - // guaranteed to occur within the same quantization group, so we need - // only load a single scale + zero to cover what this lane has + // The dequantized values in `v` for a given lane have the same n + // dimension (the B tensor core layout has all values in the same + // thread along the same n) but different k dimension, but all are + // guaranteed to occur within the same quantization group, so we need + // only load a single scale + zero to cover what this lane has #pragma unroll - for (int k = 0; k < 4; ++k) { - v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + for (int k = 0; k < 4; ++k) { + v.vals[k] = __hfma2(v.vals[k], qScale[curQGroup], qZero[curQGroup]); + } + + // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and + // can't be used as a 32-bit asm register argument for `mma` + static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); + std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); + } } - - // type pun, the __nv_bfloat162 value in bf16x2x4 is a struct and - // can't be used as a 32-bit asm register argument for `mma` - static_assert(sizeof(bf16x2x4) == sizeof(out[0][0]), ""); - std::memcpy(&out[i][j], &v, sizeof(bf16x2x4_u32)); } } - } - } -}; - -template < - typename ALayout, - typename BLayout, - typename CLayout, - int Warps, - int KTilesPerIteration> -__global__ -__launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( - // Data for the A matrix, loaded as per ALayout - const void* const __restrict__ A, - - // Data for the B matrix, loaded as per BLayout - const void* const __restrict__ B, - - // Optional quantization data for dequantizing B, loaded as per BLayout - const void* const __restrict__ B_quantizationInfo, - - // Output data for the C matrix, stored as per CLayout - void* __restrict__ C, - - // The size of the matrix multiplication - int32_t m, - int32_t n, - int32_t k, - - // The size of the matrix multiplication, in multiples of our TC tile size - int32_t mTiles, - int32_t nTiles, - int32_t kTiles) { - constexpr int32_t kMTileSize = 16; - constexpr int32_t kNTileSize = 8; - constexpr int32_t kKTileSize = 16; - - static_assert( - ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && - ALayout::kKTileSize == kKTileSize, - ""); - - static_assert( - BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && - BLayout::kKTileSize == kKTileSize, - ""); - - static_assert( - CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && - CLayout::kKTileSize == kKTileSize, - ""); - - constexpr int kInnerKTiles = BLayout::kInnerKTiles; - - // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads - static_assert( - kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); - - // We always process at least kInnerKTiles k-tiles back to back in a warp - static_assert( - KTilesPerIteration >= kInnerKTiles && - isEvenDivisor(KTilesPerIteration, kInnerKTiles), - ""); - - auto warpId = threadIdx.y; - auto laneId = threadIdx.x; - - int32_t mTile = blockIdx.z; - int32_t nTile = blockIdx.y; - - float4 c{0.0f, 0.0f, 0.0f, 0.0f}; - - // First, handle whole multiples of KTilesPerIteration - auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); - - // Each warp handles a set of KTilesPerIteration under the above limit - for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; - kTileBase < kTilesLimit; - kTileBase += Warps * KTilesPerIteration) { - // - // Load data from A - // - bf16x2x4_u32 a[KTilesPerIteration]; - ALayout::template load( - A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); - - // - // Load data from B and de-quantize as needed - // Each k-tile is bf16x2x2 - // - bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; - BLayout::template load( - B, - B_quantizationInfo, - n, - k, - nTiles, - nTile, - kTiles, - kTileBase, - laneId, - b); - - // - // Now, perform the matrix multiplication - // + }; + + template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerIteration> + __global__ + __launch_bounds__(Warps* kWarpSize) void tinygemm_m16n8k16_chunk_kernel( + // Data for the A matrix, loaded as per ALayout + const void* const __restrict__ A, + + // Data for the B matrix, loaded as per BLayout + const void* const __restrict__ B, + + // Optional quantization data for dequantizing B, loaded as per BLayout + const void* const __restrict__ B_quantizationInfo, + + // Output data for the C matrix, stored as per CLayout + void* __restrict__ C, + + // The size of the matrix multiplication + int32_t m, + int32_t n, + int32_t k, + + // The size of the matrix multiplication, in multiples of our TC tile size + int32_t mTiles, + int32_t nTiles, + int32_t kTiles) { + constexpr int32_t kMTileSize = 16; + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; - // We accumulate across k-tiles here + static_assert( + ALayout::kMTileSize == kMTileSize && ALayout::kNTileSize == kNTileSize && + ALayout::kKTileSize == kKTileSize, + ""); + + static_assert( + BLayout::kMTileSize == kMTileSize && BLayout::kNTileSize == kNTileSize && + BLayout::kKTileSize == kKTileSize, + ""); + + static_assert( + CLayout::kMTileSize == kMTileSize && CLayout::kNTileSize == kNTileSize && + CLayout::kKTileSize == kKTileSize, + ""); + + constexpr int kInnerKTiles = BLayout::kInnerKTiles; + + // 2/4/8 inner k-tiles correspond to 4, 8 and 16 byte innermost loads + static_assert( + kInnerKTiles == 2 || kInnerKTiles == 4 || kInnerKTiles == 8, ""); + + // We always process at least kInnerKTiles k-tiles back to back in a warp + static_assert( + KTilesPerIteration >= kInnerKTiles && + isEvenDivisor(KTilesPerIteration, kInnerKTiles), + ""); + + auto warpId = threadIdx.y; + auto laneId = threadIdx.x; + + int32_t mTile = blockIdx.z; + int32_t nTile = blockIdx.y; + + float4 c{0.0f, 0.0f, 0.0f, 0.0f}; + + // First, handle whole multiples of KTilesPerIteration + auto kTilesLimit = roundDown(kTiles, KTilesPerIteration); + + // Each warp handles a set of KTilesPerIteration under the above limit + for (int32_t kTileBase = (blockIdx.x * Warps + warpId) * KTilesPerIteration; + kTileBase < kTilesLimit; + kTileBase += Warps * KTilesPerIteration) { + // + // Load data from A + // + bf16x2x4_u32 a[KTilesPerIteration]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBase, laneId, a); + + // + // Load data from B and de-quantize as needed + // Each k-tile is bf16x2x2 + // + bf16x2x4_u32 b[KTilesPerIteration / kInnerKTiles][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBase, + laneId, + b); + + // + // Now, perform the matrix multiplication + // + + // We accumulate across k-tiles here #pragma unroll - for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { - static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); + for (int i = 0; i < KTilesPerIteration / kInnerKTiles; ++i) { + static_assert(isEvenDivisor(kInnerKTiles, 2) && kInnerKTiles >= 2, ""); #pragma unroll - for (int j = 0; j < kInnerKTiles / 2; ++j) { - // We don't simply accumulate into `c` as this creates a too-strong - // execution dependency. Instead, we only periodically accumulate into - // `c` - float4 cTmp[2]; + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; #pragma unroll - for (int k = 0; k < 2; ++k) { - cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; - } + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } #pragma unroll - for (int k = 0; k < 2; ++k) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" - : "=f"(cTmp[k].x), - "=f"(cTmp[k].y), - "=f"(cTmp[k].z), - "=f"(cTmp[k].w) - : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), - "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), - "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), - "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), - "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), - "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), - "f"(cTmp[k].x), - "f"(cTmp[k].y), - "f"(cTmp[k].z), - "f"(cTmp[k].w)); - } + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), + "=f"(cTmp[k].y), + "=f"(cTmp[k].z), + "=f"(cTmp[k].w) + : "r"(a[i * kInnerKTiles + j * 2 + k].vals[0]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[1]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[2]), + "r"(a[i * kInnerKTiles + j * 2 + k].vals[3]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[i][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } #pragma unroll - for (int k = 0; k < 2; ++k) { - c.x += cTmp[k].x; - c.y += cTmp[k].y; - c.z += cTmp[k].z; - c.w += cTmp[k].w; + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } } - } - } - } // for all tiles under kTilesLimit - - // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles - // remaining. We guarantee that the number of warps is >= KTilesPerIteration / - // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its - // thing without needing more warps - static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); - - auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; - - // If we have any remainder k-tiles, some warps will handle them, processing - // kInnerKTiles k-tiles at a time - if (kTileBaseRemaining < kTiles) { - bf16x2x4_u32 a[kInnerKTiles]; - ALayout::template load( - A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); - - bf16x2x4_u32 b[1][kInnerKTiles / 2]; - BLayout::template load( - B, - B_quantizationInfo, - n, - k, - nTiles, - nTile, - kTiles, - kTileBaseRemaining, - laneId, - b); + } // for all tiles under kTilesLimit + + // Now, there could be a remainder of 1 to KTilesPerIteration - 1 k-tiles + // remaining. We guarantee that the number of warps is >= KTilesPerIteration / + // kInnerKTiles, so that each warp can simply load kInnerKTiles and do its + // thing without needing more warps + static_assert(Warps >= KTilesPerIteration / kInnerKTiles, ""); + + auto kTileBaseRemaining = kTilesLimit + warpId * kInnerKTiles; + + // If we have any remainder k-tiles, some warps will handle them, processing + // kInnerKTiles k-tiles at a time + if (kTileBaseRemaining < kTiles) { + bf16x2x4_u32 a[kInnerKTiles]; + ALayout::template load( + A, m, k, mTiles, mTile, kTiles, kTileBaseRemaining, laneId, a); + + bf16x2x4_u32 b[1][kInnerKTiles / 2]; + BLayout::template load( + B, + B_quantizationInfo, + n, + k, + nTiles, + nTile, + kTiles, + kTileBaseRemaining, + laneId, + b); #pragma unroll - for (int j = 0; j < kInnerKTiles / 2; ++j) { - // We don't simply accumulate into `c` as this creates a too-strong - // execution dependency. Instead, we only periodically accumulate into - // `c` - float4 cTmp[2]; + for (int j = 0; j < kInnerKTiles / 2; ++j) { + // We don't simply accumulate into `c` as this creates a too-strong + // execution dependency. Instead, we only periodically accumulate into + // `c` + float4 cTmp[2]; #pragma unroll - for (int k = 0; k < 2; ++k) { - cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; - } + for (int k = 0; k < 2; ++k) { + cTmp[k] = float4{0.0f, 0.0f, 0.0f, 0.0f}; + } #pragma unroll - for (int k = 0; k < 2; ++k) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" - : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) - : "r"(a[j * 2 + k].vals[0]), - "r"(a[j * 2 + k].vals[1]), - "r"(a[j * 2 + k].vals[2]), - "r"(a[j * 2 + k].vals[3]), - "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), - "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), - "f"(cTmp[k].x), - "f"(cTmp[k].y), - "f"(cTmp[k].z), - "f"(cTmp[k].w)); - } + for (int k = 0; k < 2; ++k) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};" + : "=f"(cTmp[k].x), "=f"(cTmp[k].y), "=f"(cTmp[k].z), "=f"(cTmp[k].w) + : "r"(a[j * 2 + k].vals[0]), + "r"(a[j * 2 + k].vals[1]), + "r"(a[j * 2 + k].vals[2]), + "r"(a[j * 2 + k].vals[3]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 0]), + "r"(b[0][(j * 2 + k) / 2].vals[((j * 2 + k) % 2) * 2 + 1]), + "f"(cTmp[k].x), + "f"(cTmp[k].y), + "f"(cTmp[k].z), + "f"(cTmp[k].w)); + } #pragma unroll - for (int k = 0; k < 2; ++k) { - c.x += cTmp[k].x; - c.y += cTmp[k].y; - c.z += cTmp[k].z; - c.w += cTmp[k].w; + for (int k = 0; k < 2; ++k) { + c.x += cTmp[k].x; + c.y += cTmp[k].y; + c.z += cTmp[k].z; + c.w += cTmp[k].w; + } + } } - } - } - // - // Reduce independent k-tiles (same m/n) across warps - // - __shared__ float4 smem_sum[Warps][kWarpSize]; + // + // Reduce independent k-tiles (same m/n) across warps + // + __shared__ float4 smem_sum[Warps][kWarpSize]; - // FIXME: this likely doesn't need to be a true reduction tree, can just be a - // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) - // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); - smem_sum[warpId][laneId] = c; + // FIXME: this likely doesn't need to be a true reduction tree, can just be a + // serial sum, maybe (unless nvcc/ptxas goes back to its old ways) + // smem_sum[warpId][laneId] = TreeReduce4::reduce(c); + smem_sum[warpId][laneId] = c; - __syncthreads(); + __syncthreads(); - if (warpId == 0) { - float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; + if (warpId == 0) { + float4 sum_f32{0.0f, 0.0f, 0.0f, 0.0f}; - // Reduce across the block in the first warp - for (int i = 0; i < Warps; ++i) { - float4 v = smem_sum[i][laneId]; - sum_f32.x += v.x; - sum_f32.y += v.y; - sum_f32.z += v.z; - sum_f32.w += v.w; - } - - // Write the reduced result (in the first warp) into the output - CLayout::store( - C, - m, - n, - mTiles, - mTile, - // n for C output becomes k for A input, so for m16n8k16, - // we need to halve the tiles - nTiles / 2, - nTile, - laneId, - sum_f32); - } -} - -// FIXME: parallelize better, smem staging etc? - template - __global__ void - matrix_to_m16n8k16_Bint4_layout( - // size [n][k] - const int32_t* in, - int32_t n, - int32_t depth, - int32_t depth_output1, - int32_t depth_output2, - int32_t depth_output3, - // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] - int32_t* out) { - // int4 values are packed into int32 values, which require at least 8. Given - // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of - // innermost k-tiles that we can use is 2. - static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); - - constexpr int32_t kNTileSize = 8; - constexpr int32_t kKTileSize = 16; + // Reduce across the block in the first warp + for (int i = 0; i < Warps; ++i) { + float4 v = smem_sum[i][laneId]; + sum_f32.x += v.x; + sum_f32.y += v.y; + sum_f32.z += v.z; + sum_f32.w += v.w; + } - // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles - auto kOuterTile = blockIdx.x; - auto nTile = blockIdx.y; - auto t = threadIdx.x; + // Write the reduced result (in the first warp) into the output + CLayout::store( + C, + m, + n, + mTiles, + mTile, + // n for C output becomes k for A input, so for m16n8k16, + // we need to halve the tiles + nTiles / 2, + nTile, + laneId, + sum_f32); + } + } - // Two k-tiles are packed into an int32 at a time + // FIXME: parallelize better, smem staging etc? + template + __global__ void + matrix_to_m16n8k16_Bint4_layout( + // size [n][k] + const int32_t* in, + int32_t n, + int32_t depth, + int32_t depth_output1, + int32_t depth_output2, + int32_t depth_output3, + // size [ceil(n / 8)][ceil(k / (InnerKTiles * 16))][32][InnerKTiles / 2] + int32_t* out) { + // int4 values are packed into int32 values, which require at least 8. Given + // m16n8k16 B layout requires 4 scalar values/lane, the minimum number of + // innermost k-tiles that we can use is 2. + static_assert(InnerKTiles >= 2 && isPowerOf2(InnerKTiles), ""); + + constexpr int32_t kNTileSize = 8; + constexpr int32_t kKTileSize = 16; + + // gridDim.x corresponds to the number of k-tiles divided by InnerKTiles + auto kOuterTile = blockIdx.x; + auto nTile = blockIdx.y; + auto t = threadIdx.x; + + // Two k-tiles are packed into an int32 at a time #pragma unroll - for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { - // n dimension that this lane loads from - auto n0 = nTile * kNTileSize + (t / 4); + for (int innerKTile = 0; innerKTile < InnerKTiles; innerKTile += 2) { + // n dimension that this lane loads from + auto n0 = nTile * kNTileSize + (t / 4); - bool n0Valid = n0 < n; + bool n0Valid = n0 < n; - int32_t ks[8]; + int32_t ks[8]; - auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; - ks[0] = kBase0 + (t % 4) * 2; - ks[1] = ks[0] + 1; - ks[2] = ks[0] + 8; - ks[3] = ks[0] + 8 + 1; + auto kBase0 = (kOuterTile * InnerKTiles + innerKTile) * kKTileSize; + ks[0] = kBase0 + (t % 4) * 2; + ks[1] = ks[0] + 1; + ks[2] = ks[0] + 8; + ks[3] = ks[0] + 8 + 1; - auto kBase1 = kBase0 + kKTileSize; - ks[4] = kBase1 + (t % 4) * 2; - ks[5] = ks[4] + 1; - ks[6] = ks[4] + 8; - ks[7] = ks[4] + 8 + 1; + auto kBase1 = kBase0 + kKTileSize; + ks[4] = kBase1 + (t % 4) * 2; + ks[5] = ks[4] + 1; + ks[6] = ks[4] + 8; + ks[7] = ks[4] + 8 + 1; - auto pIn = in + (n0 * depth); + auto pIn = in + (n0 * depth); - uint32_t v[8]; + uint32_t v[8]; #pragma unroll - for (int i = 0; i < 8; ++i) { - v[i] = (n0Valid && (ks[i] < depth)) ? pIn[ks[i]] : uint32_t(0); - } + for (int i = 0; i < 8; ++i) { + v[i] = (n0Valid && (ks[i] < depth)) ? pIn[ks[i]] : uint32_t(0); + } - int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) | - (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; + int32_t pack = (v[7] << 28) | (v[5] << 24) | (v[3] << 20) | (v[1] << 16) | + (v[6] << 12) | (v[4] << 8) | (v[2] << 4) | v[0]; - // inner k-tiles pack two at a time - out[nTile * depth_output3 + kOuterTile * depth_output2 + t * depth_output1 + innerKTile / 2] = pack; - } - } + // inner k-tiles pack two at a time + out[nTile * depth_output3 + kOuterTile * depth_output2 + t * depth_output1 + innerKTile / 2] = pack; + } + } -template < - typename ALayout, - typename BLayout, - typename CLayout, - int Warps, - int KTilesPerWarp> -void launch_tinygemm_kernel( - const StorageView& A, - const StorageView& B, - const StorageView* qScaleAndZeros, /* optional */ - StorageView& C_final, - int32_t mTiles, - int32_t nTiles, - int32_t kTiles, - int32_t m, - int32_t n, - int32_t k, - cudaStream_t stream) { - // After intra-block reduction across the k dimension, we are left with this - // many tiles - // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); - int32_t postKernelKTiles = 1; // we loop - - auto grid = dim3(postKernelKTiles, nTiles, mTiles); - auto block = dim3(kWarpSize, Warps); - - auto func = - tinygemm_m16n8k16_chunk_kernel; - - func<<>>( - A.data(), - B.data(), - qScaleAndZeros ? qScaleAndZeros->data() : nullptr, - C_final.data(), - m, - n, - k, - mTiles, - nTiles, - kTiles); -} -//#endif + template < + typename ALayout, + typename BLayout, + typename CLayout, + int Warps, + int KTilesPerWarp> + void launch_tinygemm_kernel( + const StorageView& A, + const StorageView& B, + const StorageView* qScaleAndZeros, /* optional */ + StorageView& C_final, + int32_t mTiles, + int32_t nTiles, + int32_t kTiles, + int32_t m, + int32_t n, + int32_t k, + cudaStream_t stream) { + // After intra-block reduction across the k dimension, we are left with this + // many tiles + // int32_t postKernelKTiles = kTiles / (Warps * KTilesPerWarp); + int32_t postKernelKTiles = 1; // we loop + + auto grid = dim3(postKernelKTiles, nTiles, mTiles); + auto block = dim3(kWarpSize, Warps); + + auto func = + tinygemm_m16n8k16_chunk_kernel; + + func<<>>( + A.data(), + B.data(), + qScaleAndZeros ? qScaleAndZeros->data() : nullptr, + C_final.data(), + m, + n, + k, + mTiles, + nTiles, + kTiles); + } template void Gemm::compute(const StorageView& a, @@ -870,7 +807,6 @@ void launch_tinygemm_kernel( auto numQGroups = scaleAndZero.dim(0); // Output is a standard row-major matrix c.resize({m, n}); -//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) auto stream = cuda::get_cuda_stream(); #define RUN_GEMM(WARPS, K_TILES_PER_WARP, Q_GROUP_SIZE, REDUCE_TYPE) \ do { \ @@ -970,7 +906,6 @@ void launch_tinygemm_kernel( #undef HANDLE_Q_GROUP #undef RUN_GEMM -//#endif } template <> @@ -992,7 +927,6 @@ void launch_tinygemm_kernel( // 2 k-tiles are a single int32 b.resize({nTiles, kSuperTiles, 32, innerKTiles / 2}); -//#if (defined(CUDA_VERSION) && CUDA_VERSION >= 12000) && (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) auto stream = ctranslate2::cuda::get_cuda_stream(); dim3 grid(kSuperTiles, nTiles); @@ -1024,7 +958,6 @@ void launch_tinygemm_kernel( b.stride(0), b.data()); } -//#endif } #define DECLARE_IMPL(T) \ diff --git a/src/storage_view.cc b/src/storage_view.cc index 1ad45caf3..2728c4480 100644 --- a/src/storage_view.cc +++ b/src/storage_view.cc @@ -3,7 +3,6 @@ #include "ctranslate2/primitives.h" #include "dispatch.h" -#include #define PRINT_MAX_VALUES 6 namespace ctranslate2 { @@ -440,12 +439,6 @@ namespace ctranslate2 { return os; } - template<> - std::ostream& print_value(std::ostream& os, const uint8_t& val) { - os << static_cast(val); - return os; - } - std::ostream& operator<<(std::ostream& os, const StorageView& storage) { StorageView printable(storage.dtype()); printable.copy_from(storage); diff --git a/src/type_dispatch.h b/src/type_dispatch.h index 9f6c87b36..7ecab93b5 100644 --- a/src/type_dispatch.h +++ b/src/type_dispatch.h @@ -41,7 +41,6 @@ namespace ctranslate2 { } MATCH_TYPE_AND_ENUM(float, DataType::FLOAT32); - MATCH_TYPE_AND_ENUM(uint8_t, DataType::UINT8); MATCH_TYPE_AND_ENUM(int8_t, DataType::INT8); MATCH_TYPE_AND_ENUM(int16_t, DataType::INT16); MATCH_TYPE_AND_ENUM(int32_t, DataType::INT32); @@ -61,7 +60,6 @@ namespace ctranslate2 { #define TYPE_DISPATCH(TYPE_ENUM, STMTS) \ switch (TYPE_ENUM) { \ TYPE_CASE(float, SINGLE_ARG(STMTS)) \ - TYPE_CASE(uint8_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int8_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int16_t, SINGLE_ARG(STMTS)) \ TYPE_CASE(int32_t, SINGLE_ARG(STMTS)) \ @@ -71,7 +69,6 @@ namespace ctranslate2 { #define DECLARE_ALL_TYPES(FUNC) \ FUNC(float) \ - FUNC(uint8_t) \ FUNC(int8_t) \ FUNC(int16_t) \ FUNC(int32_t) \ diff --git a/src/types.cc b/src/types.cc index 5ecc23449..b1440216a 100644 --- a/src/types.cc +++ b/src/types.cc @@ -15,8 +15,6 @@ namespace ctranslate2 { switch (type) { case DataType::FLOAT32: return "float32"; - case DataType::UINT8: - return "uint8"; case DataType::INT8: return "int8"; case DataType::INT16: From ae7a1243d6435ebe3e642c369e8e3230cefb99a7 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Thu, 11 Jul 2024 10:33:00 +0200 Subject: [PATCH 09/10] fix stuffs --- include/ctranslate2/types.h | 1 - src/ops/int4gemm_gpu.cu | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ctranslate2/types.h b/include/ctranslate2/types.h index 90a974777..8fa47b7ff 100644 --- a/include/ctranslate2/types.h +++ b/include/ctranslate2/types.h @@ -20,7 +20,6 @@ namespace ctranslate2 { INT32, FLOAT16, BFLOAT16, - UINT8, }; std::string dtype_name(DataType type); diff --git a/src/ops/int4gemm_gpu.cu b/src/ops/int4gemm_gpu.cu index 71c890ba6..5f564031c 100644 --- a/src/ops/int4gemm_gpu.cu +++ b/src/ops/int4gemm_gpu.cu @@ -3,6 +3,7 @@ namespace ctranslate2 { namespace ops { +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) template constexpr __host__ __device__ auto divDown(U a, V b) -> decltype(a + b) { static_assert(std::is_integral::value && std::is_integral::value, ""); @@ -968,5 +969,6 @@ namespace ctranslate2 { StorageView& c) const; DECLARE_IMPL(bfloat16_t) +#endif } } \ No newline at end of file From f80e61f03670552f3f3e4b2535c8f25434fa4663 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Thu, 11 Jul 2024 17:35:26 +0200 Subject: [PATCH 10/10] fix stuffs --- include/ctranslate2/models/model.h | 2 ++ include/ctranslate2/ops/gemm.h | 2 +- python/ctranslate2/specs/model_spec.py | 1 - python/ctranslate2/specs/transformer_spec.py | 2 ++ src/layers/common.cc | 3 +-- src/models/model.cc | 13 +++++++++++++ src/ops/gemm.cc | 2 +- 7 files changed, 20 insertions(+), 5 deletions(-) diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 86d44c1c4..f2ab02db2 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -119,6 +119,8 @@ namespace ctranslate2 { // Copy the model to another device. std::shared_ptr copy_to(Device device, int device_index = 0) const; + template + T get_config_if_exists(const std::string& name) const; const StorageView* get_variable_if_exists(const std::string& name) const; const StorageView& get_variable(const std::string& name) const; std::unordered_map get_variables() const; diff --git a/include/ctranslate2/ops/gemm.h b/include/ctranslate2/ops/gemm.h index 8ecd96083..41cb78fe2 100644 --- a/include/ctranslate2/ops/gemm.h +++ b/include/ctranslate2/ops/gemm.h @@ -19,7 +19,7 @@ namespace ctranslate2 { bool a_is_packed = false, bool b_is_packed = false, const ActivationType* activation_type = nullptr, - int _group_size = 32); + const int _group_size = 0); void operator()(const StorageView& a, const StorageView& b, diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 116bba912..5a086191a 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -318,7 +318,6 @@ def _quantize(spec, name, value): scale = hqq_linear.meta['scale'].cpu() zero = hqq_linear.meta['zero'].cpu() old_shape = hqq_linear.meta['shape'] - new_shape = old_shape[:-1] + (old_shape[-1] // 2,) value = Quantizer.unpack[hqq_linear.meta["packing"]](value) value, scale = self._hqq_quants_to_torch_quants(value, scale, zero, old_shape) diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index abb812c8b..954819cfa 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -233,7 +233,9 @@ def __init__( if quant_type is not None: self._config["quantization_type"] = quant_type + if quant_bits is not None: self._config["quantization_bits"] = quant_bits + if quant_group_size is not None: self._config["quantization_group_size"] = quant_group_size @property diff --git a/src/layers/common.cc b/src/layers/common.cc index 82862eef1..bf1c7b226 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -291,8 +291,7 @@ namespace ctranslate2 { /*a_is_packed=*/false, _packed_weight, _quantized_gemm ? nullptr : activation_type, - (model.config.contains("quantization_group_size") ? - model.config["quantization_group_size"] : nullptr)) + (model.get_config_if_exists("quantization_group_size"))) , _quantize_op(model.use_global_int16_scale() ? ops::Quantize::ScaleType::GLOBAL : ops::Quantize::ScaleType::PER_LAYER, diff --git a/src/models/model.cc b/src/models/model.cc index 672b2c520..c84f085e3 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -244,6 +244,15 @@ namespace ctranslate2 { } } + template + T Model::get_config_if_exists(const std::string &name) const { + T value = 0; + if (config.contains(name)) { + value = config[name]; + } + return value; + } + const StorageView* Model::get_variable_if_exists(const std::string& name) const { auto it = _variable_index.find(name); if (it == _variable_index.end()) @@ -912,6 +921,10 @@ namespace ctranslate2 { return models; } +#define DECLARE_IMPL(T) \ + template T \ + Model::get_config_if_exists(const std::string& name) const; + DECLARE_IMPL(int) } } diff --git a/src/ops/gemm.cc b/src/ops/gemm.cc index 62d758925..465d3af4a 100644 --- a/src/ops/gemm.cc +++ b/src/ops/gemm.cc @@ -26,7 +26,7 @@ namespace ctranslate2 { bool a_is_packed, bool b_is_packed, const ActivationType* activation_type, - int group_size) + const int group_size) : _alpha(alpha) , _beta(beta) , _trans_a(trans_a)