Skip to content

Commit

Permalink
chat_passthrough WIP 1
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Oct 23, 2023
1 parent 993a615 commit 78010d9
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 77 deletions.
47 changes: 44 additions & 3 deletions src/cached_tokenizers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
use reqwest::header::AUTHORIZATION;
use tracing::info;
use tokio::io::AsyncWriteExt;
use std::path::Path;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use tokio::sync::RwLock as ARwLock;
use tokenizers::Tokenizer;
use reqwest::header::AUTHORIZATION;
use tracing::info;

use crate::global_context::GlobalContext;
use crate::caps::CodeAssistantCaps;


#[derive(Debug, Clone, PartialEq, Eq)]
Expand All @@ -10,7 +17,7 @@ pub struct Error {
pub data: Option<serde_json::Value>,
}

pub async fn download_tokenizer_file(
async fn _download_tokenizer_file(
http_client: &reqwest::Client,
http_path: &str,
api_token: String,
Expand Down Expand Up @@ -46,3 +53,37 @@ pub async fn download_tokenizer_file(
).await.map_err(|e| format!("failed to write to file: {}", e))?;
Ok(())
}

pub async fn cached_tokenizer(
caps: Arc<StdRwLock<CodeAssistantCaps>>,
global_context: Arc<ARwLock<GlobalContext>>,
model_name: String,
) -> Result<Arc<StdRwLock<Tokenizer>>, String> {
let mut cx_locked = global_context.write().await;
let client2 = cx_locked.http_client.clone();
let cache_dir = cx_locked.cache_dir.clone();
let tokenizer_arc = match cx_locked.tokenizer_map.get(&model_name) {
Some(arc) => arc.clone(),
None => {
let tokenizer_cache_dir = std::path::PathBuf::from(cache_dir).join("tokenizers");
tokio::fs::create_dir_all(&tokenizer_cache_dir)
.await
.expect("failed to create cache dir");
let path = tokenizer_cache_dir.join(model_name.clone()).join("tokenizer.json");
// Download it while it's locked, so another download won't start.
let http_path;
{
// To avoid deadlocks, in all other places locks must be in the same order
let caps_locked = caps.read().unwrap();
let rewritten_model_name = caps_locked.tokenizer_rewrite_path.get(&model_name).unwrap_or(&model_name);
http_path = caps_locked.tokenizer_path_template.replace("$MODEL", rewritten_model_name);();
}
_download_tokenizer_file(&client2, http_path.as_str(), cx_locked.cmdline.api_key.clone(), &path).await?;
let tokenizer = Tokenizer::from_file(path).map_err(|e| format!("failed to load tokenizer: {}", e))?;
let arc = Arc::new(StdRwLock::new(tokenizer));
cx_locked.tokenizer_map.insert(model_name.clone(), arc.clone());
arc
}
};
Ok(tokenizer_arc)
}
8 changes: 5 additions & 3 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::caps::CodeAssistantCaps;
use crate::completion_cache::CompletionCache;
use crate::telemetry_storage;
use crate::vecdb_search::VecdbSearch;
use crate::custom_error::ScratchError;
use hyper::StatusCode;


#[derive(Debug, StructOpt, Clone)]
Expand Down Expand Up @@ -77,7 +79,7 @@ pub async fn caps_background_reload(

pub async fn try_load_caps_quickly_if_not_present(
global_context: Arc<ARwLock<GlobalContext>>,
) -> Result<Arc<StdRwLock<CodeAssistantCaps>>, String> {
) -> Result<Arc<StdRwLock<CodeAssistantCaps>>, ScratchError> {
let caps_last_attempted_ts;
{
let cx_locked = global_context.write().await;
Expand All @@ -88,7 +90,7 @@ pub async fn try_load_caps_quickly_if_not_present(
}
let now = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
if caps_last_attempted_ts + CAPS_RELOAD_BACKOFF > now {
return Err("server is not reachable, no caps available".to_string());
return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, "server is not reachable, no caps available".to_string()));
}
let caps_result = crate::caps::load_caps(
CommandLine::from_args()
Expand All @@ -104,7 +106,7 @@ pub async fn try_load_caps_quickly_if_not_present(
Ok(caps)
},
Err(e) => {
Err(format!("server is not reachable: {}", e))
return Err(ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, format!("server is not reachable: {}", e)));
}
}
}
Expand Down
88 changes: 23 additions & 65 deletions src/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ use hyper::{Body, Request, Response, Server, Method, StatusCode};
use hyper::server::conn::AddrStream;
use hyper::service::{make_service_fn, service_fn};
use serde_json::json;
use tokenizers::Tokenizer;

