diff --git a/src/llm/token_generator/mod.rs b/src/llm/token_generator/mod.rs index d385173..85732d4 100644 --- a/src/llm/token_generator/mod.rs +++ b/src/llm/token_generator/mod.rs @@ -96,27 +96,43 @@ pub struct TokenGenerator2 { index: usize, stop_tokens: HashSet, parameter: GenerateParameter, - tokens: Vec, + prompt_tokens: Vec, sampler: Box, model: Box, + next_token: Option, } impl TokenGenerator2 { pub fn new( - tokens: Vec, + prompt_tokens: Vec, stop_tokens: HashSet, parameter: GenerateParameter, model: Box, sampler: Box, ) -> Self { - Self { + let mut token_generator = Self { index: 0, stop_tokens, parameter, - tokens, + prompt_tokens, model, sampler, - } + next_token: None, + }; + token_generator.next_token = Some(token_generator.next_token().unwrap_or_default()); + token_generator + } +} + +impl TokenGenerator2 { + fn next_token(&mut self) -> Result { + let next_token = { + let input = Tensor::new(self.prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let logits = self.model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + self.sampler.sample(&logits)? + }; + Ok(next_token) } } @@ -127,15 +143,20 @@ impl TokenGeneratorTrait for TokenGenerator2 { } let next_token = { - let input = Tensor::new(self.tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; - let logits = self.model.forward(&input, 0)?; + let input = Tensor::new(&[self.next_token.unwrap()], &Device::Cpu)?.unsqueeze(0)?; + let logits = self + .model + .forward(&input, self.prompt_tokens.len() + self.index)?; let logits = logits.squeeze(0)?; self.sampler.sample(&logits)? }; + // todo: repeat penalty + if self.stop_tokens.contains(&next_token) { return Ok(TokenGeneratorResult::Finish(FinishReason::EosToken)); } + self.next_token = Some(next_token); self.index += 1; Ok(TokenGeneratorResult::Token((next_token, 1.0))) } @@ -156,7 +177,8 @@ mod tests { Box::new(DummyModelProcessor::new()), Box::new(DummySampler::new()), ); - for index in 0..10 { + // starting at 1 because model processor and sampler run already in the new function. + for index in 1..11 { assert_eq!( token_generator.next().unwrap(), TokenGeneratorResult::Token((index, 1.0)) @@ -178,7 +200,7 @@ mod tests { Box::new(DummyModelProcessor::new()), Box::new(DummySampler::new()), ); - for index in 0..3 { + for index in 1..3 { assert_eq!( token_generator.next().unwrap(), TokenGeneratorResult::Token((index, 1.0))