Skip to content

Commit

Permalink
Better naming.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 5, 2024
1 parent 1f38d78 commit 0390835
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
20 changes: 10 additions & 10 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ fn load_vocab(tokenizer: &Tokenizer) -> Vocabulary {
}
}

async fn load_init_state(
async fn load_initial_state(
context: &Context,
info: &ModelInfo,
model: SafeTensors<'_>,
Expand Down Expand Up @@ -353,14 +353,14 @@ async fn load_runtime(
LoadType::SafeTensors => {
let model = SafeTensors::deserialize(&data)?;

let init_state = match state {
let state = match state {
Some(state) => {
let file = File::open(state.path).await?;
let data = unsafe { Mmap::map(&file) }?;
let model = SafeTensors::deserialize(&data)?;
load_init_state(context, &info, model).await
load_initial_state(context, &info, model).await
}
None => load_init_state(context, &info, model).await,
None => load_initial_state(context, &info, model).await,
};

let model = SafeTensors::deserialize(&data)?;
Expand Down Expand Up @@ -395,32 +395,32 @@ async fn load_runtime(
(ModelVersion::V4, Precision::Fp16) => {
let model = Build::<v4::Model>::build(builder).await?;
let builder = v4::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp16) => {
let model = Build::<v5::Model>::build(builder).await?;
let builder = v5::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp16) => {
let model = Build::<v6::Model>::build(builder).await?;
let builder = v6::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
(ModelVersion::V4, Precision::Fp32) => {
let model = Build::<v4::Model>::build(builder).await?;
let builder = v4::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp32) => {
let model = Build::<v5::Model>::build(builder).await?;
let builder = v5::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp32) => {
let model = Build::<v6::Model>::build(builder).await?;
let builder = v6::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, init_state, tokenizer, vocab).await
Runtime::new(context, builder, reload, state, tokenizer, vocab).await
}
}
}
Expand Down
16 changes: 10 additions & 6 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ impl<T> CachedItem<T> {
}
}

pub fn renew(cached: CachedItem<T>) -> Self {
/// Update an existing cache item's timestamp.
pub fn update(cached: CachedItem<T>) -> Self {
Self {
item: cached.item,
instant: Instant::now(),
Expand Down Expand Up @@ -297,7 +298,7 @@ pub struct Runtime {
state: Arc<dyn State + Send + Sync>,
model: Arc<dyn ModelSerialize + Send + Sync>,
runtime: JobRuntime<InferInput, InferOutput>,
init_state: Option<TensorCpu<f32>>,
initial_state: Option<TensorCpu<f32>>,
tokenizer: Arc<Tokenizer>,
vocab: Arc<Vocabulary>,
slots: Mutex<Vec<SlotState>>,
Expand All @@ -309,14 +310,16 @@ impl Runtime {
context: Context,
builder: B,
reload: ReloadRequest,
init_state: Option<TensorCpu<f32>>,
state: Option<TensorCpu<f32>>,
tokenizer: Tokenizer,
vocab: Vocabulary,
) -> Self
where
J: Job<Info = InferInfo, Input = InferChunk, Output = InferOutput>,
B: JobBuilder<J, Info = InferInfo> + ModelRuntime,
{
let initial_state = state;

let slots = (0..reload.max_batch)
.map(|_| SlotState::default())
.collect();
Expand All @@ -333,7 +336,7 @@ impl Runtime {
state,
model,
runtime,
init_state,
initial_state,
tokenizer: Arc::new(tokenizer),
vocab: Arc::new(vocab),
slots: Mutex::new(slots),
Expand Down Expand Up @@ -387,9 +390,10 @@ impl Runtime {
log::info!("slot {} checks out backed cache of length {}", batch, len);

let prefix = prefix[0..len].to_vec();
let state = self.initial_state.clone();
let reload = match cache.remove(prefix[..].as_token_slice()) {
Some(reload) => CachedItem::renew(reload),
None => CachedItem::new(self.init_state.clone().unwrap_or_else(|| self.state.init())),
Some(reload) => CachedItem::update(reload),
None => CachedItem::new(state.unwrap_or_else(|| self.state.init())),
};
if len > 0 {
let key = Tokens(prefix.clone());
Expand Down

0 comments on commit 0390835

Please sign in to comment.