From 4bfc6b334978c4fd4026ca5f17a54dbfb0c3dfe7 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Tue, 23 May 2023 14:55:39 +0200 Subject: [PATCH 1/5] Keep FFN output layer in float32 for T5 models --- include/ctranslate2/models/model.h | 3 +++ python/ctranslate2/converters/transformers.py | 2 ++ python/ctranslate2/specs/common_spec.py | 1 + python/ctranslate2/specs/model_spec.py | 6 ++++++ src/layers/common.cc | 3 +++ src/models/model.cc | 14 ++++++++++++++ 6 files changed, 29 insertions(+) diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 923b978a2..4b26c2111 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -133,6 +133,9 @@ namespace ctranslate2 { // Returns true if the variable can be converted to another type. virtual bool is_convertible(const StorageView& variable, const std::string& name) const; + // Returns true if the variable should be kept in float32 precision. + virtual bool keep_in_float32(const std::string& variable_name) const; + // Models can override these methods to execute some transformations if needed // (e.g. a variable name changed in a newer spec revision). virtual void register_variable(std::string name, StorageView variable); diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 0c62bce34..394d922a5 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -955,6 +955,8 @@ def set_ffn(self, spec, module): self.set_linear(spec.linear_0, module.DenseReluDense.wi) self.set_linear(spec.linear_1, module.DenseReluDense.wo) + spec.linear_1.keep_in_float32 = True + self.set_layer_norm(spec.layer_norm, module.layer_norm) def set_self_attention(self, spec, module): diff --git a/python/ctranslate2/specs/common_spec.py b/python/ctranslate2/specs/common_spec.py index 6f0189650..f026cbf16 100644 --- a/python/ctranslate2/specs/common_spec.py +++ b/python/ctranslate2/specs/common_spec.py @@ -36,6 +36,7 @@ def __init__(self): self.weight = None self.weight_scale = model_spec.OPTIONAL self.bias = model_spec.OPTIONAL + self.keep_in_float32 = False def has_bias(self): return isinstance(self.bias, np.ndarray) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index fbb45b925..e14607501 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -166,6 +166,12 @@ def _quantize(spec, name, value): return key = _split_scope(name)[-1] + + if getattr(spec, "keep_in_float32", False): + if value.dtype == np.float16: + setattr(spec, key, value.astype(np.float32)) + return + scale = None is_quantizable = hasattr(spec, "%s_scale" % key) diff --git a/src/layers/common.cc b/src/layers/common.cc index 679ddf754..59e82486e 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -353,6 +353,9 @@ namespace ctranslate2 { /*trans_b=*/true, output, bias); + } else if (input.dtype() != weight->dtype()) { + _gemm_op(input.to(weight->dtype()), *weight, output, nullptr, bias); + output = output.to(input.dtype()); } else { _gemm_op(input, *weight, output, nullptr, bias); } diff --git a/src/models/model.cc b/src/models/model.cc index c88943826..32ea8bff5 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -170,6 +170,9 @@ namespace ctranslate2 { const auto& name = variable_pair.first; auto& variable = *variable_pair.second; + if (keep_in_float32(name)) + continue; + // Convert "weight" variables to the expected compute type. // Other float variables (e.g. biases) may be converted from or to float16. if (is_quantizable(name)) @@ -249,6 +252,15 @@ namespace ctranslate2 { return !variable.is_scalar() && name.find("_scale") == std::string::npos; } + bool Model::keep_in_float32(const std::string& variable_name) const { + const size_t pos = variable_name.rfind('/'); + if (pos == std::string::npos) + return false; + + const std::string scope = variable_name.substr(0, pos); + return get_flag_with_default(scope + "/keep_in_float32", false); + } + void Model::ensure_dtype(const std::string& name, StorageView& variable, const DataType target_dtype) { @@ -330,6 +342,8 @@ namespace ctranslate2 { for (const auto& variable_pair : _variable_index) { const std::string& name = variable_pair.first; const StorageView& variable = *variable_pair.second; + if (keep_in_float32(name)) + continue; if (is_quantizable(name)) { weight_type = variable.dtype(); } else if (is_convertible(variable, name)) { From 0265b9c98ec79c20bc9cf314e1059f22c14c19a1 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 24 May 2023 12:44:42 +0200 Subject: [PATCH 2/5] Fix dtype error in float16 --- src/layers/common.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/common.cc b/src/layers/common.cc index 59e82486e..de183ed89 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -354,8 +354,9 @@ namespace ctranslate2 { output, bias); } else if (input.dtype() != weight->dtype()) { - _gemm_op(input.to(weight->dtype()), *weight, output, nullptr, bias); - output = output.to(input.dtype()); + StorageView tmp_output(_weight.dtype(), device); + _gemm_op(input.to(weight->dtype()), *weight, tmp_output, nullptr, bias); + output = tmp_output.to(output.dtype()); } else { _gemm_op(input, *weight, output, nullptr, bias); } From afd9d5f56c4509acfcb6e583be06c664c85b8118 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 24 May 2023 14:19:20 +0200 Subject: [PATCH 3/5] Fix compilation --- src/layers/common.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/common.cc b/src/layers/common.cc index de183ed89..b15c531b3 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -354,7 +354,7 @@ namespace ctranslate2 { output, bias); } else if (input.dtype() != weight->dtype()) { - StorageView tmp_output(_weight.dtype(), device); + StorageView tmp_output(weight->dtype(), weight->device()); _gemm_op(input.to(weight->dtype()), *weight, tmp_output, nullptr, bias); output = tmp_output.to(output.dtype()); } else { From 0e6db4fa51d29a565da1a7b225a8e8cbfbdd9a68 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 24 Jul 2023 13:16:00 +0200 Subject: [PATCH 4/5] Fix condition --- python/ctranslate2/specs/model_spec.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 01c46dbb5..ea6a3a4da 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -180,16 +180,15 @@ def _quantize(spec, name, value): key = _split_scope(name)[-1] - if getattr(spec, "keep_in_float32", False): - if value.dtype == np.float16: - setattr(spec, key, value.astype(np.float32)) - return - scale = None is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16") - if is_quantizable: + if hasattr(spec, "keep_in_float32") and spec.keep_in_float32.numpy(): + if is_convertible: + value = value.to("float32") + + elif is_quantizable: if quantization == "int16": value = value.to("float32").numpy() # Represent the value with 10 bits so the multiplication is 20 bits From c65323c493fed3f41e072406ee7af3b8bfe1630b Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 24 Jul 2023 13:20:02 +0200 Subject: [PATCH 5/5] Cleanup diff --- python/ctranslate2/specs/model_spec.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index ea6a3a4da..77ec49ed4 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -179,7 +179,6 @@ def _quantize(spec, name, value): return key = _split_scope(name)[-1] - scale = None is_quantizable = hasattr(spec, "%s_scale" % key) is_convertible = value.dtype in ("float32", "float16", "bfloat16")