From 5a3c4fce810becfc9dc42d3f133d7ffc0f1fff82 Mon Sep 17 00:00:00 2001 From: Christian M Date: Tue, 26 Dec 2023 09:57:20 +0100 Subject: [PATCH] :white_check_mark::memo: adds docs and tests for token and text generator --- .../text_generator/dummy_text_generator.rs | 62 ++++++++++++++++++- src/llm/text_generator/mod.rs | 51 ++++++++++++++- src/llm/token_generator/dummy.rs | 61 ++++++++++++++++-- src/llm/token_generator/mod.rs | 35 +++++++++++ 4 files changed, 201 insertions(+), 8 deletions(-) diff --git a/src/llm/text_generator/dummy_text_generator.rs b/src/llm/text_generator/dummy_text_generator.rs index db92eab..4af36b4 100644 --- a/src/llm/text_generator/dummy_text_generator.rs +++ b/src/llm/text_generator/dummy_text_generator.rs @@ -1,26 +1,82 @@ +use crate::llm::FinishReason; + use super::{TextGeneratorResult, TextGeneratorTrait}; use anyhow::Result; - +/// A basic implementation of the `TextGeneratorTrait` for testing and demonstration purposes. +/// +/// This struct uses a simple approach for text generation, primarily intended to serve as a placeholder +/// or for testing the framework without involving complex models. pub struct DummyTextGenerator { + // The internal text that will be used for generation. text: String, } impl DummyTextGenerator { + /// Constructs a new `DummyTextGenerator` with a given text. + /// + /// # Arguments + /// + /// * `text` - The initial text to be used by the generator. pub fn new(text: String) -> Self { Self { text } } } impl TextGeneratorTrait for DummyTextGenerator { + /// Initializes the generator with a given prompt. + /// + /// This method sets the internal text to the provided prompt. + /// + /// # Arguments + /// + /// * `prompt` - A `String` serving as the initial text. + /// + /// # Returns + /// + /// Always returns `Ok(())` as there is no complex initialization process. fn init(&mut self, prompt: String) -> Result<()> { self.text = prompt; Ok(()) } + /// Generates the next piece of text. + /// + /// For the `DummyTextGenerator`, this method returns the entire internal text at once and + /// then signifies completion in subsequent calls. + /// + /// # Returns + /// + /// A `Result` wrapping a `TextGeneratorResult`, which is either the entire text as a token + /// or an indication that the generation process has finished. fn next(&mut self) -> Result { - todo!() + if !self.text.is_empty() { + let text = std::mem::take(&mut self.text); + Ok(TextGeneratorResult::Token((text, 1.0))) + } else { + Ok(TextGeneratorResult::Finish(FinishReason::Length)) + } } } #[cfg(test)] -mod tests {} +mod tests { + use super::*; + + #[test] + fn test_dummy_text_generator() { + let mut generator = DummyTextGenerator::new("Hello World".to_string()); + generator.init("Test".to_string()).unwrap(); + + // First call should return the entire text. + match generator.next().unwrap() { + TextGeneratorResult::Token((text, _)) => assert_eq!(text, "Test"), + _ => panic!("Unexpected result on first call to next"), + } + + // Subsequent calls should indicate that the generation process has finished. + assert_eq!( + generator.next().unwrap(), + TextGeneratorResult::Finish(FinishReason::Length) + ); + } +} diff --git a/src/llm/text_generator/mod.rs b/src/llm/text_generator/mod.rs index 16d2ef6..d9239b7 100644 --- a/src/llm/text_generator/mod.rs +++ b/src/llm/text_generator/mod.rs @@ -6,26 +6,75 @@ use anyhow::Result; use candle_examples::token_output_stream::TokenOutputStream; mod dummy_text_generator; +/// Represents the probability associated with a piece of generated text. pub type TextProbability = (String, f32); +/// Enumerates possible results from a text generation process. +/// +/// This enum is used to encapsulate the outcomes of text generation, including +/// both the generation of a new token and the conclusion of the generation process. #[derive(Debug, PartialEq)] pub enum TextGeneratorResult { + /// Represents a generated piece of text along with its probability. + /// + /// The `String` is the generated text, and the `f32` is the probability associated with it. Token(TextProbability), + + /// Indicates the completion of the text generation process. + /// + /// This variant is used when the generation process reaches an end, either due to reaching + /// a specified limit or encountering a stopping condition. Finish(FinishReason), } -/// Trait for text generation functionality. +/// A trait defining the core functionality for text generation. +/// +/// This trait encapsulates the necessary methods for initializing the generation process with a +/// prompt and then producing text iteratively. pub trait TextGeneratorTrait { + /// Initializes the text generation process with a given prompt. + /// + /// This method sets up the necessary state for text generation based on the provided prompt. + /// + /// # Arguments + /// + /// * `prompt` - A `String` that serves as the starting point for text generation. + /// + /// # Returns + /// + /// A `Result` indicating success or failure of the initialization process. fn init(&mut self, prompt: String) -> Result<()>; + + /// Generates the next piece of text in the sequence. + /// + /// This method should be called iteratively to generate text progressively. + /// It provides the next piece of text based on the current state of the generator. + /// + /// # Returns + /// + /// A `Result` wrapping a `TextGeneratorResult`, which can be either a generated token + /// or an indication that the generation process has finished. fn next(&mut self) -> Result; } +/// Handles the text generation process. +/// +/// This struct is responsible for managing the token generation and converting tokens into text. pub struct TextGenerator { + /// The tokenizer used to encode the prompt and decode the generated tokens. tokenizer: TokenOutputStream, + + /// The token generator that produces tokens based on the model's output. token_generator: Box, } impl TextGenerator { + /// Constructs a new `TextGenerator`. + /// + /// # Arguments + /// + /// * `tokenizer` - Tokenizer for encoding prompts and decoding generated tokens. + /// * `token_generator` - Token generator that provides the logic for generating tokens. pub fn new( tokenizer: TokenOutputStream, token_generator: Box, diff --git a/src/llm/token_generator/dummy.rs b/src/llm/token_generator/dummy.rs index ce11fb4..eaaf6ee 100644 --- a/src/llm/token_generator/dummy.rs +++ b/src/llm/token_generator/dummy.rs @@ -8,6 +8,10 @@ use candle_core::{Device, Tensor}; use super::{FinishReason, TokenGeneratorResult, TokenGeneratorTrait}; +/// A dummy implementation of the `TokenGeneratorTrait` used for testing purposes. +/// +/// This token generator produces a predefined sequence of tokens based on the index counter. +/// It's a simplified version intended for unit testing without involving complex models or samplers. pub struct DummyTokenGenerator { parameter: GenerateParameter, index: usize, @@ -18,6 +22,15 @@ pub struct DummyTokenGenerator { unsafe impl Send for DummyTokenGenerator {} impl DummyTokenGenerator { + /// Creates a new instance of `DummyTokenGenerator` with specified generation parameters. + /// + /// # Arguments + /// + /// * `parameter` - Parameters that control the generation process such as max tokens, temperature, etc. + /// + /// # Returns + /// + /// A new instance of `DummyTokenGenerator`. pub fn new(parameter: GenerateParameter) -> Self { let sampler = Box::new(DummySampler::new()); let model = Box::new(DummyModelProcessor::new()); @@ -32,6 +45,7 @@ impl DummyTokenGenerator { impl TokenGeneratorTrait for DummyTokenGenerator { fn init(&mut self, _prompt_tokens: Vec) -> Result<()> { + self.index = 0; Ok(()) } fn next(&mut self) -> Result { @@ -51,20 +65,59 @@ mod tests { use super::*; #[test] - fn test_dummy_token_generator() { + fn test_dummy_token_generator_with_zero_max_tokens() { let mut token_generator = DummyTokenGenerator::new(GenerateParameter { - max_new_tokens: 10, + max_new_tokens: 0, ..Default::default() }); - for index in 0..10 { + assert_eq!( + token_generator.next().unwrap(), + TokenGeneratorResult::Finish(FinishReason::Length) + ); + } + + #[test] + fn test_dummy_token_generator_with_repeat_penalty() { + let mut token_generator = DummyTokenGenerator::new(GenerateParameter { + max_new_tokens: 5, + repeat_penalty: 1.5, // Non-default value + ..Default::default() + }); + // Just verify that it can still produce tokens normally + for index in 0..5 { assert_eq!( token_generator.next().unwrap(), TokenGeneratorResult::Token((index, 1.0)) ); } + } + + #[test] + fn test_dummy_token_generator_with_high_max_tokens() { + let mut token_generator = DummyTokenGenerator::new(GenerateParameter { + max_new_tokens: 1000, // High value + ..Default::default() + }); + // Run through multiple iterations and ensure it stops correctly + for _ in 0..1000 { + if let TokenGeneratorResult::Finish(_) = token_generator.next().unwrap() { + break; + } + } + } + + #[test] + fn test_dummy_token_generator_initialization() { + let mut token_generator = DummyTokenGenerator::new(Default::default()); + token_generator.init(vec![1, 2, 3]).unwrap(); // Initial set of tokens assert_eq!( token_generator.next().unwrap(), - TokenGeneratorResult::Finish(FinishReason::Length) + TokenGeneratorResult::Token((0, 1.0)) + ); + token_generator.init(vec![4, 5, 6]).unwrap(); // Re-initialize with new tokens + assert_eq!( + token_generator.next().unwrap(), + TokenGeneratorResult::Token((1, 1.0)) ); } } diff --git a/src/llm/token_generator/mod.rs b/src/llm/token_generator/mod.rs index f037447..904d222 100644 --- a/src/llm/token_generator/mod.rs +++ b/src/llm/token_generator/mod.rs @@ -18,11 +18,34 @@ pub enum TokenGeneratorResult { Finish(FinishReason), } +/// A trait defining the behavior of a token generator. +/// +/// This trait is implemented by objects that can generate tokens based on some internal logic. +/// The trait provides methods to initialize the generator and to retrieve the next token in the sequence. pub trait TokenGeneratorTrait: Send { + /// Initializes the token generator with a given set of prompt tokens. + /// + /// # Arguments + /// + /// * `prompt_tokens` - A vector of initial tokens used to start the token generation process. + /// + /// # Returns + /// + /// A `Result` indicating the success or failure of the initialization. fn init(&mut self, prompt_tokens: Vec) -> Result<()>; + + /// Retrieves the next token from the generator. + /// + /// # Returns + /// + /// A `Result` containing the `TokenGeneratorResult`, which can be either a token or a signal to finish generation. fn next(&mut self) -> Result; } +/// A token generator that generates tokens based on provided parameters, model processor, and sampler. +/// +/// This struct implements the `TokenGeneratorTrait` and provides functionality to generate tokens +/// for text generation tasks. pub struct TokenGenerator { index: usize, stop_tokens: HashSet, @@ -37,6 +60,18 @@ pub struct TokenGenerator { unsafe impl Send for TokenGenerator {} impl TokenGenerator { + /// Creates a new `TokenGenerator` with the specified parameters. + /// + /// # Arguments + /// + /// * `stop_tokens` - A set of token IDs that signal the end of token generation. + /// * `parameter` - The parameters to use for token generation. + /// * `model` - A model processor to generate logits. + /// * `sampler` - A sampler to sample tokens from logits. + /// + /// # Returns + /// + /// A new instance of `TokenGenerator`. pub fn new( stop_tokens: HashSet, parameter: GenerateParameter,