CTranslate2 exposes high-level classes to run generative language models such as GPT-2. The main entrypoint is the Generator
class which provides several methods:
Method name | Description | Example |
---|---|---|
generate_batch |
Generate text from a batch of prompts or start tokens. | {ref}guides/transformers:gpt-2 |
score_batch |
Compute the token-level log-likelihood and the sequence perplexity. | {ref}guides/fairseq:wmt19 language model |
generate_tokens |
Stream the generated tokens. | {ref}generation:token streaming Chat with Llama 2 |
forward_batch |
Get the full output logits (or log probs) for a sequence. |
generate_tokens
is a convenience method to return tokens as they are generated by the model. This can be useful when running large models in an interactive environment.
The example below shows how to use this method and progressively decode SentencePiece tokens. It should be adapted if the model uses a different tokenizer or the generated language does not use a space to separate words.
import ctranslate2
import sentencepiece as spm
generator = ctranslate2.Generator("ct2_model/")
sp = spm.SentencePieceProcessor("tokenizer.model")
prompt = "What is the meaning of life?"
prompt_tokens = sp.encode(prompt, out_type=str)
step_results = generator.generate_tokens(
prompt_tokens,
sampling_temperature=0.8,
sampling_topk=20,
max_length=1024,
)
output_ids = []
for step_result in step_results:
is_new_word = step_result.token.startswith("▁")
if is_new_word and output_ids:
word = sp.decode(output_ids)
print(word, end=" ", flush=True)
output_ids = []
output_ids.append(step_result.token_id)
if output_ids:
word = sp.decode(output_ids)
print(word)
If you `break` out of the loop, the generation will still run to completion in the background. To stop the generation early you should close the generator, for example using `step_results.close()`.
The `callback` argument in the method `generate_batch` which can also be used to implement token streaming. This is what `generate_tokens` use internally.
The example [Chat with Llama 2](https://github.com/OpenNMT/CTranslate2/tree/master/examples/llama2) which uses token streaming in an interactive chat session.
The methods generate_batch
and generate_tokens
have an argument static_prompt
that can be used for models that always start with the same prompt (also known as a system prompt). The model is run once on this static prompt and the model state is cached and reused for future calls with the same static prompt.
For example StableLM uses a system prompt which could be implemented like this:
import ctranslate2
import transformers
generator = ctranslate2.Generator("stablelm-ct2/", device="cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
system_prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(system_prompt))
prompt = "<|USER|>What's your mood today?<|ASSISTANT|>"
prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))
step_results = generator.generate_tokens(
prompt=prompt_tokens,
static_prompt=system_prompt_tokens,
max_length=512,
sampling_topk=10,
sampling_temperature=0.7,
end_token=[50278, 50279, 50277, 1, 0],
)
At this time the cache size is unlimited and the cache is only cleared when the model is unloaded. Also if the model is loaded on multiple GPUs, each model replica manages its own cache to avoid copying the state between devices.
Special tokens such as the decoder start token <s>
should be explicitly included in the input if required by the model. No special tokens are added by the generator methods.
This is different from the translator methods which usually include these special tokens implicitly.