diff --git a/include/ctranslate2/models/language_model.h b/include/ctranslate2/models/language_model.h index 7532b9a3a..04c083da9 100644 --- a/include/ctranslate2/models/language_model.h +++ b/include/ctranslate2/models/language_model.h @@ -22,6 +22,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 43a4ea5b9..2e8e9baf2 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -31,6 +31,17 @@ namespace ctranslate2 { Device device = Device::CPU, int device_index = 0, ComputeType compute_type = ComputeType::DEFAULT); + static std::shared_ptr load(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& alias, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + Device device = Device::CPU, + int device_index = 0, + ComputeType compute_type = ComputeType::DEFAULT); + virtual std::unique_ptr as_sequence_to_sequence() const; virtual std::unique_ptr as_sequence_generator() const; @@ -86,6 +97,10 @@ namespace ctranslate2 { return ScopedDeviceSetter(_device, _device_index); } + void set_config(const std::string& config_str); + void set_revision(const size_t revision); + void set_binary_version(const size_t binary_version); + // If the model contains variables, they will be moved to the new device. void set_device(const Device device, const int index = 0); @@ -143,6 +158,7 @@ namespace ctranslate2 { // Runs some initialization after the model is loaded. virtual void initialize(ModelReader&) {} + virtual void initialize(std::unordered_map>&) {} virtual std::unique_ptr clone() const = 0; diff --git a/include/ctranslate2/models/sequence_to_sequence.h b/include/ctranslate2/models/sequence_to_sequence.h index e1d79327f..125d0d77f 100644 --- a/include/ctranslate2/models/sequence_to_sequence.h +++ b/include/ctranslate2/models/sequence_to_sequence.h @@ -32,6 +32,7 @@ namespace ctranslate2 { protected: virtual void initialize(ModelReader& model_reader) override; + virtual void initialize(std::unordered_map>& vocabularies) override; private: std::vector> _source_vocabularies; diff --git a/include/ctranslate2/models/transformer.h b/include/ctranslate2/models/transformer.h index 4e97f85e0..74ee851c5 100644 --- a/include/ctranslate2/models/transformer.h +++ b/include/ctranslate2/models/transformer.h @@ -34,6 +34,7 @@ namespace ctranslate2 { protected: bool is_linear_weight(const std::string& variable_name) const override; void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; std::unique_ptr clone() const override; }; diff --git a/include/ctranslate2/models/wav2vec2.h b/include/ctranslate2/models/wav2vec2.h index d1034ef88..5427d2631 100644 --- a/include/ctranslate2/models/wav2vec2.h +++ b/include/ctranslate2/models/wav2vec2.h @@ -41,6 +41,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; }; diff --git a/include/ctranslate2/models/whisper.h b/include/ctranslate2/models/whisper.h index 7ade2bd20..c3ccb873b 100644 --- a/include/ctranslate2/models/whisper.h +++ b/include/ctranslate2/models/whisper.h @@ -90,6 +90,7 @@ namespace ctranslate2 { protected: void initialize(ModelReader& model_reader) override; + void initialize(std::unordered_map>& vocabularies) override; private: std::shared_ptr _vocabulary; diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 981c6da68..2f1cc53b3 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -1,6 +1,7 @@ #include "module.h" #include +#include #include "replica_pool.h" @@ -158,6 +159,44 @@ namespace ctranslate2 { :obj:`model_path` acts as an identifier for this model. )pbdoc") + .def(py::init&, + std::unordered_map>&, std::unordered_map&, const std::string&, const std::string&, const std::variant>, const StringOrMap&, size_t, size_t, long>(), + py::arg("spec"), + py::arg("spec_revision"), + py::arg("binary_version"), + py::arg("aliases"), + py::arg("vocabularies"), + py::arg("variables"), + py::arg("config"), + py::arg("device")="cpu", + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + R"pbdoc( + Initializes the generator. + + Arguments: + spec: The name of the model specification. + spec_revision: The model specification revision. + binary_version: The version of binary model + aliases: aliases got in the mode + vocabularies: dictionary of name and list of tokens + variables: dictionary of name of variables and storage view of variable + config: list of config (normally saved in config.json) + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + )pbdoc") + .def_property_readonly("device", &GeneratorWrapper::device, "Device this generator is running on.") .def_property_readonly("device_index", &GeneratorWrapper::device_index, diff --git a/python/cpp/replica_pool.h b/python/cpp/replica_pool.h index a735ea363..4a26d7459 100644 --- a/python/cpp/replica_pool.h +++ b/python/cpp/replica_pool.h @@ -1,7 +1,11 @@ #pragma once #include +#include +#include +#include +#include #include "utils.h" namespace ctranslate2 { @@ -49,15 +53,53 @@ namespace ctranslate2 { { pybind11::gil_scoped_release nogil; - _model_loader.device = str_to_device(device); - _model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index); - _model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type); - _model_loader.num_replicas_per_device = inter_threads; + _model_loader->device = str_to_device(device); + _model_loader->device_indices = std::visit(DeviceIndexResolver(), device_index); + _model_loader->compute_type = std::visit(ComputeTypeResolver(device), compute_type); + _model_loader->num_replicas_per_device = inter_threads; _pool_config.num_threads_per_replica = intra_threads; _pool_config.max_queued_batches = max_queued_batches; - _pool = std::make_unique(_model_loader, _pool_config); + _pool = std::make_unique(_model_loader.value(), _pool_config); + } + + ReplicaPoolHelper(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& aliases, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + const std::string& device, + const std::variant>& device_index, + const StringOrMap& compute_type, + size_t ,//inter_threads + size_t intra_threads, + long max_queued_batches) + { + pybind11::gil_scoped_release nogil; + + // Load the variables. + auto model_device = str_to_device(device); + auto model_device_indices = std::visit(DeviceIndexResolver(), device_index)[0]; + auto model_compute_type = std::visit(ComputeTypeResolver(device), compute_type); + + auto model = models::Model::load(spec, + spec_version, + binary_version, + aliases, + vocabularies, + variables, + config, + model_device, + model_device_indices, + model_compute_type); + + _pool_config.num_threads_per_replica = intra_threads; + _pool_config.max_queued_batches = max_queued_batches; + + _pool = std::make_unique(model, _pool_config); } ~ReplicaPoolHelper() { @@ -66,11 +108,19 @@ namespace ctranslate2 { } std::string device() const { - return device_to_str(_model_loader.device); + if (_model_loader.has_value()) + return device_to_str(_model_loader->device); + if (_device) + return _device.value(); + return ""; } const std::vector& device_index() const { - return _model_loader.device_indices; + if (_model_loader.has_value()) + return _model_loader->device_indices; + if (!_device_index.has_value() || _device_index->empty()) + throw pybind11::type_error("No device index found"); + return _device_index.value(); } std::string compute_type() const { @@ -91,7 +141,9 @@ namespace ctranslate2 { protected: std::unique_ptr _pool; - models::ModelLoader _model_loader; + std::optional _model_loader; + std::optional _device; + std::optional> _device_index; ReplicaPoolConfig _pool_config; const std::shared_ptr& model() const { diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc index d920469fe..a544d6a25 100644 --- a/python/cpp/translator.cc +++ b/python/cpp/translator.cc @@ -42,9 +42,9 @@ namespace ctranslate2 { intra_threads, max_queued_batches, files) - , _device(_model_loader.device) - , _device_index(_model_loader.device_indices) - , _num_replicas_per_device(_model_loader.num_replicas_per_device) + , _device(_model_loader->device) + , _device_index(_model_loader->device_indices) + , _num_replicas_per_device(_model_loader->num_replicas_per_device) , _model_is_loaded(true) { } @@ -324,7 +324,7 @@ namespace ctranslate2 { return; if (_cached_models.empty()) { - _cached_models = _model_loader.load(); + _cached_models = _model_loader->load(); } else { move_cached_models(_device, _device_index, _num_replicas_per_device); } diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index 9c0efac2a..8d453ee2a 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -39,6 +39,7 @@ set_random_seed, ) from ctranslate2.extensions import register_extensions + from ctranslate2.generator_on_the_fly import GeneratorOnTheFly from ctranslate2.logging import get_log_level, set_log_level register_extensions() diff --git a/python/ctranslate2/converters/converter.py b/python/ctranslate2/converters/converter.py index ecede044a..acc182efe 100644 --- a/python/ctranslate2/converters/converter.py +++ b/python/ctranslate2/converters/converter.py @@ -104,6 +104,42 @@ def convert( model_spec.save(output_dir) return output_dir + def convert_on_the_fly( + self, + vmap: Optional[str] = None, + quantization: Optional[str] = None, + ) -> ModelSpec: + """Converts the model to the CTranslate2 format. + + Arguments: + vmap: Optional path to a vocabulary mapping file that will be included + in the converted model directory. + quantization: Weight quantization scheme (possible values are: int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + + Returns: + Path to the output directory. + + Raises: + RuntimeError: If the output directory already exists and :obj:`force` + is not set. + NotImplementedError: If the converter cannot convert this model to the + CTranslate2 format. + """ + model_spec = self._load() + if model_spec is None: + raise NotImplementedError( + "This model is not supported by CTranslate2 or this converter" + ) + if vmap is not None: + model_spec.register_vocabulary_mapping(vmap) + + model_spec.validate() + model_spec.optimize(quantization=quantization) + # model_spec.save(output_dir, False) + + return model_spec + @abc.abstractmethod def _load(self): raise NotImplementedError() diff --git a/python/ctranslate2/extensions.py b/python/ctranslate2/extensions.py index b6d9fd4b5..e6d9de67a 100644 --- a/python/ctranslate2/extensions.py +++ b/python/ctranslate2/extensions.py @@ -14,6 +14,7 @@ TranslationResult, Translator, ) +from ctranslate2.generator_on_the_fly import GeneratorOnTheFly def register_extensions(): @@ -25,6 +26,16 @@ def register_extensions(): setattr(Generator, "score_iterable", generator_score_iterable) setattr(Generator, "generate_tokens", generator_generate_tokens) setattr(Generator, "async_generate_tokens", generator_async_generate_tokens) + setattr( + GeneratorOnTheFly, "generate_iterable", generator_generate_iterable_on_the_fly + ) + setattr(GeneratorOnTheFly, "score_iterable", generator_score_iterable_on_the_fly) + setattr(GeneratorOnTheFly, "generate_tokens", generator_generate_tokens_on_the_fly) + setattr( + GeneratorOnTheFly, + "async_generate_tokens", + generator_async_generate_tokens_on_the_fly, + ) def translator_translate_iterable( @@ -430,6 +441,244 @@ async def generator_async_generate_tokens( yield step_result +def generator_generate_tokens_on_the_fly( + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, +) -> Iterable[GenerationStepResult]: + """Yields tokens as they are generated by the model. + + Arguments: + prompt: Batch of start tokens. If the decoder starts from a + special start token like , this token should be added to this input. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + max_length: Maximum generation length. + min_length: Minimum generation length. + sampling_topk: Randomly sample predictions from the top K candidates. + sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. + sampling_temperature: Sampling temperature to generate more random samples. + return_log_prob: Include the token log probability in the result. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size + (set 0 to disable). + disable_unk: Disable the generation of the unknown token. + suppress_sequences: Disable the generation of some sequences of tokens. + end_token: Stop the decoding on one these tokens (defaults to the model EOS token). + static_prompt: If the model expects a static prompt (a.k.a. system prompt) + it can be set here to simplify the inputs and optionally cache the model + state for this prompt to accelerate future generations. + cache_static_prompt: Cache the model state after the static prompt and + reuse it for future generations using the same static prompt. + callback: Optional function that is called for each generated token when + obj:`beam_size` is 1. If the callback function returns ``True``, the + decoding will stop for this batch index. + + Returns: + A generator iterator over :class:`ctranslate2.GenerationStepResult` instances. + + Note: + This generation method is not compatible with beam search which requires a complete decoding. + """ + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt = [prompt] + + yield from _generate_tokens( + generator.generate_batch, + prompt, + max_batch_size=max_batch_size, + batch_type=batch_type, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + disable_unk=disable_unk, + suppress_sequences=suppress_sequences, + end_token=end_token, + max_length=max_length, + min_length=min_length, + sampling_topk=sampling_topk, + sampling_topp=sampling_topp, + sampling_temperature=sampling_temperature, + return_scores=return_log_prob, + static_prompt=static_prompt, + cache_static_prompt=cache_static_prompt, + include_prompt_in_result=False, + callback=callback, + ) + + +def generator_generate_iterable_on_the_fly( + generator: GeneratorOnTheFly, + start_tokens: Iterable[List[str]], + max_batch_size: int = 32, + batch_type: str = "examples", + **kwargs, +) -> Iterable[GenerationResult]: + """Generates from an iterable of tokenized prompts. + + This method is built on top of :meth:`ctranslate2.Generator.generate_batch` + to efficiently run generation on an arbitrarily large stream of data. It enables + the following optimizations: + + * stream processing (the iterable is not fully materialized in memory) + * parallel generations (if the generator has multiple workers) + * asynchronous batch prefetching + * local sorting by length + + Arguments: + start_tokens: An iterable of tokenized prompts. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + **kwargs: Any generation options accepted by + :meth:`ctranslate2.Generator.generate_batch`. + + Returns: + A generator iterator over :class:`ctranslate2.GenerationResult` instances. + """ + yield from _process_iterable( + generator.generate_batch, + [start_tokens], + max_batch_size, + batch_type, + **kwargs, + ) + + +def generator_score_iterable_on_the_fly( + generator: GeneratorOnTheFly, + tokens: Iterable[List[str]], + max_batch_size: int = 64, + batch_type: str = "examples", + **kwargs, +) -> Iterable[ScoringResult]: + """Scores an iterable of tokenized examples. + + This method is built on top of :meth:`ctranslate2.Generator.score_batch` + to efficiently score an arbitrarily large stream of data. It enables + the following optimizations: + + * stream processing (the iterable is not fully materialized in memory) + * parallel scoring (if the generator has multiple workers) + * asynchronous batch prefetching + * local sorting by length + + Arguments: + tokens: An iterable of tokenized examples. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + **kwargs: Any score options accepted by + :meth:`ctranslate2.Generator.score_batch`. + + Returns: + A generator iterator over :class:`ctranslate2.ScoringResult` instances. + """ + yield from _process_iterable( + generator.score_batch, + [tokens], + max_batch_size, + batch_type, + **kwargs, + ) + + +async def generator_async_generate_tokens_on_the_fly( + generator: GeneratorOnTheFly, + prompt: Union[List[str], List[List[str]]], + max_batch_size: int = 0, + batch_type: str = "examples", + *, + max_length: int = 512, + min_length: int = 0, + sampling_topk: int = 1, + sampling_topp: float = 1, + sampling_temperature: float = 1, + return_log_prob: bool = False, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + disable_unk: bool = False, + suppress_sequences: Optional[List[List[str]]] = None, + end_token: Optional[Union[str, List[str], List[int]]] = None, + static_prompt: Optional[List[str]] = None, + cache_static_prompt: bool = True, + callback: Callable[[GenerationStepResult], bool] = None, +) -> AsyncIterable[GenerationStepResult]: + """Yields tokens asynchronously as they are generated by the model. + + Arguments: + prompt: Batch of start tokens. If the decoder starts from a + special start token like , this token should be added to this input. + max_batch_size: The maximum batch size. + batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". + max_length: Maximum generation length. + min_length: Minimum generation length. + sampling_topk: Randomly sample predictions from the top K candidates. + sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. + sampling_temperature: Sampling temperature to generate more random samples. + return_log_prob: Include the token log probability in the result. + repetition_penalty: Penalty applied to the score of previously generated tokens + (set > 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size + (set 0 to disable). + disable_unk: Disable the generation of the unknown token. + suppress_sequences: Disable the generation of some sequences of tokens. + end_token: Stop the decoding on one of these tokens (defaults to the model EOS token). + static_prompt: If the model expects a static prompt (a.k.a. system prompt) + it can be set here to simplify the inputs and optionally cache the model + state for this prompt to accelerate future generations. + cache_static_prompt: Cache the model state after the static prompt and + reuse it for future generations using the same static prompt. + callback: Optional function that is called for each generated token when + obj:`beam_size` is 1. If the callback function returns ``True``, the + decoding will stop for this batch index. + + Returns: + An async generator iterator over :class:`ctranslate2.GenerationStepResult` instances. + + Note: + This generation method is not compatible with beam search which requires a complete decoding. + """ + if len(prompt) > 0 and isinstance(prompt[0], str): + prompt = [prompt] + async for step_result in AsyncGenerator( + generator.generate_batch, + prompt, + max_batch_size=max_batch_size, + batch_type=batch_type, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + disable_unk=disable_unk, + suppress_sequences=suppress_sequences, + end_token=end_token, + max_length=max_length, + min_length=min_length, + sampling_topk=sampling_topk, + sampling_topp=sampling_topp, + sampling_temperature=sampling_temperature, + return_scores=return_log_prob, + static_prompt=static_prompt, + cache_static_prompt=cache_static_prompt, + include_prompt_in_result=False, + callback=callback, + ): + yield step_result + + class AsyncGenerator: def __init__(self, process_func, *args, **kwargs): self.queue = asyncio.Queue() diff --git a/python/ctranslate2/generator_on_the_fly.py b/python/ctranslate2/generator_on_the_fly.py new file mode 100644 index 000000000..f408b2ff4 --- /dev/null +++ b/python/ctranslate2/generator_on_the_fly.py @@ -0,0 +1,97 @@ +import json +import os + +from typing import Optional + +import ctranslate2 + +from ctranslate2.converters.opennmt_py import OpenNMTPyConverter + + +def _get_converter(model_path: str, model_type: str): + if model_type == "OpenNMTPy": + if not os.path.exists(model_path): + raise RuntimeError("No model opennmt-py found in %s" % model_path) + + converter = OpenNMTPyConverter(model_path=model_path) + return converter + else: + raise NotImplementedError( + "Converter on the fly for %s is not implemented." % model_type + ) + + +class GeneratorOnTheFly: + """Initializes the generator on the fly. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this generator on. + compute_type: Model computation type or a dictionary mapping + a device name to the computation type (possible values are: + default, auto, int8, int8_float32, int8_float16, int8_bfloat16, + int16, float16, bfloat16, float32). + inter_threads: Maximum number of parallel generations. + intra_threads: Number of OpenMP threads per generator + (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the queue + (-1 for unlimited, 0 for an automatic value). + When the queue is full, future requests will block + until a free slot is available. + model_type: type of converter to convert the model + quantization: quantize the model + """ + + def __init__( + self, + model_path: str, + device="cpu", + device_index=0, + compute_type="default", + inter_threads=1, + intra_threads=0, + max_queued_batches=0, + model_type="OpenNMTPy", + quantization: Optional[str] = None, + ): + converter = _get_converter(model_path=model_path, model_type=model_type) + model_spec = converter.convert_on_the_fly(quantization=quantization) + + variables = model_spec.variables(ordered=True) + self.vocabularies = model_spec.get_vocabulary() + self.config = json.dumps(model_spec.config.to_dict()) + aliases = {} + + spec = model_spec.name + spec_revision = model_spec.revision + binary_version = model_spec.binary_version + variables_cpp = dict() + + for key, value in variables: + if isinstance(value, str): + aliases[key] = value + else: + variables_cpp[key] = ctranslate2.StorageView.from_array(value.numpy()) + + self.generator = ctranslate2.Generator( + spec=spec, + spec_revision=spec_revision, + binary_version=binary_version, + aliases=aliases, + vocabularies=self.vocabularies, + variables=variables_cpp, + config=self.config, + device=device, + device_index=device_index, + compute_type=compute_type, + inter_threads=inter_threads, + intra_threads=intra_threads, + max_queued_batches=max_queued_batches, + ) + + def generate_batch(self, *args, **kwargs): + return self.generator.generate_batch(*args, **kwargs) + + def score_batch(self, *args, **kwargs): + return self.generator.score_batch(*args, **kwargs) diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index 0b3a44c4b..5c13d4f6a 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -43,18 +43,18 @@ def __init__( self.relative_attention_max_distance = None if rotary_dim is not None: - self.rotary_dim = np.dtype("int32").type(rotary_dim) + self.rotary_dim = np.array(rotary_dim, dtype="int32") self.rotary_interleave = rotary_interleave - self.rotary_base = np.dtype("float32").type(rotary_base) + self.rotary_base = np.array(rotary_base, dtype="float32") if rotary_scaling_type is not None: - self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type) - self.rotary_scaling_factor = np.dtype("float32").type( - rotary_scaling_factor + self.rotary_scaling_type = np.array(rotary_scaling_type, dtype="int8") + self.rotary_scaling_factor = np.array( + rotary_scaling_factor, dtype="float32" ) if num_heads_kv is not None: - self.num_heads_kv = np.dtype("int32").type(num_heads_kv) + self.num_heads_kv = np.array(num_heads_kv, dtype="int32") if sliding_window is not None: - self.sliding_window = np.dtype("int32").type(sliding_window) + self.sliding_window = np.array(sliding_window, dtype="int32") diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 4cb765636..fdcdbffc8 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -114,10 +114,10 @@ def _check(spec, name, value): if value.dtype == np.float64: value = value.astype(np.float32) elif isinstance(value, float): - value = np.dtype("float32").type(value) + value = np.array(value, dtype="float32") elif isinstance(value, bool): # Convert bool to an integer type. - value = np.dtype("int8").type(value) + value = np.array(value, dtype="int8") elif isinstance(value, str): if value != OPTIONAL: value = np.frombuffer(value.encode("utf-8"), dtype=np.int8) @@ -311,6 +311,11 @@ def __init__(self): self._config = self.get_default_config() self._files = {} + @abc.abstractmethod + def get_vocabulary(self): + """Returns the map vocabulary expected by the model.""" + raise NotImplementedError() + @property def name(self): """The name of the model specification.""" @@ -325,6 +330,11 @@ def revision(self): """ return 1 + @property + def binary_version(self): + """The binary version""" + return CURRENT_BINARY_VERSION + @property def config(self): """The model configuration.""" @@ -455,6 +465,14 @@ def __init__(self): "target": [], } + def get_vocabulary(self): + vocabularies = dict(_flatten_vocabularies(self._vocabularies)) + all_vocabularies = list(vocabularies.values()) + if all(vocabulary == all_vocabularies[0] for vocabulary in all_vocabularies): + vocabularies = {"shared": all_vocabularies[0]} + + return vocabularies + def get_default_config(self): return SequenceToSequenceModelConfig() @@ -566,6 +584,9 @@ def __init__(self): super().__init__() self._vocabulary = [] + def get_vocabulary(self): + return {"vocabulary": self._vocabulary} + def get_default_config(self): return LanguageModelConfig() diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 7208be8a9..a9232652e 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -45,10 +45,10 @@ def __init__( rms_norm: Use the root mean square layer normalization. multi_query_attention: Use multi-query attention. """ - self.num_heads = np.dtype("int16").type(num_heads) + self.num_heads = np.array(num_heads, dtype="int16") self.pre_norm = pre_norm - self.activation = np.dtype("int8").type(activation) - self.embeddings_merge = np.dtype("int8").type(embeddings_merge) + self.activation = np.array(activation, dtype="int8") + self.embeddings_merge = np.array(embeddings_merge, dtype="int8") self.embeddings = [ common_spec.EmbeddingsSpec() for _ in range(num_source_embeddings) ] @@ -160,11 +160,11 @@ def __init__( % num_heads_kv ) - self.num_heads = np.dtype("int16").type(num_heads) + self.num_heads = np.array(num_heads, dtype="int16") self.pre_norm = pre_norm - self.activation = np.dtype("int8").type(activation) - self.alignment_layer = np.dtype("int16").type(alignment_layer) - self.alignment_heads = np.dtype("int16").type(alignment_heads) + self.activation = np.array(activation, dtype="int8") + self.alignment_layer = np.array(alignment_layer, dtype="int16") + self.alignment_heads = np.array(alignment_heads, dtype="int16") self.embeddings = common_spec.EmbeddingsSpec() self.scale_embeddings = True self.scale_outputs = model_spec.OPTIONAL @@ -172,7 +172,7 @@ def __init__( self.alibi_use_positive_positions = alibi_use_positive_positions self.scale_alibi = scale_alibi if sliding_window is not None: - self.sliding_window = np.dtype("int32").type(sliding_window) + self.sliding_window = np.array(sliding_window, dtype="int32") if ( not relative_position and not relative_attention_bias diff --git a/src/models/language_model.cc b/src/models/language_model.cc index 466e42594..b4899fe48 100644 --- a/src/models/language_model.cc +++ b/src/models/language_model.cc @@ -35,6 +35,22 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void LanguageModel::initialize(std::unordered_map>& vocabularies) { + if (binary_version() < 6) { + config["unk_token"] = get_attribute_with_default("unk_token", ""); + config["bos_token"] = get_attribute_with_default("bos_token", ""); + config["eos_token"] = get_attribute_with_default("eos_token", ""); + } + + VocabularyInfo vocab_info; + vocab_info.unk_token = config["unk_token"]; + vocab_info.bos_token = config["bos_token"]; + vocab_info.eos_token = config["eos_token"]; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } std::vector SequenceGeneratorReplica::score(const std::vector>& tokens, diff --git a/src/models/model.cc b/src/models/model.cc index 0672494ff..48ac37bd1 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -185,6 +185,18 @@ namespace ctranslate2 { } } + void Model::set_config(const std::string& config_str) { + config = nlohmann::json::parse(config_str); + } + + void Model::set_revision(const size_t revision) { + _spec_revision = revision; + } + + void Model::set_binary_version(const size_t binary_version) { + _binary_version = binary_version; + } + const StorageView* Model::get_variable_if_exists(const std::string& name) const { auto it = _variable_index.find(name); if (it == _variable_index.end()) @@ -509,6 +521,50 @@ namespace ctranslate2 { return model; } + std::shared_ptr Model::load(const std::string& spec, + const size_t& spec_version, + const size_t& binary_version, + std::unordered_map& aliases, + std::unordered_map>& vocabularies, + std::unordered_map& variables, + const std::string& config, + Device device, + int device_index, + ComputeType compute_type) { + auto model = models::create_model(spec); + + // Load the variables. + for (auto& variable : variables) { + model->register_variable(variable.first, std::move(variable.second)); + } + // Maybe quantize/dequantize/convert the variables to match the requested compute type. + model->set_compute_type(compute_type, device, device_index); + + // Move variables to the target device. + model->set_device(device, device_index); + + model->set_config(config); + + model->set_revision(spec_version); + model->set_binary_version(binary_version); + // Register variable aliases. + if (binary_version >= 3) { + for (auto& alias_pair : aliases) { + const auto alias = alias_pair.first; + const auto variable_name = alias_pair.second; + model->register_variable_alias(alias, variable_name); + // Also alias the quantization scale that could be associated to variable_name. + model->register_variable_alias(alias + "_scale", variable_name + "_scale"); + } + } + + // Run additional model initialization. + const ScopedDeviceSetter scoped_device_setter(device, device_index); + model->process_linear_weights(); + model->initialize(vocabularies); + return model; + } + std::shared_ptr Model::copy_to(Device device, int device_index) const { auto model = clone(); diff --git a/src/models/sequence_to_sequence.cc b/src/models/sequence_to_sequence.cc index a7e64611f..7b239954a 100644 --- a/src/models/sequence_to_sequence.cc +++ b/src/models/sequence_to_sequence.cc @@ -76,6 +76,59 @@ namespace ctranslate2 { load_vocabularies(model_reader); } + void SequenceToSequenceModel::initialize(std::unordered_map>& vocabularies) { + if (binary_version() < 6) { + config["unk_token"] = get_attribute_with_default("unk_token", ""); + config["bos_token"] = get_attribute_with_default("bos_token", ""); + config["eos_token"] = get_attribute_with_default("eos_token", ""); + config["add_source_bos"] = get_flag_with_default("with_source_bos", false); + config["add_source_eos"] = get_flag_with_default("with_source_eos", false); + + if (get_flag_with_default("user_decoder_start_tokens", false)) + config["decoder_start_token"] = nullptr; + else if (get_flag_with_default("with_target_bos", true)) + config["decoder_start_token"] = config["bos_token"]; + else + config["decoder_start_token"] = config["eos_token"]; + } + + VocabularyInfo vocab_info; + vocab_info.unk_token = config["unk_token"]; + vocab_info.bos_token = config["bos_token"]; + vocab_info.eos_token = config["eos_token"]; + + auto shared_vocabulary = std::make_shared(std::move(vocabularies.at("shared_vocabulary"))); + + if (shared_vocabulary) { + _target_vocabulary = shared_vocabulary; + _source_vocabularies = {shared_vocabulary}; + + } else { + _target_vocabulary = std::make_shared(std::move(vocabularies.at("target_vocabulary"))); + if (!_target_vocabulary) + throw std::runtime_error("Cannot load the target vocabulary from the model directory"); + + auto source_vocabulary = std::make_shared(std::move(vocabularies.at("source_vocabulary"))); + + if (source_vocabulary) { + _source_vocabularies = {source_vocabulary}; + } else { + for (size_t i = 1;; i++) { + const std::string name = "source_" + std::to_string(i) + "_vocabulary"; + auto vocabulary = std::make_shared(std::move(vocabularies.at(name))); + + if (!vocabulary) + break; + + _source_vocabularies.emplace_back(vocabulary); + } + } + + if (_source_vocabularies.empty()) + throw std::runtime_error("Cannot load the source vocabulary from the model directory"); + } + } + size_t SequenceToSequenceModel::num_source_vocabularies() const { return _source_vocabularies.size(); } diff --git a/src/models/transformer.cc b/src/models/transformer.cc index f62984b2e..099ece22a 100644 --- a/src/models/transformer.cc +++ b/src/models/transformer.cc @@ -108,6 +108,16 @@ namespace ctranslate2 { } } + void TransformerDecoderModel::initialize(std::unordered_map>& vocabularies) { + LanguageModel::initialize(vocabularies); + + if (spec_revision() < 2) { + register_variable_alias("decoder/num_heads", "num_heads"); + register_variable_alias("decoder/pre_norm", "pre_norm"); + register_variable_alias("decoder/activation", "activation"); + } + } + std::unique_ptr TransformerDecoderModel::as_sequence_generator() const { const auto scoped_device_setter = get_scoped_device_setter(); diff --git a/src/models/wav2vec2.cc b/src/models/wav2vec2.cc index 79a7a40d4..f4b8ed56c 100644 --- a/src/models/wav2vec2.cc +++ b/src/models/wav2vec2.cc @@ -34,6 +34,17 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void Wav2Vec2Model::initialize(std::unordered_map>& vocabularies) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "[UNK]"; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + bool Wav2Vec2Model::is_quantizable(const std::string& variable_name) const { return (Model::is_quantizable(variable_name) && variable_name.find("conv") == std::string::npos); diff --git a/src/models/whisper.cc b/src/models/whisper.cc index da12898e9..c5d1e807c 100644 --- a/src/models/whisper.cc +++ b/src/models/whisper.cc @@ -33,6 +33,17 @@ namespace ctranslate2 { throw std::runtime_error("Cannot load the vocabulary from the model directory"); } + void WhisperModel::initialize(std::unordered_map>& vocabularies) { + VocabularyInfo vocab_info; + vocab_info.unk_token = "<|endoftext|>"; + vocab_info.bos_token = "<|startoftranscript|>"; + vocab_info.eos_token = "<|endoftext|>"; + + _vocabulary = std::make_shared(vocabularies.at("vocabulary")); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + bool WhisperModel::is_quantizable(const std::string& variable_name) const { return (Model::is_quantizable(variable_name) && variable_name.find("conv") == std::string::npos);