use crate::cached_tokenizers;
use crate::caps;
use crate::scratchpads;
use crate::call_validation::{CodeCompletionPost, ChatPost};
Expand All @@ -21,49 +19,12 @@ use crate::custom_error::ScratchError;
use crate::telemetry_basic;
use crate::telemetry_snippets;
use crate::completion_cache;
// use crate::vecdb_search::VecdbSearch;


async fn _get_caps_and_tokenizer(
global_context: Arc<ARwLock<GlobalContext>>,
model_name: String,
) -> Result<(Arc<StdRwLock<CodeAssistantCaps>>, Arc<StdRwLock<Tokenizer>>, reqwest::Client, String), String> {
let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?;
let mut cx_locked = global_context.write().await;
let client1 = cx_locked.http_client.clone();
let client2 = cx_locked.http_client.clone();
let cache_dir = cx_locked.cache_dir.clone();
let tokenizer_arc = match cx_locked.tokenizer_map.get(&model_name) {
Some(arc) => arc.clone(),
None => {
let tokenizer_cache_dir = std::path::PathBuf::from(cache_dir).join("tokenizers");
tokio::fs::create_dir_all(&tokenizer_cache_dir)
.await
.expect("failed to create cache dir");
let path = tokenizer_cache_dir.join(model_name.clone()).join("tokenizer.json");
// Download it while it's locked, so another download won't start.
let http_path;
{
// To avoid deadlocks, in all other places locks must be in the same order
let caps_locked = caps.read().unwrap();
let rewritten_model_name = caps_locked.tokenizer_rewrite_path.get(&model_name).unwrap_or(&model_name);
http_path = caps_locked.tokenizer_path_template.replace("$MODEL", rewritten_model_name);();
}
cached_tokenizers::download_tokenizer_file(&client2, http_path.as_str(), cx_locked.cmdline.api_key.clone(), &path).await?;
let tokenizer = Tokenizer::from_file(path).map_err(|e| format!("failed to load tokenizer: {}", e))?;
let arc = Arc::new(StdRwLock::new(tokenizer));
cx_locked.tokenizer_map.insert(model_name.clone(), arc.clone());
arc
}
};
Ok((caps, tokenizer_arc, client1, cx_locked.cmdline.api_key.clone()))
}

