Skip to content

Commit

Permalink
Cache the LLM model inside the service (#2258)
Browse files Browse the repository at this point in the history
* Cache the LLM model inside the service

Instead of fetching it for every query.

* Update the LLM README file

Remove the TODO for reusing the model across queries, and detail that
the model is only fetched when handling the first prompt.
  • Loading branch information
jvff authored Jul 18, 2024
1 parent 85c3121 commit 1475e74
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
11 changes: 7 additions & 4 deletions examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@ CAVEAT:
([#1981](https://github.com/linera-io/linera-protocol/issues/1981)) or in an external
decentralized storage.

* We should also not download the model at every query ([#1999](https://github.com/linera-io/linera-protocol/issues/1999)).

* Running larger LLMs with acceptable performance will likely require hardware acceleration ([#1931](https://github.com/linera-io/linera-protocol/issues/1931)).

* The service currently is restarted when the wallet receives a new block for the chain where the
application is running from. That means it fetches the model again, which is inefficient. The
service should be allowed to continue executing in that case
([#2160](https://github.com/linera-io/linera-protocol/issues/2160)).


# How It Works

Expand All @@ -26,8 +29,8 @@ at `model.bin` and `tokenizer.json`.
The application's service exposes a single GraphQL field called `prompt` which takes a prompt
as input and returns a response.

When a prompt is submitted, the application's service uses the `fetch_url`
system API to inject the model and tokenizer. Subsequently, the model bytes are converted
When the first prompt is submitted, the application's service uses the `fetch_url`
system API to fetch the model and tokenizer. Subsequently, the model bytes are converted
to the GGUF format where it can be used for inference.

# Usage
Expand Down
11 changes: 7 additions & 4 deletions examples/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ CAVEAT:
([#1981](https://github.com/linera-io/linera-protocol/issues/1981)) or in an external
decentralized storage.
* We should also not download the model at every query ([#1999](https://github.com/linera-io/linera-protocol/issues/1999)).
* Running larger LLMs with acceptable performance will likely require hardware acceleration ([#1931](https://github.com/linera-io/linera-protocol/issues/1931)).
* The service currently is restarted when the wallet receives a new block for the chain where the
application is running from. That means it fetches the model again, which is inefficient. The
service should be allowed to continue executing in that case
([#2160](https://github.com/linera-io/linera-protocol/issues/2160)).
# How It Works
Expand All @@ -28,8 +31,8 @@ at `model.bin` and `tokenizer.json`.
The application's service exposes a single GraphQL field called `prompt` which takes a prompt
as input and returns a response.
When a prompt is submitted, the application's service uses the `fetch_url`
system API to inject the model and tokenizer. Subsequently, the model bytes are converted
When the first prompt is submitted, the application's service uses the `fetch_url`
system API to fetch the model and tokenizer. Subsequently, the model bytes are converted
to the GGUF format where it can be used for inference.
# Usage
Expand Down
32 changes: 18 additions & 14 deletions examples/llm/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ mod random;
mod state;
mod token;

use std::io::{Cursor, Seek, SeekFrom};
use std::{
io::{Cursor, Seek, SeekFrom},
sync::Arc,
};

use async_graphql::{Context, EmptyMutation, EmptySubscription, Object, Request, Response, Schema};
use candle_core::{
Expand All @@ -25,7 +28,7 @@ use tokenizers::Tokenizer;
use crate::token::TokenOutputStream;

pub struct LlmService {
runtime: ServiceRuntime<Self>,
model_context: Arc<ModelContext>,
}

linera_sdk::service!(LlmService);
Expand All @@ -39,7 +42,7 @@ struct QueryRoot {}
#[Object]
impl QueryRoot {
async fn prompt(&self, ctx: &Context<'_>, prompt: String) -> String {
let model_context = ctx.data::<ModelContext>().unwrap();
let model_context = ctx.data::<Arc<ModelContext>>().unwrap();
model_context.run_model(&prompt).unwrap()
}
}
Expand Down Expand Up @@ -73,23 +76,24 @@ impl Service for LlmService {
type Parameters = ();

async fn new(runtime: ServiceRuntime<Self>) -> Self {
LlmService { runtime }
let raw_weights = runtime
.fetch_url("https://huggingface.co/karpathy/tinyllamas/resolve/main/stories42M.bin");
info!("got weights: {}B", raw_weights.len());
let tokenizer_bytes = runtime.fetch_url(
"https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json",
);
let model_context = Arc::new(ModelContext {
model: raw_weights,
tokenizer: tokenizer_bytes,
});
LlmService { model_context }
}

async fn handle_query(&self, request: Request) -> Response {
let query_string = &request.query;
info!("query: {}", query_string);
let raw_weights = self.runtime.fetch_url("http://localhost:10001/model.bin");
info!("got weights: {}B", raw_weights.len());
let tokenizer_bytes = self
.runtime
.fetch_url("http://localhost:10001/tokenizer.json");
let model_context = ModelContext {
model: raw_weights,
tokenizer: tokenizer_bytes,
};
let schema = Schema::build(QueryRoot {}, EmptyMutation, EmptySubscription)
.data(model_context)
.data(self.model_context.clone())
.finish();
schema.execute(request).await
}
Expand Down

0 comments on commit 1475e74

Please sign in to comment.