Skip to content

Commit

Permalink
Support BLOOM language models (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Mar 15, 2023
1 parent 5f9aac6 commit cbdd540
Show file tree
Hide file tree
Showing 10 changed files with 222 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ CTranslate2 is a C++ and Python library for efficient inference with Transformer
The project implements a custom runtime that applies many performance optimization techniques such as weights quantization, layers fusion, batch reordering, etc., to [accelerate and reduce the memory usage](#benchmarks) of Transformer models on CPU and GPU. The following model types are currently supported:

* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
* Decoder-only models: GPT-2, OPT
* Decoder-only models: GPT-2, OPT, BLOOM

Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks:

Expand Down
24 changes: 24 additions & 0 deletions docs/guides/transformers.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CTranslate2 supports selected models from Hugging Face's [Transformers](https://github.com/huggingface/transformers). The following models are currently supported:

* BART
* BLOOM
* M2M100
* MarianMT
* MBART
Expand Down Expand Up @@ -67,6 +68,29 @@ target = results[0].hypotheses[0]
print(tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True))
```

## BLOOM

[BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom) is a collection of multilingual language models trained by the [BigScience workshop](https://bigscience.huggingface.co/).

This example uses the smallest model with 560M parameters.

```bash
ct2-transformers-converter --model bigscience/bloom-560m --output_dir bloom-560m
```

```python
import ctranslate2
import transformers

generator = ctranslate2.Generator("bloom-560m")
tokenizer = transformers.AutoTokenizer.from_pretrained("bigscience/bloom-560m")

text = "Hello, I am"
start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))
results = generator.generate_batch([start_tokens], max_length=30, sampling_topk=10)
print(tokenizer.decode(results[0].sequences_ids[0]))
```

## MarianMT

This example uses the English-German model from [MarianMT](https://huggingface.co/docs/transformers/model_doc/marian).
Expand Down
9 changes: 8 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ namespace ctranslate2 {
dim_t max_position,
bool with_cache = false);

StorageView build_alibi(dim_t batch_size,
dim_t num_heads,
dim_t query_max_length,
dim_t key_max_length,
const StorageView* key_lengths = nullptr);

class MultiHeadAttention : public Layer
{
public:
Expand All @@ -30,7 +36,8 @@ namespace ctranslate2 {
StorageView* attention = nullptr,
const Padder* queries_padder = nullptr,
const Padder* values_padder = nullptr,
bool return_normalized_attention = true) const;
bool return_normalized_attention = true,
const StorageView* alibi = nullptr) const;

bool has_relative_position() const {
return _relative_position_keys || _relative_attention_bias;
Expand Down
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ namespace ctranslate2 {
StorageView* attention = nullptr,
const Padder* input_padder = nullptr,
const Padder* memory_padder = nullptr,
bool return_normalized_attention = true) const;
bool return_normalized_attention = true,
const StorageView* alibi = nullptr) const;

DataType output_type() const override {
return _ff.output_type();
Expand Down Expand Up @@ -190,6 +191,7 @@ namespace ctranslate2 {
const ComputeType _compute_type;
const Embeddings _embeddings;
const bool _start_from_zero_embedding;
const bool _use_alibi;
const std::unique_ptr<const StorageView> _embeddings_scale;
std::unique_ptr<const StorageView> _outputs_scale;
const std::unique_ptr<const LayerNorm> _layernorm_embedding;
Expand Down
77 changes: 77 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,83 @@ def architecture_name(self):
return "MT5ForConditionalGeneration"


@register_loader("BloomConfig")
class BloomLoader(ModelLoader):
@property
def architecture_name(self):
return "BloomForCausalLM"

def get_model_spec(self, model):
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
model.config.n_layer,
model.config.n_head,
pre_norm=True,
activation=common_spec.Activation.GELUTanh,
layernorm_embedding=True,
alibi=True,
)

self.set_decoder(spec.decoder, model.transformer)
self.set_linear(spec.decoder.projection, model.lm_head)
return spec

def get_vocabulary(self, model, tokenizer):
tokens = super().get_vocabulary(model, tokenizer)

extra_ids = model.config.vocab_size - len(tokens)
for i in range(extra_ids):
tokens.append("<extra_id_%d>" % i)

return tokens

def set_vocabulary(self, spec, tokens):
spec.register_vocabulary(tokens)

def set_config(self, config, model, tokenizer):
config.bos_token = tokenizer.bos_token
config.eos_token = tokenizer.eos_token
config.unk_token = tokenizer.unk_token

def set_decoder(self, spec, module):
spec.scale_embeddings = False
self.set_embeddings(spec.embeddings, module.word_embeddings)
self.set_layer_norm(spec.layernorm_embedding, module.word_embeddings_layernorm)
self.set_layer_norm(spec.layer_norm, module.ln_f)

for layer_spec, layer in zip(spec.layer, module.h):
self.set_layer_norm(
layer_spec.self_attention.layer_norm, layer.input_layernorm
)
self.set_qkv_linear(
layer_spec.self_attention.linear[0],
layer.self_attention.query_key_value,
layer.self_attention.num_heads,
)
self.set_linear(
layer_spec.self_attention.linear[1], layer.self_attention.dense
)

self.set_layer_norm(
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
)
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.dense_h_to_4h)
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.dense_4h_to_h)

def set_qkv_linear(self, spec, module, num_heads):
weight = module.weight
weight = weight.reshape(num_heads, 3, -1, weight.shape[-1])
weight = weight.transpose(0, 1)
weight = weight.reshape(-1, weight.shape[-1])

bias = module.bias
bias = bias.reshape(num_heads, 3, -1)
bias = bias.transpose(0, 1)
bias = bias.reshape(-1)

spec.weight = weight.numpy()
spec.bias = bias.numpy()


def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand Down
6 changes: 6 additions & 0 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
alignment_heads: int = 1,
ffn_glu: bool = False,
rms_norm: bool = False,
alibi: bool = False,
):
"""Initializes a Transformer decoder specification.
Expand All @@ -107,6 +108,7 @@ def __init__(
ffn_glu: Use gated linear units in the FFN layers as described in
https://arxiv.org/abs/2002.05202.
rms_norm: Use the root mean square layer normalization.
alibi: Use attention with linear biases.
"""
self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
Expand All @@ -116,6 +118,7 @@ def __init__(
self.embeddings = common_spec.EmbeddingsSpec()
self.scale_embeddings = True
self.scale_outputs = model_spec.OPTIONAL
self.alibi = alibi
if not relative_position and not relative_attention_bias:
self.position_encodings = PositionEncoderSpec()
if pre_norm and not no_final_norm:
Expand Down Expand Up @@ -335,6 +338,7 @@ def from_config(
no_final_norm: bool = False,
project_in_out: bool = False,
with_relative_position: bool = False,
alibi: bool = False,
):
"""Creates a Transformer decoder model specification.
Expand All @@ -348,6 +352,7 @@ def from_config(
project_in_out: Add a linear layer after the embedding layer and another one
before the final output projection.
with_relative_position: Enable relative position representations modules.
alibi: Use attention with linear biases.
"""
decoder = TransformerDecoderSpec(
num_layers,
Expand All @@ -359,6 +364,7 @@ def from_config(
no_final_norm=no_final_norm,
project_in_out=project_in_out,
relative_position=with_relative_position,
alibi=alibi,
)

return cls(decoder)
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def test_transformers_translation(
100,
"Hello <|endoftext|> Hello Ġ! Ġ: D",
),
(
"bigscience/bloom-560m",
"Hello , ĠI Ġam",
20,
"Hello , ĠI Ġam Ġa Ġnew bie Ġin Ġthe Ġworld Ġof Ġweb Ġdesign Ġand ĠI Ġam "
"Ġlooking Ġfor Ġa Ġweb Ġdeveloper",
),
]


Expand Down
49 changes: 48 additions & 1 deletion src/layers/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,47 @@ namespace ctranslate2 {
return values_t;
}

StorageView build_alibi(dim_t batch_size,
dim_t num_heads,
dim_t query_max_length,
dim_t key_max_length,
const StorageView* key_lengths) {
const float closest_power_of_2_f = std::pow(2.f, std::floor(std::log2f(num_heads)));
const dim_t closest_power_of_2 = closest_power_of_2_f;

const float base = std::pow(2.f, -std::pow(2.f, -(std::log2f(closest_power_of_2_f) - 3.f)));

std::vector<float> slopes;
slopes.reserve(closest_power_of_2);
for (dim_t power = 1; power <= closest_power_of_2; ++power)
slopes.emplace_back(std::pow(base, float(power)));

if (closest_power_of_2 != num_heads) {
const float extra_base = (
std::pow(2.f, -std::pow(2.f, -(std::log2f(2 * closest_power_of_2_f) - 3.f))));
const dim_t num_remaining_heads = std::min(
closest_power_of_2, num_heads - closest_power_of_2);

for (dim_t power = 1; power <= 2 * num_remaining_heads; power += 2)
slopes.emplace_back(std::pow(extra_base, float(power)));
}

StorageView alibi({batch_size, num_heads, query_max_length, key_max_length});

for (dim_t b = 0; b < batch_size; ++b) {
for (dim_t h = 0; h < num_heads; ++h) {
for (dim_t q = 0; q < query_max_length; ++q) {
for (dim_t k = 0; k < key_max_length; ++k) {
dim_t length = key_lengths ? key_lengths->scalar_at<int32_t>({b}) : key_max_length;
alibi.at<float>({b, h, q, k}) = k >= length ? 0 : float(k) * slopes[h];
}
}
}
}

return alibi;
}

static void matmul_with_relative_representations(const ops::MatMul& matmul_op,
const StorageView& a,
const StorageView& b,
Expand Down Expand Up @@ -155,6 +196,7 @@ namespace ctranslate2 {
static void dot_product_attention(const StorageView& queries,
const StorageView& keys,
const StorageView& values,
const StorageView* alibi,
const StorageView* values_lengths,
const StorageView* relative_position_keys,
const StorageView* relative_position_values,
Expand Down Expand Up @@ -204,6 +246,9 @@ namespace ctranslate2 {
output.size()));
}

if (alibi)
ops::Add()(output, *alibi, output);

StorageView attn(values.dtype(), values.device());
ops::SoftMax()(output, values_lengths, attn);

Expand Down Expand Up @@ -330,7 +375,8 @@ namespace ctranslate2 {
StorageView* attention,
const Padder* queries_padder,
const Padder* values_padder,
bool return_normalized_attention) const {
bool return_normalized_attention,
const StorageView* alibi) const {
PROFILE("MultiHeadAttention");
const Device device = queries.device();
const DataType dtype = queries.dtype();
Expand Down Expand Up @@ -394,6 +440,7 @@ namespace ctranslate2 {
dot_product_attention(queries_proj,
keys_proj,
values_proj,
alibi,
values_lengths,
_relative_position_keys,
_relative_position_values,
Expand Down
20 changes: 16 additions & 4 deletions src/layers/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ namespace ctranslate2 {
StorageView* attention,
const Padder* input_padder,
const Padder* memory_padder,
bool return_normalized_attention) const {
bool return_normalized_attention,
const StorageView* alibi) const {
PROFILE("TransformerDecoderLayer");
_self_attention(input,
input,
Expand All @@ -109,7 +110,9 @@ namespace ctranslate2 {
cached_self_attn_values,
nullptr,
input_padder,
input_padder);
input_padder,
true,
alibi);

StorageView context(input.dtype(), input.device());
if (_encoder_attention) {
Expand Down Expand Up @@ -235,6 +238,7 @@ namespace ctranslate2 {
, _embeddings(model, scope + "/embeddings")
, _start_from_zero_embedding(model.get_flag_with_default(scope + "/start_from_zero_embedding",
false))
, _use_alibi(model.get_flag_with_default(scope + "/alibi", false))
, _embeddings_scale(build_embeddings_scale(model, scope, _embeddings))
, _layernorm_embedding(build_optional_layer<LayerNorm>(model, scope + "/layernorm_embedding"))
, _output_norm(build_optional_layer<LayerNorm>(model, scope + "/layer_norm"))
Expand All @@ -246,7 +250,7 @@ namespace ctranslate2 {
_num_heads,
model.get_flag_with_default(scope + "/pre_norm", true),
model.get_enum_value<ops::ActivationType>(scope + "/activation")))
, _position_encoder(_layers.front()->has_relative_position()
, _position_encoder(_layers.front()->has_relative_position() || _use_alibi
? nullptr
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
, _with_encoder_attention(_layers.front()->has_cross_attention())
Expand Down Expand Up @@ -434,6 +438,13 @@ namespace ctranslate2 {
}
}

std::unique_ptr<StorageView> alibi;
if (_use_alibi) {
alibi = std::make_unique<StorageView>(
build_alibi(batch_size, _num_heads, max_time, step > 0 ? step + 1 : max_time, lengths));
alibi->move_to(device, dtype);
}

std::vector<StorageView> alignment_heads;
if (attention)
alignment_heads.reserve(_layers.size());
Expand Down Expand Up @@ -471,7 +482,8 @@ namespace ctranslate2 {
layer_attention.get(),
input_padder.get(),
memory_padder.get(),
return_normalized_attention());
return_normalized_attention(),
alibi.get());
layer_in = std::move(layer_out);

if (layer_attention) {
Expand Down
Loading

0 comments on commit cbdd540

Please sign in to comment.