Skip to content

Commit

Permalink
✅🚧 implemented token generator
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 21, 2023
1 parent 920e951 commit 6cbc5b0
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions src/llm/token_generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,27 +96,43 @@ pub struct TokenGenerator2 {
index: usize,
stop_tokens: HashSet<u32>,
parameter: GenerateParameter,
tokens: Vec<u32>,
prompt_tokens: Vec<u32>,
sampler: Box<dyn Sampler>,
model: Box<dyn ModelProcessor>,
next_token: Option<u32>,
}

impl TokenGenerator2 {
pub fn new(
tokens: Vec<u32>,
prompt_tokens: Vec<u32>,
stop_tokens: HashSet<u32>,
parameter: GenerateParameter,
model: Box<dyn ModelProcessor>,
sampler: Box<dyn Sampler>,
) -> Self {
Self {
let mut token_generator = Self {
index: 0,
stop_tokens,
parameter,
tokens,
prompt_tokens,
model,
sampler,
}
next_token: None,
};
token_generator.next_token = Some(token_generator.next_token().unwrap_or_default());
token_generator
}
}

impl TokenGenerator2 {
fn next_token(&mut self) -> Result<u32> {
let next_token = {
let input = Tensor::new(self.prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = self.model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
self.sampler.sample(&logits)?
};
Ok(next_token)
}
}

Expand All @@ -127,15 +143,20 @@ impl TokenGeneratorTrait for TokenGenerator2 {
}

let next_token = {
let input = Tensor::new(self.tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = self.model.forward(&input, 0)?;
let input = Tensor::new(&[self.next_token.unwrap()], &Device::Cpu)?.unsqueeze(0)?;
let logits = self
.model
.forward(&input, self.prompt_tokens.len() + self.index)?;
let logits = logits.squeeze(0)?;
self.sampler.sample(&logits)?
};

// todo: repeat penalty

if self.stop_tokens.contains(&next_token) {
return Ok(TokenGeneratorResult::Finish(FinishReason::EosToken));
}
self.next_token = Some(next_token);
self.index += 1;
Ok(TokenGeneratorResult::Token((next_token, 1.0)))
}
Expand All @@ -156,7 +177,8 @@ mod tests {
Box::new(DummyModelProcessor::new()),
Box::new(DummySampler::new()),
);
for index in 0..10 {
// starting at 1 because model processor and sampler run already in the new function.
for index in 1..11 {
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Token((index, 1.0))
Expand All @@ -178,7 +200,7 @@ mod tests {
Box::new(DummyModelProcessor::new()),
Box::new(DummySampler::new()),
);
for index in 0..3 {
for index in 1..3 {
assert_eq!(
token_generator.next().unwrap(),
TokenGeneratorResult::Token((index, 1.0))
Expand Down

0 comments on commit 6cbc5b0

Please sign in to comment.