Skip to content

Commit

Permalink
✅📝 adds docs and tests for token and text generator
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 26, 2023
1 parent cafe578 commit 5a3c4fc
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 8 deletions.
62 changes: 59 additions & 3 deletions src/llm/text_generator/dummy_text_generator.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,82 @@
use crate::llm::FinishReason;

use super::{TextGeneratorResult, TextGeneratorTrait};
use anyhow::Result;

/// A basic implementation of the `TextGeneratorTrait` for testing and demonstration purposes.
///
/// This struct uses a simple approach for text generation, primarily intended to serve as a placeholder
/// or for testing the framework without involving complex models.
pub struct DummyTextGenerator {
// The internal text that will be used for generation.
text: String,
}

impl DummyTextGenerator {
/// Constructs a new `DummyTextGenerator` with a given text.
///
/// # Arguments
///
/// * `text` - The initial text to be used by the generator.
pub fn new(text: String) -> Self {
Self { text }
}
}

impl TextGeneratorTrait for DummyTextGenerator {
/// Initializes the generator with a given prompt.
///
/// This method sets the internal text to the provided prompt.
///
/// # Arguments
///
/// * `prompt` - A `String` serving as the initial text.
///
/// # Returns
///
/// Always returns `Ok(())` as there is no complex initialization process.
fn init(&mut self, prompt: String) -> Result<()> {
self.text = prompt;
Ok(())
}

/// Generates the next piece of text.
///
/// For the `DummyTextGenerator`, this method returns the entire internal text at once and
/// then signifies completion in subsequent calls.
///
/// # Returns
///
/// A `Result` wrapping a `TextGeneratorResult`, which is either the entire text as a token
/// or an indication that the generation process has finished.
fn next(&mut self) -> Result<TextGeneratorResult> {
todo!()
if !self.text.is_empty() {
let text = std::mem::take(&mut self.text);
Ok(TextGeneratorResult::Token((text, 1.0)))
} else {
Ok(TextGeneratorResult::Finish(FinishReason::Length))
}
}
}

#[cfg(test)]
mod tests {}
mod tests {
use super::*;

#[test]
fn test_dummy_text_generator() {
let mut generator = DummyTextGenerator::new("Hello World".to_string());
generator.init("Test".to_string()).unwrap();

// First call should return the entire text.
match generator.next().unwrap() {
TextGeneratorResult::Token((text, _)) => assert_eq!(text, "Test"),
_ => panic!("Unexpected result on first call to next"),
}

// Subsequent calls should indicate that the generation process has finished.
assert_eq!(
generator.next().unwrap(),
TextGeneratorResult::Finish(FinishReason::Length)
);
}
}
51 changes: 50 additions & 1 deletion src/llm/text_generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,75 @@ use anyhow::Result;
use candle_examples::token_output_stream::TokenOutputStream;
mod dummy_text_generator;

/// Represents the probability associated with a piece of generated text.
pub type TextProbability = (String, f32);

/// Enumerates possible results from a text generation process.
///
/// This enum is used to encapsulate the outcomes of text generation, including
/// both the generation of a new token and the conclusion of the generation process.
#[derive(Debug, PartialEq)]
pub enum TextGeneratorResult {
/// Represents a generated piece of text along with its probability.
///
/// The `String` is the generated text, and the `f32` is the probability associated with it.
Token(TextProbability),

/// Indicates the completion of the text generation process.
///
/// This variant is used when the generation process reaches an end, either due to reaching
/// a specified limit or encountering a stopping condition.
Finish(FinishReason),
}

/// Trait for text generation functionality.
/// A trait defining the core functionality for text generation.
///
/// This trait encapsulates the necessary methods for initializing the generation process with a
/// prompt and then producing text iteratively.
pub trait TextGeneratorTrait {
/// Initializes the text generation process with a given prompt.
///
/// This method sets up the necessary state for text generation based on the provided prompt.
///
/// # Arguments
///
/// * `prompt` - A `String` that serves as the starting point for text generation.
///
/// # Returns
///
/// A `Result` indicating success or failure of the initialization process.
fn init(&mut self, prompt: String) -> Result<()>;

/// Generates the next piece of text in the sequence.
///
/// This method should be called iteratively to generate text progressively.
/// It provides the next piece of text based on the current state of the generator.
///
/// # Returns
///
/// A `Result` wrapping a `TextGeneratorResult`, which can be either a generated token
/// or an indication that the generation process has finished.
fn next(&mut self) -> Result<TextGeneratorResult>;
}

/// Handles the text generation process.
///
/// This struct is responsible for managing the token generation and converting tokens into text.
pub struct TextGenerator {
/// The tokenizer used to encode the prompt and decode the generated tokens.
tokenizer: TokenOutputStream,

/// The token generator that produces tokens based on the model's output.
token_generator: Box<dyn TokenGeneratorTrait>,
}

