Skip to content

Commit

Permalink
Fix sampler serialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Mar 1, 2024
1 parent b088699 commit 0e38864
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
22 changes: 11 additions & 11 deletions assets/configs/Config.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
[model]
path = "assets/models/RWKV-4-World-0.4B-v1-20230529-ctx4096.st" # Path to the model.
quant = 0 # Layers to be quantized.
quant_type = "Int8" # Quantization type ("Int8" or "NF4").
turbo = true # Whether to use alternative GEMM kernel to speed-up long prompts.
token_chunk_size = 32 # Size of token chunk that is inferred at once. For high end GPUs, this could be 64 or 128 (faster).
head_chunk_size = 8192 # DO NOT modify this if you don't know what you are doing.
state_chunk_size = 4 # The chunk size of layers in model state.
max_runtime_batch = 8 # The maximum batches that can be scheduled for inference at the same time.
max_batch = 16 # The maximum batches that are cached on GPU.
embed_device = "Cpu" # Device to put the embed tensor ("Cpu" or "Gpu").
stop = ["\n\n"] # Additional stop words in generation.
embed_device = "Cpu" # Device to put the embed tensor ("Cpu" or "Gpu").
head_chunk_size = 8192 # DO NOT modify this if you don't know what you are doing.
max_batch = 16 # The maximum batches that are cached on GPU.
max_runtime_batch = 8 # The maximum batches that can be scheduled for inference at the same time.
path = "assets/models/RWKV-x060-World-3B-v2-20240228-ctx4096.st" # Path to the model.
quant = 0 # Layers to be quantized.
quant_type = "Int8" # Quantization type ("Int8" or "NF4").
state_chunk_size = 4 # The chunk size of layers in model state.
stop = ["\n\n"] # Additional stop words in generation.
token_chunk_size = 128 # Size of token chunk that is inferred at once. For high end GPUs, this could be 64 or 128 (faster).
turbo = true # Whether to use alternative GEMM kernel to speed-up long prompts.

[tokenizer]
path = "assets/tokenizer/rwkv_vocab_v20230424.json" # Path to the tokenizer.
Expand Down
9 changes: 5 additions & 4 deletions src/api/oai/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ use futures_util::{Stream, StreamExt};
use itertools::Itertools;
use regex::Regex;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use super::SamplerParams;
use crate::{
api::request_info, Array, FinishReason, GenerateRequest, ThreadRequest, ThreadState, Token,
TokenCounter,
api::request_info, sampler::Sampler, Array, FinishReason, GenerateRequest, ThreadRequest,
ThreadState, Token, TokenCounter,
};

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
Expand Down Expand Up @@ -67,7 +68,7 @@ impl Default for ChatRequest {
stop: Array::Item("\n\n".into()),
stream: false,
bias: HashMap::new(),
sampler: SamplerParams::Nucleus(Default::default()),
sampler: Default::default(),
}
}
}
Expand Down Expand Up @@ -110,7 +111,7 @@ impl From<ChatRequest> for GenerateRequest {
let max_tokens = max_tokens.min(crate::MAX_TOKENS);
let stop = stop.into();
let bias = Arc::new(bias);
let sampler = sampler.into();
let sampler: Arc<RwLock<dyn Sampler + Send + Sync>> = sampler.into();

Self {
prompt,
Expand Down
2 changes: 1 addition & 1 deletion src/api/oai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl Default for CompletionRequest {
stop: Array::default(),
stream: false,
bias: HashMap::new(),
sampler: SamplerParams::Nucleus(Default::default()),
sampler: Default::default(),
}
}
}
Expand Down
19 changes: 11 additions & 8 deletions src/api/oai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ use crate::sampler::{
Sampler,
};

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum SamplerParams {
Nucleus(NucleusParams),
Mirostat(MirostatParams),
#[derive(Debug, Default, Clone, Deserialize)]
pub struct SamplerParams {
#[serde(flatten)]
nucleus: Option<NucleusParams>,
#[serde(flatten)]
mirostat: Option<MirostatParams>,
}

impl From<SamplerParams> for Arc<RwLock<dyn Sampler + Send + Sync>> {
fn from(value: SamplerParams) -> Self {
match value {
SamplerParams::Nucleus(params) => Arc::new(RwLock::new(NucleusSampler::new(params))),
SamplerParams::Mirostat(params) => Arc::new(RwLock::new(MirostatSampler::new(params))),
let SamplerParams { nucleus, mirostat } = value;
match (nucleus, mirostat) {
(None, None) => Arc::new(RwLock::new(NucleusSampler::new(Default::default()))),
(None, Some(params)) => Arc::new(RwLock::new(MirostatSampler::new(params))),
(Some(params), _) => Arc::new(RwLock::new(NucleusSampler::new(params))),
}
}
}
8 changes: 5 additions & 3 deletions src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::Sampler;
pub struct MirostatParams {
#[derivative(Default(value = "3.0"))]
pub tau: f32,
#[derivative(Default(value = "1.0"))]
#[derivative(Default(value = "0.1"))]
pub rate: f32,
#[derivative(Default(value = "128"))]
pub threshold: usize,
Expand Down Expand Up @@ -50,8 +50,9 @@ impl MirostatSampler {

fn compute_k(&self, probs: &[f32], s: f32) -> usize {
let n = probs.len() as f32;
let tau = self.state.max_surprise;
let eps = s - 1.0;
let k = (eps * 2.0_f32.powf(self.params.tau) / (1.0 - n.powf(-eps))).powf(1.0 / s);
let k = (eps * 2.0_f32.powf(tau) / (1.0 - n.powf(-eps))).powf(1.0 / s);
k.round() as usize
}
}
Expand All @@ -70,7 +71,8 @@ impl Sampler for MirostatSampler {
let sorted_probs = sorted.iter().map(|x| x.2).collect_vec();

let s = self.estimate_s(&sorted_probs);
let k = self.compute_k(&sorted_probs, s);
let k = self.compute_k(&sorted_probs, s) + 1;
let k = k.min(probs.len() - 1);

let sum = sorted.get(k).map(|&(_, cum, _)| cum).unwrap_or_default();
let rand = fastrand::f32() * sum;
Expand Down
1 change: 0 additions & 1 deletion src/sampler/nucleus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use super::Sampler;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize)]
#[derivative(Default)]
#[serde(default)]
pub struct NucleusParams {
#[derivative(Default(value = "1.0"))]
pub top_p: f32,
Expand Down

0 comments on commit 0e38864

Please sign in to comment.