From cafe5781e86751f57ba6e010bbdbbce82b9c5e48 Mon Sep 17 00:00:00 2001 From: Christian M Date: Tue, 26 Dec 2023 09:31:39 +0100 Subject: [PATCH] :white_check_mark::memo: adds docs and tests --- src/llm/generate_parameter.rs | 89 +++++++++++++++++++++++++++++++---- src/llm/loader.rs | 40 ++++++++++++++++ src/llm/mod.rs | 4 +- src/llm/model_processor.rs | 31 +++++++++++- src/llm/sampler.rs | 29 +++++++++++- 5 files changed, 179 insertions(+), 14 deletions(-) diff --git a/src/llm/generate_parameter.rs b/src/llm/generate_parameter.rs index 6825460..cb8e873 100644 --- a/src/llm/generate_parameter.rs +++ b/src/llm/generate_parameter.rs @@ -1,16 +1,89 @@ -/// Parameters used to generate samples -#[derive(Debug, Clone, Default)] +//! Generate Parameters Module. +//! +//! This module defines parameters used for controlling text generation. + +use serde::{Deserialize, Serialize}; + +/// Parameters used to generate samples. +/// +/// This struct defines various settings that influence the behavior of the text generation process, +/// such as token limits, sampling temperature, and repeat penalties. +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct GenerateParameter { - /// Maximum number of tokens to generate + /// Maximum number of new tokens to generate. + #[serde(default = "default_max_new_tokens")] pub max_new_tokens: usize, - /// The seed used to generate samples + + /// Seed used for deterministic generation. + #[serde(default = "default_seed")] pub seed: u64, - /// The temperature used to generate samples + + /// Temperature for sampling. + #[serde(default = "default_temperature")] pub temperature: f64, - /// Nucleus sampling probability cutoff + + /// Nucleus sampling probability cutoff. + #[serde(default = "default_top_p")] pub top_p: f64, - /// Penalty to be applied for repeating tokens, 1. means no penalty + + /// Penalty for repeating tokens. + #[serde(default = "default_repeat_penalty")] pub repeat_penalty: f32, - /// The context size to consider for the repeat penalty + + /// The number of last tokens to consider for applying the repeat penalty. + #[serde(default = "default_repeat_last_n")] pub repeat_last_n: usize, } + +fn default_max_new_tokens() -> usize { + 50 +} + +fn default_seed() -> u64 { + 299792458 +} + +fn default_temperature() -> f64 { + 1.0 +} + +fn default_top_p() -> f64 { + 0.9 +} + +fn default_repeat_penalty() -> f32 { + 1.0 +} + +fn default_repeat_last_n() -> usize { + 64 +} + +impl Default for GenerateParameter { + fn default() -> Self { + serde_json::from_str("{}").unwrap_or_else(|_| GenerateParameter { + max_new_tokens: default_max_new_tokens(), + seed: default_seed(), + temperature: default_temperature(), + top_p: default_top_p(), + repeat_penalty: default_repeat_penalty(), + repeat_last_n: default_repeat_last_n(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_generate_parameter() { + let param = GenerateParameter::default(); + assert_eq!(param.max_new_tokens, default_max_new_tokens()); + assert_eq!(param.seed, default_seed()); + assert_eq!(param.temperature, default_temperature()); + assert_eq!(param.top_p, default_top_p()); + assert_eq!(param.repeat_penalty, default_repeat_penalty()); + assert_eq!(param.repeat_last_n, default_repeat_last_n()); + } +} diff --git a/src/llm/loader.rs b/src/llm/loader.rs index 38a2771..dc40512 100644 --- a/src/llm/loader.rs +++ b/src/llm/loader.rs @@ -1,3 +1,8 @@ +//! Model Loader Module. +//! +//! This module contains functions for loading model weights and tokenizers for text generation. +//! It supports various models and uses the Hugging Face Hub for downloading model files. + use std::path::PathBuf; use super::models::Models; @@ -10,6 +15,7 @@ use hf_hub::{Repo, RepoType}; use log::{debug, info}; use tokenizers::Tokenizer; +/// Formats the size in bytes into a human-readable string. fn format_size(size_in_bytes: usize) -> String { if size_in_bytes < 1_000 { format!("{}B", size_in_bytes) @@ -22,6 +28,17 @@ fn format_size(size_in_bytes: usize) -> String { } } +/// Creates and loads model weights from the Hugging Face Hub. +/// +/// # Arguments +/// +/// * `model` - The model enum specifying the model to load. +/// * `cache_dir` - Optional directory for caching downloaded models. +/// +/// # Returns +/// +/// Returns a result containing a tuple of `ModelWeights` and `Device`, +/// or an error if loading fails. pub fn create_model( model: Models, cache_dir: &Option, @@ -116,6 +133,16 @@ pub fn create_model( Ok((model, Device::Cpu)) } +/// Creates and loads a tokenizer from the Hugging Face Hub. +/// +/// # Arguments +/// +/// * `model` - The model enum specifying the tokenizer to load. +/// +/// # Returns +/// +/// Returns a result containing the `Tokenizer`, +/// or an error if loading fails. pub fn create_tokenizer(model: Models) -> Result> { let tokenizer_path = { let api = hf_hub::api::sync::Api::new()?; @@ -126,3 +153,16 @@ pub fn create_tokenizer(model: Models) -> Result Result; } +/// Implementation of `ModelProcessor` for the `ModelWeights` from `candle_transformers`. impl ModelProcessor for ModelWeights { fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { Self::forward(self, x, index_pos) } } +/// A dummy implementation of `ModelProcessor` for testing purposes. +/// +/// This processor simulates model outputs by returning incrementing tensors. pub struct DummyModelProcessor { index: usize, } impl DummyModelProcessor { + /// Creates a new `DummyModelProcessor`. pub fn new() -> Self { Self { index: 0 } } } +/// Provides a default instance of `DummyModelProcessor`. impl Default for DummyModelProcessor { fn default() -> Self { Self::new() } } +/// Implementation of `ModelProcessor` for `DummyModelProcessor`. impl ModelProcessor for DummyModelProcessor { fn forward(&mut self, x: &Tensor, _index_pos: usize) -> Result { self.index += 1; @@ -37,10 +64,10 @@ impl ModelProcessor for DummyModelProcessor { #[cfg(test)] mod tests { - use candle_core::Device; - use super::*; + use candle_core::Device; + /// Tests the `DummyModelProcessor` to ensure it returns incrementing tensors. #[test] fn test_dummy_model_processor() { let mut model_processor = DummyModelProcessor::new(); diff --git a/src/llm/sampler.rs b/src/llm/sampler.rs index 90d0205..e9b8a04 100644 --- a/src/llm/sampler.rs +++ b/src/llm/sampler.rs @@ -1,32 +1,57 @@ +//! Sampler module for text generation. +//! +//! This module contains the `Sampler` trait and its implementations which are +//! used for sampling tokens based on the output logits from a language model. + use candle_core::{Result, Tensor}; use candle_transformers::generation::LogitsProcessor; +/// A trait for sampling a token based on logits output. +/// +/// This trait defines a method for sampling a single token from a distribution +/// represented by logits. pub trait Sampler { + /// Samples a token based on provided logits. + /// + /// # Arguments + /// + /// * `logits` - A reference to a tensor containing logits output from the model. + /// + /// # Returns + /// + /// Returns a `Result` containing the sampled token's ID. fn sample(&mut self, logits: &Tensor) -> Result; } +/// Implementation of `Sampler` for the `LogitsProcessor` from `candle_transformers`. impl Sampler for LogitsProcessor { fn sample(&mut self, logits: &Tensor) -> Result { Self::sample(self, logits) } } +/// A dummy implementation of `Sampler` for testing purposes. +/// +/// This sampler sequentially returns incrementing integers as tokens. pub struct DummySampler { index: usize, } impl DummySampler { + /// Creates a new `DummySampler`. pub fn new() -> Self { Self { index: 0 } } } +/// Provides a default instance of `DummySampler`. impl Default for DummySampler { fn default() -> Self { Self::new() } } +/// Implementation of `Sampler` for `DummySampler`. impl Sampler for DummySampler { fn sample(&mut self, _logits: &Tensor) -> Result { self.index += 1; @@ -36,10 +61,10 @@ impl Sampler for DummySampler { #[cfg(test)] mod tests { - use candle_core::Device; - use super::*; + use candle_core::Device; + /// Tests the `DummySampler` to ensure it returns incrementing integers. #[test] fn test_dummy_sampler() { let mut sampler = DummySampler::new();