Skip to content

Commit

Permalink
🚧 implements token generator
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 21, 2023
1 parent 6cbc5b0 commit 01118fa
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 162 deletions.
119 changes: 0 additions & 119 deletions src/llm/dummy_text_generator.rs

This file was deleted.

12 changes: 8 additions & 4 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// source: https://github.com/huggingface/candle/blob/main/candle-examples/examples/mistral/main.rs

pub mod dummy_text_generator;
mod generate_parameter;
pub mod model_processor;
pub mod models;
Expand All @@ -10,8 +8,7 @@ pub mod text_generator;
pub mod token_generator;
pub mod token_output_stream;

pub use dummy_text_generator::DummyTextGenerator;
pub use text_generator::TextGenerator;
pub use text_generator::TextGeneratorTrait;

use std::path::PathBuf;

Expand All @@ -27,6 +24,13 @@ use tokenizers::Tokenizer;

use self::models::Models;

#[derive(Debug, PartialEq)]
pub enum FinishReason {
Length,
EosToken,
StopSequence,
}

fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{}B", size_in_bytes)
Expand Down
21 changes: 0 additions & 21 deletions src/llm/text_generator.rs

This file was deleted.

26 changes: 26 additions & 0 deletions src/llm/text_generator/dummy_text_generator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use super::{TextGeneratorResult, TextGeneratorTrait};
use anyhow::Result;

pub struct DummyTextGenerator {
text: String,
}

impl DummyTextGenerator {
pub fn new(text: String) -> Self {
Self { text }
}
}

impl TextGeneratorTrait for DummyTextGenerator {
fn init(&mut self, prompt: String) -> Result<()> {
self.text = prompt;
Ok(())
}

fn next(&mut self) -> Result<TextGeneratorResult> {
todo!()
}
}

#[cfg(test)]
mod tests {}
97 changes: 97 additions & 0 deletions src/llm/text_generator/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use super::{
token_generator::{TokenGeneratorResult, TokenGeneratorTrait},
token_output_stream::TokenOutputStream,
FinishReason,
};
use anyhow::Result;
mod dummy_text_generator;

pub type TextProbability = (String, f32);

#[derive(Debug, PartialEq)]
pub enum TextGeneratorResult {
Token(TextProbability),
Finish(FinishReason),
}

/// Trait for text generation functionality.
pub trait TextGeneratorTrait {
fn init(&mut self, prompt: String) -> Result<()>;
fn next(&mut self) -> Result<TextGeneratorResult>;
}

pub struct TextGenerator {
tokenizer: TokenOutputStream,
token_generator: Box<dyn TokenGeneratorTrait>,
}

impl TextGenerator {
pub fn new(
tokenizer: TokenOutputStream,
token_generator: Box<dyn TokenGeneratorTrait>,
) -> Self {
Self {
tokenizer,
token_generator,
}
}
}

impl TextGeneratorTrait for TextGenerator {
fn init(&mut self, prompt: String) -> Result<()> {
let prompt_tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(anyhow::Error::msg)?;
self.token_generator
.init(prompt_tokens.get_ids().to_vec())?;
Ok(())
}

fn next(&mut self) -> Result<TextGeneratorResult> {
let token = self.token_generator.next()?;
match token {
TokenGeneratorResult::Token((token, probability)) => {
let text = self.tokenizer.next_token(token)?;
match text {
Some(text) => Ok(TextGeneratorResult::Token((text, probability))),
None => Ok(TextGeneratorResult::Token(("".to_string(), 1.0))),
}
}
TokenGeneratorResult::Finish(reason) => Ok(TextGeneratorResult::Finish(reason)),
}
}
}

#[cfg(test)]
mod tests {
use crate::llm::{
generate_parameter::GenerateParameter, token_generator::dummy::DummyTokenGenerator,
};

use super::*;

#[test]
fn test_text_generator() {
let mut text_generator = TextGenerator::new(
TokenOutputStream::new(tokenizers::tokenizer::Tokenizer::new(
tokenizers::models::bpe::BPE::default(),
)),
Box::new(DummyTokenGenerator::new(GenerateParameter {
max_new_tokens: 10,
})),
);
text_generator.init("Hello World".to_string()).unwrap();
for _ in 0..10 {
assert!(match text_generator.next().unwrap() {
TextGeneratorResult::Token((_, _)) => true,
_ => false,
});
}
assert_eq!(
text_generator.next().unwrap(),
TextGeneratorResult::Finish(FinishReason::Length)
);
}
}
5 changes: 4 additions & 1 deletion src/llm/token_generator/dummy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use candle_core::{Device, Tensor};

use super::{FinishReason, TokenGeneratorResult, TokenGeneratorTrait};

struct DummyTokenGenerator {
pub struct DummyTokenGenerator {
parameter: GenerateParameter,
index: usize,
sampler: Box<dyn Sampler>,
Expand All @@ -29,6 +29,9 @@ impl DummyTokenGenerator {
}

impl TokenGeneratorTrait for DummyTokenGenerator {
fn init(&mut self, _prompt_tokens: Vec<u32>) -> Result<()> {
Ok(())
}
fn next(&mut self) -> Result<TokenGeneratorResult> {
self.index += 1;
if self.index > self.parameter.max_new_tokens {
Expand Down
Loading

0 comments on commit 01118fa

Please sign in to comment.