Skip to content

Commit

Permalink
Reorganize sampler files.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 28, 2024
1 parent 543faa1 commit 08bb356
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use half::f16;
use itertools::Itertools;
use memmap2::Mmap;
use safetensors::SafeTensors;
use sampler::NucleusSampler;
use sampler::nucleus::NucleusSampler;
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex, RwLock};
use tower_http::{cors::CorsLayer, services::ServeDir};
Expand Down
2 changes: 1 addition & 1 deletion src/oai/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use crate::{
sampler::{NucleusParams, NucleusSampler},
sampler::nucleus::{NucleusParams, NucleusSampler},
utils::request_info,
Array, FinishReason, GenerateRequest, ThreadRequest, ThreadState, Token, TokenCounter,
};
Expand Down
2 changes: 1 addition & 1 deletion src/oai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use crate::{
sampler::{NucleusParams, NucleusSampler},
sampler::nucleus::{NucleusParams, NucleusSampler},
utils::request_info,
Array, FinishReason, GenerateRequest, ThreadRequest, ThreadState, Token, TokenCounter,
};
Expand Down
Empty file added src/sampler/mirostat.rs
Empty file.
13 changes: 13 additions & 0 deletions src/sampler/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pub mod mirostat;
pub mod nucleus;

pub trait Sampler {
/// Initialize the sampler state.
fn init(&mut self, model_tokens: &[u16]);
/// Update the raw model output.
fn transform(&self, output: &mut [f32]);
/// Select one token from the distribution.
fn sample(&self, probs: &[f32]) -> u16;
/// Update the sampler state after a token is chosen.
fn update(&mut self, token: u16);
}
11 changes: 1 addition & 10 deletions src/sampler.rs → src/sampler/nucleus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,7 @@ use std::collections::HashMap;
use derivative::Derivative;
use itertools::Itertools;

pub trait Sampler {
/// Initialize the sampler state.
fn init(&mut self, model_tokens: &[u16]);
/// Update the raw model output.
fn transform(&self, output: &mut [f32]);
/// Select one token from the distribution.
fn sample(&self, probs: &[f32]) -> u16;
/// Update the sampler state after a token is chosen.
fn update(&mut self, token: u16);
}
use super::Sampler;

#[derive(Debug, Clone, Derivative)]
#[derivative(Default)]
Expand Down

0 comments on commit 08bb356

Please sign in to comment.