diff --git a/src/llm/text_generation.rs b/src/llm/text_generation.rs index ec2a17c..e277006 100644 --- a/src/llm/text_generation.rs +++ b/src/llm/text_generation.rs @@ -3,17 +3,23 @@ use crate::{ llm::token_generator::TokenGenerator, }; +use crate::llm::generate_parameter::GenerateParameter; use anyhow::{Error as E, Result}; use candle_core::Device; use candle_transformers::{generation::LogitsProcessor, models::quantized_llama::ModelWeights}; use futures::Stream; use log::{debug, info, trace}; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; use tokenizers::Tokenizer; use tokio::sync::Mutex; use tokio_stream::wrappers::ReceiverStream; -use super::token_output_stream::TokenOutputStream; +use super::{ + text_generator::{self, TextGenerator}, + token_generator::{TokenGenerator2, TokenGeneratorTrait}, + token_output_stream::TokenOutputStream, + TextGeneratorTrait, +}; pub struct TextGeneration { model: Arc>, @@ -43,87 +49,61 @@ impl TextGeneration { } } - fn prepare_tokens(tokenizer: &mut TokenOutputStream, prompt: String) -> Result> { - tokenizer.clear(); - let binding = tokenizer.tokenizer().encode(prompt, true).map_err(E::msg)?; - Ok(binding.get_ids().to_vec()) - } + pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result> { + let locked_tokenizer = self.tokenizer.try_lock().unwrap(); + let locked_model = self.model.try_lock().unwrap(); + + let stop_tokens = vec!["<|endoftext|>", ""]; + + let eos_tokens: HashSet = stop_tokens + .into_iter() + .map(|token| { + locked_tokenizer + .tokenizer() + .token_to_id(token) + .unwrap_or_default() + }) + .collect::>(); + + let parameter = GenerateParameter { + max_new_tokens: sample_len, + }; + + let model = Box::new(locked_model.clone()); + let sampler = Box::new(LogitsProcessor::new(42, None, None)); + + let token_generator: Box = + Box::new(TokenGenerator2::new(eos_tokens, parameter, model, sampler)); + + let mut text_generator = TextGenerator::new( + TokenOutputStream::new(locked_tokenizer.tokenizer().clone()), + token_generator, + ); + + text_generator.init(prompt.to_string())?; - fn process_tokens( - tokenizer: &Mutex, - model: &Mutex, - prompt: String, - sample_len: usize, - stop_tokens: Option>, - mut handle_token: F, - ) -> Result<()> - where - F: FnMut(String, usize, bool) -> Result<()>, - { - let mut tokenizer = tokenizer.try_lock().unwrap(); - let mut model = model.try_lock().unwrap(); - let mut generated_tokens = Self::prepare_tokens(&mut tokenizer, prompt).unwrap(); - - let repeat_penalty = 1.1; - let repeat_last_n = 64; - let mut logits_processor = LogitsProcessor::new(42, None, None); let start_gen = std::time::Instant::now(); + let mut token_count = 0; - let mut token_generator = TokenGenerator::new(); - token_generator.set_stop_tokens(stop_tokens, &mut tokenizer); - - let mut generated_text: Vec = Vec::new(); - - for _ in 0..sample_len { - let next_token = { - token_generator.next( - &generated_tokens, - &mut logits_processor, - &mut model, - repeat_penalty, - repeat_last_n, - )? - }; - - if let Some(token) = next_token { - generated_tokens.push(token); - generated_text.push(tokenizer.next_token(token).unwrap_or_default().unwrap()); - if token_generator.is_stop_token(&token) { + let mut generated_text = String::new(); + while let Ok(result) = text_generator.next() { + token_count += 1; + match result { + text_generator::TextGeneratorResult::Token((text, _)) => { + generated_text.push_str(&text); + } + text_generator::TextGeneratorResult::Finish(_) => { break; } } } - for (index, text) in generated_text.iter().enumerate() { - handle_token(text.clone(), index, false)?; - } - info!( "{} tokens generated ({:.2} token/s)", - generated_tokens.len(), - generated_tokens.len() as f64 / start_gen.elapsed().as_secs_f64(), + token_count, + token_count as f64 / start_gen.elapsed().as_secs_f64(), ); - Ok(()) - } - - pub fn run(&mut self, prompt: &str, sample_len: usize) -> Result> { - let mut generated_text = String::new(); - Self::process_tokens( - &self.tokenizer, - &self.model, - prompt.to_string(), - sample_len, - None, - |text, _, is_final| { - if is_final { - Ok(()) - } else { - generated_text.push_str(&text); - Ok(()) - } - }, - )?; Ok(Some(generated_text)) }