async fn _lookup_code_completion_scratchpad(
global_context: Arc<ARwLock<GlobalContext>>,
caps: Arc<StdRwLock<CodeAssistantCaps>>,
code_completion_post: &CodeCompletionPost,
) -> Result<(String, String, serde_json::Value), String> {
let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?;
let caps_locked = caps.read().unwrap();
let (model_name, recommended_model_record) =
caps::which_model_to_use(
Expand All @@ -80,10 +41,9 @@ async fn _lookup_code_completion_scratchpad(
}

async fn _lookup_chat_scratchpad(
global_context: Arc<ARwLock<GlobalContext>>,
caps: Arc<StdRwLock<CodeAssistantCaps>>,
chat_post: &ChatPost,
) -> Result<(String, String, serde_json::Value), String> {
let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?;
let caps_locked = caps.read().unwrap();
let (model_name, recommended_model_record) =
caps::which_model_to_use(
Expand All @@ -103,8 +63,9 @@ pub async fn handle_v1_code_completion(
global_context: Arc<ARwLock<GlobalContext>>,
code_completion_post: &mut CodeCompletionPost
) -> Result<Response<Body>, ScratchError> {
let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?;
let (model_name, scratchpad_name, scratchpad_patch) = _lookup_code_completion_scratchpad(
global_context.clone(),
caps.clone(),
&code_completion_post,
).await.map_err(|e| {
ScratchError::new(StatusCode::BAD_REQUEST, format!("{}", e))
Expand All @@ -119,16 +80,11 @@ pub async fn handle_v1_code_completion(
code_completion_post.scratchpad = scratchpad_name.clone();
}
code_completion_post.parameters.temperature = Some(code_completion_post.parameters.temperature.unwrap_or(0.2));
let (_caps, tokenizer_arc, client1, api_key) = _get_caps_and_tokenizer(
global_context.clone(),
model_name.clone(),
).await.map_err(|e|
ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e))
)?;

let cache_arc = global_context.read().await.completions_cache.clone();
let tele_storage = global_context.read().await.telemetry.clone();
if (!code_completion_post.no_cache) {
let (client1, api_key, cache_arc, tele_storage) = {
let cx_locked = global_context.write().await;
(cx_locked.http_client.clone(), cx_locked.cmdline.api_key.clone(), cx_locked.completions_cache.clone(), cx_locked.telemetry.clone())
};
if !code_completion_post.no_cache {
let cache_key = completion_cache::cache_key_from_post(&code_completion_post);
let cached_maybe = completion_cache::cache_get(cache_arc.clone(), cache_key.clone());
if let Some(cached_json_value) = cached_maybe {
Expand All @@ -142,13 +98,15 @@ pub async fn handle_v1_code_completion(
}

let mut scratchpad = scratchpads::create_code_completion_scratchpad(
global_context.clone(),
caps,
model_name.clone(),
code_completion_post.clone(),
&scratchpad_name,
&scratchpad_patch,
tokenizer_arc.clone(),
cache_arc.clone(),
tele_storage.clone(),
).map_err(|e|
).await.map_err(|e|
ScratchError::new(StatusCode::BAD_REQUEST, e)
)?;
let t1 = std::time::Instant::now();
Expand Down Expand Up @@ -185,8 +143,9 @@ async fn handle_v1_chat(
let mut chat_post = serde_json::from_slice::<ChatPost>(&body_bytes).map_err(|e|
ScratchError::new(StatusCode::BAD_REQUEST, format!("JSON problem: {}", e))
)?;
let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?;
let (model_name, scratchpad_name, scratchpad_patch) = _lookup_chat_scratchpad(
global_context.clone(),
caps.clone(),
&chat_post,
).await.map_err(|e| {
ScratchError::new(StatusCode::BAD_REQUEST, format!("{}", e))
Expand All @@ -196,21 +155,20 @@ async fn handle_v1_chat(
}
chat_post.parameters.temperature = Some(chat_post.parameters.temperature.unwrap_or(0.2));
chat_post.model = model_name.clone();
let (_caps, tokenizer_arc, client1, api_key) = _get_caps_and_tokenizer(
global_context.clone(),
model_name.clone(),
).await.map_err(|e|
ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR,format!("Tokenizer: {}", e))
)?;

let (client1, api_key) = {
let cx_locked = global_context.write().await;
(cx_locked.http_client.clone(), cx_locked.cmdline.api_key.clone())
};
let vecdb_search = global_context.read().await.vecdb_search.clone();
let mut scratchpad = scratchpads::create_chat_scratchpad(
global_context.clone(),
caps,
model_name.clone(),
chat_post.clone(),
&scratchpad_name,
&scratchpad_patch,
tokenizer_arc.clone(),
vecdb_search,
).map_err(|e|
).await.map_err(|e|
ScratchError::new(StatusCode::BAD_REQUEST, e)
)?;
let t1 = std::time::Instant::now();
Expand Down
80 changes: 80 additions & 0 deletions src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use tracing::info;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use tokio::sync::Mutex as AMutex;
use async_trait::async_trait;

use crate::scratchpad_abstract::ScratchpadAbstract;
use crate::call_validation::ChatPost;
use crate::call_validation::ChatMessage;
use crate::call_validation::SamplingParameters;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes;
use crate::vecdb_search::{VecdbSearch, embed_vecdb_results};


const DEBUG: bool = true;


// #[derive(Debug)]
pub struct ChatPassthrough {
pub post: ChatPost,
pub default_system_message: String,
pub limit_bytes: usize,
pub vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
pub limited_msgs: Vec<ChatMessage>,
}

impl ChatPassthrough {
pub fn new(
post: ChatPost,
vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
) -> Self {
ChatPassthrough {
post,
default_system_message: "".to_string(),
limit_bytes: 4096*3, // one token translates to 3 bytes (not unicode chars)
vecdb_search,
limited_msgs: Vec::new(),
}
}
}

#[async_trait]
impl ScratchpadAbstract for ChatPassthrough {
fn apply_model_adaptation_patch(
&mut self,
patch: &serde_json::Value,
) -> Result<(), String> {
self.default_system_message = patch.get("default_system_message").and_then(|x| x.as_str()).unwrap_or("").to_string();
self.limit_bytes = patch.get("limit_bytes").and_then(|x| x.as_u64()).unwrap_or(4096*3) as usize;
Ok(())
}

async fn prompt(
&mut self,
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
let limited_msgs: Vec<ChatMessage> = limit_messages_history_in_bytes(&self.post, context_size, self.limit_bytes, &self.default_system_message)?;
info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &self.limited_msgs.len());
Ok("".to_string())
}

fn response_n_choices(
&mut self,
choices: Vec<String>,
stopped: Vec<bool>,
) -> Result<serde_json::Value, String> {
unimplemented!()
}

fn response_streaming(
&mut self,
delta: String,
stop_toks: bool,
stop_length: bool,
) -> Result<(serde_json::Value, bool), String> {
unimplemented!()
}
}

Loading

0 comments on commit 78010d9

Please sign in to comment.