Skip to content

Commit

Permalink
Better way to do SamplerParams.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Mar 2, 2024
1 parent 70ed939 commit 29196bf
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
35 changes: 24 additions & 11 deletions src/api/oai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,34 @@ use crate::sampler::{
Sampler,
};

#[derive(Debug, Default, Clone, Deserialize)]
pub struct SamplerParams {
#[serde(flatten)]
nucleus: Option<NucleusParams>,
#[serde(flatten)]
mirostat: Option<MirostatParams>,
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum SamplerParams {
Nucleus {
#[serde(flatten)]
params: NucleusParams,
},
Mirostat {
#[serde(flatten)]
params: MirostatParams,
},
}

impl Default for SamplerParams {
fn default() -> Self {
Self::Nucleus {
params: Default::default(),
}
}
}

impl From<SamplerParams> for Arc<RwLock<dyn Sampler + Send + Sync>> {
fn from(value: SamplerParams) -> Self {
let SamplerParams { nucleus, mirostat } = value;
match (nucleus, mirostat) {
(None, None) => Arc::new(RwLock::new(NucleusSampler::new(Default::default()))),
(None, Some(params)) => Arc::new(RwLock::new(MirostatSampler::new(params))),
(Some(params), _) => Arc::new(RwLock::new(NucleusSampler::new(params))),
match value {
SamplerParams::Nucleus { params } => Arc::new(RwLock::new(NucleusSampler::new(params))),
SamplerParams::Mirostat { params } => {
Arc::new(RwLock::new(MirostatSampler::new(params)))
}
}
}
}
6 changes: 5 additions & 1 deletion src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@ use super::Sampler;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize)]
#[derivative(Default)]
#[serde(default)]
pub struct MirostatParams {
#[derivative(Default(value = "3.0"))]
pub tau: f32,
#[derivative(Default(value = "0.1"))]
pub rate: f32,
#[derivative(Default(value = "128"))]
#[serde(default = "default_threshold")]
pub threshold: usize,
}

fn default_threshold() -> usize {
MirostatParams::default().threshold
}

#[derive(Debug, Clone, Default)]
pub struct MirostatState {
pub max_surprise: f32,
Expand Down

0 comments on commit 29196bf

Please sign in to comment.