Skip to content

Latest commit

 

History

History
108 lines (79 loc) · 4.6 KB

generation.md

File metadata and controls

108 lines (79 loc) · 4.6 KB

Text generation

CTranslate2 exposes high-level classes to run generative language models such as GPT-2. The main entrypoint is the Generator class which provides several methods:

Method name Description Example
generate_batch Generate text from a batch of prompts or start tokens. {ref}guides/transformers:gpt-2
score_batch Compute the token-level log-likelihood and the sequence perplexity. {ref}guides/fairseq:wmt19 language model
generate_tokens Stream the generated tokens. {ref}generation:token streaming
Chat with Llama 2
forward_batch Get the full output logits (or log probs) for a sequence.

Token streaming

generate_tokens is a convenience method to return tokens as they are generated by the model. This can be useful when running large models in an interactive environment.

The example below shows how to use this method and progressively decode SentencePiece tokens. It should be adapted if the model uses a different tokenizer or the generated language does not use a space to separate words.

import ctranslate2
import sentencepiece as spm

generator = ctranslate2.Generator("ct2_model/")
sp = spm.SentencePieceProcessor("tokenizer.model")

prompt = "What is the meaning of life?"
prompt_tokens = sp.encode(prompt, out_type=str)

step_results = generator.generate_tokens(
    prompt_tokens,
    sampling_temperature=0.8,
    sampling_topk=20,
    max_length=1024,
)

output_ids = []

for step_result in step_results:
    is_new_word = step_result.token.startswith("▁")

    if is_new_word and output_ids:
        word = sp.decode(output_ids)
        print(word, end=" ", flush=True)
        output_ids = []

    output_ids.append(step_result.token_id)

if output_ids:
    word = sp.decode(output_ids)
    print(word)
If you `break` out of the loop, the generation will still run to completion in the background. To stop the generation early you should close the generator, for example using `step_results.close()`.
The `callback` argument in the method `generate_batch` which can also be used to implement token streaming. This is what `generate_tokens` use internally.
The example [Chat with Llama 2](https://github.com/OpenNMT/CTranslate2/tree/master/examples/llama2) which uses token streaming in an interactive chat session.

Prompt caching

The methods generate_batch and generate_tokens have an argument static_prompt that can be used for models that always start with the same prompt (also known as a system prompt). The model is run once on this static prompt and the model state is cached and reused for future calls with the same static prompt.

For example StableLM uses a system prompt which could be implemented like this:

import ctranslate2
import transformers

generator = ctranslate2.Generator("stablelm-ct2/", device="cpu")
tokenizer = transformers.AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b")

system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
system_prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(system_prompt))

prompt = "<|USER|>What's your mood today?<|ASSISTANT|>"
prompt_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(prompt))

step_results = generator.generate_tokens(
    prompt=prompt_tokens,
    static_prompt=system_prompt_tokens,
    max_length=512,
    sampling_topk=10,
    sampling_temperature=0.7,
    end_token=[50278, 50279, 50277, 1, 0],
)
At this time the cache size is unlimited and the cache is only cleared when the model is unloaded. Also if the model is loaded on multiple GPUs, each model replica manages its own cache to avoid copying the state between devices.

Special tokens

Special tokens such as the decoder start token <s> should be explicitly included in the input if required by the model. No special tokens are added by the generator methods.

This is different from the translator methods which usually include these special tokens implicitly.