impl TextGenerator {
/// Constructs a new `TextGenerator`.
///
/// # Arguments
///
/// * `tokenizer` - Tokenizer for encoding prompts and decoding generated tokens.
/// * `token_generator` - Token generator that provides the logic for generating tokens.
pub fn new(
tokenizer: TokenOutputStream,
token_generator: Box<dyn TokenGeneratorTrait>,
Expand Down
61 changes: 57 additions & 4 deletions src/llm/token_generator/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ use candle_core::{Device, Tensor};

use super::{FinishReason, TokenGeneratorResult, TokenGeneratorTrait};

/// A dummy implementation of the `TokenGeneratorTrait` used for testing purposes.
///
/// This token generator produces a predefined sequence of tokens based on the index counter.
/// It's a simplified version intended for unit testing without involving complex models or samplers.
pub struct DummyTokenGenerator {
parameter: GenerateParameter,
index: usize,
Expand All @@ -18,6 +22,15 @@ pub struct DummyTokenGenerator {
unsafe impl Send for DummyTokenGenerator {}

impl DummyTokenGenerator {
/// Creates a new instance of `DummyTokenGenerator` with specified generation parameters.
///
/// # Arguments
///
/// * `parameter` - Parameters that control the generation process such as max tokens, temperature, etc.
///
/// # Returns
///
/// A new instance of `DummyTokenGenerator`.
pub fn new(parameter: GenerateParameter) -> Self {
let sampler = Box::new(DummySampler::new());
let model = Box::new(DummyModelProcessor::new());
Expand All @@ -32,6 +45,7 @@ impl DummyTokenGenerator {

impl TokenGeneratorTrait for DummyTokenGenerator {
fn init(&mut self, _prompt_tokens: Vec<u32>) -> Result<()> {
self.index = 0;
Ok(())
}
fn next(&mut self) -> Result<TokenGeneratorResult> {
Expand All @@ -51,20 +65,59 @@ mod tests {
use super::*;

#[test]
fn test_dummy_token_generator() {
fn test_dummy_token_generator_with_zero_max_tokens() {
let mut token_generator = DummyTokenGenerator::new(GenerateParameter {
max_new_tokens: 10,
max_new_tokens: 0,
..Default::default()
});
for index in 0..10 {
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Finish(FinishReason::Length)
);
}

#[test]
fn test_dummy_token_generator_with_repeat_penalty() {
let mut token_generator = DummyTokenGenerator::new(GenerateParameter {
max_new_tokens: 5,
repeat_penalty: 1.5, // Non-default value
..Default::default()
});
// Just verify that it can still produce tokens normally
for index in 0..5 {
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Token((index, 1.0))
);
}
}

#[test]
fn test_dummy_token_generator_with_high_max_tokens() {
let mut token_generator = DummyTokenGenerator::new(GenerateParameter {
max_new_tokens: 1000, // High value
..Default::default()
});
// Run through multiple iterations and ensure it stops correctly
for _ in 0..1000 {
if let TokenGeneratorResult::Finish(_) = token_generator.next().unwrap() {
break;
}
}
}

#[test]
fn test_dummy_token_generator_initialization() {
let mut token_generator = DummyTokenGenerator::new(Default::default());
token_generator.init(vec![1, 2, 3]).unwrap(); // Initial set of tokens
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Finish(FinishReason::Length)
TokenGeneratorResult::Token((0, 1.0))
);
token_generator.init(vec![4, 5, 6]).unwrap(); // Re-initialize with new tokens
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Token((1, 1.0))
);
}
}
35 changes: 35 additions & 0 deletions src/llm/token_generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,34 @@ pub enum TokenGeneratorResult {
Finish(FinishReason),
}

/// A trait defining the behavior of a token generator.
///
/// This trait is implemented by objects that can generate tokens based on some internal logic.
/// The trait provides methods to initialize the generator and to retrieve the next token in the sequence.
pub trait TokenGeneratorTrait: Send {
/// Initializes the token generator with a given set of prompt tokens.
///
/// # Arguments
///
/// * `prompt_tokens` - A vector of initial tokens used to start the token generation process.
///
/// # Returns
///
/// A `Result` indicating the success or failure of the initialization.
fn init(&mut self, prompt_tokens: Vec<u32>) -> Result<()>;

/// Retrieves the next token from the generator.
///
/// # Returns
///
/// A `Result` containing the `TokenGeneratorResult`, which can be either a token or a signal to finish generation.
fn next(&mut self) -> Result<TokenGeneratorResult>;
}

/// A token generator that generates tokens based on provided parameters, model processor, and sampler.
///
/// This struct implements the `TokenGeneratorTrait` and provides functionality to generate tokens
/// for text generation tasks.
pub struct TokenGenerator {
index: usize,
stop_tokens: HashSet<u32>,
Expand All @@ -37,6 +60,18 @@ pub struct TokenGenerator {
unsafe impl Send for TokenGenerator {}

impl TokenGenerator {
/// Creates a new `TokenGenerator` with the specified parameters.
///
/// # Arguments
///
/// * `stop_tokens` - A set of token IDs that signal the end of token generation.
/// * `parameter` - The parameters to use for token generation.
/// * `model` - A model processor to generate logits.
/// * `sampler` - A sampler to sample tokens from logits.
///
/// # Returns
///
/// A new instance of `TokenGenerator`.
pub fn new(
stop_tokens: HashSet<u32>,
parameter: GenerateParameter,
Expand Down

0 comments on commit 5a3c4fc

Please sign in to comment.