Skip to content

Commit

Permalink
char_passthrough WIP 2
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov committed Oct 24, 2023
1 parent 78010d9 commit b767fce
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 31 deletions.
2 changes: 1 addition & 1 deletion src/call_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct CodeCompletionPost {
pub no_cache: bool,
}

#[derive(Debug, Deserialize, Clone)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
Expand Down
8 changes: 3 additions & 5 deletions src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub struct CodeAssistantCaps {
pub cloud_name: String,
pub endpoint_template: String,
pub endpoint_style: String,
pub endpoint_chat_passthrough: String,
pub tokenizer_path_template: String,
pub tokenizer_rewrite_path: HashMap<String, String>,
pub telemetry_basic_dest: String,
Expand Down Expand Up @@ -117,15 +118,12 @@ const KNOWN_MODELS: &str = r#"
}
}
},
"gpt-3.5": {
"gpt-3.5-turbo": {
"n_ctx": 4096,
"supports_scratchpads": {
"PASSTHROUGH": {}
},
"similar_models": [
"gpt3.5",
"gpt-4",
"gpt4"
]
},
"starchat/15b/beta": {
Expand Down Expand Up @@ -208,7 +206,7 @@ const REFACT_DEFAULT_CAPS: &str = r#"
"code_completion_default_model": "smallcloudai/Refact-1_6B-fim",
"code_chat_default_model": "smallcloudai/Refact-1_6B-fim",
"telemetry_basic_dest": "https://www.smallcloud.ai/v1/telemetry-basic",
"running_models": ["smallcloudai/Refact-1_6B-fim"]
"running_models": ["smallcloudai/Refact-1_6B-fim", "gpt-3.5-turbo"]
}
"#;

Expand Down
42 changes: 33 additions & 9 deletions src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,41 @@ use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest_eventsource::EventSource;
use serde_json::json;
use crate::call_validation;
use crate::call_validation::SamplingParameters;
use tracing::info;


pub async fn forward_to_openai_style_endpoint(
mut save_url: &String,
save_url: &mut String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
endpoint_chat_passthrough: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
let is_passthrough = prompt.starts_with("MESSAGES ");
let url = if !is_passthrough { endpoint_template.replace("$MODEL", model_name) } else { endpoint_chat_passthrough.clone() };
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
headers.insert(AUTHORIZATION, HeaderValue::from_str(format!("Bearer {}", bearer).as_str()).unwrap());
}
let data = json!({
let mut data = json!({
"model": model_name,
"prompt": prompt,
"echo": false,
"stream": false,
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
if is_passthrough {
_passthrough_messages_to_json(&mut data, prompt);
} else {
data["prompt"] = serde_json::Value::String(prompt.to_string());
}
let req = client.post(&url)
.headers(headers)
.body(data.to_string())
Expand All @@ -49,29 +57,35 @@ pub async fn forward_to_openai_style_endpoint(
}

pub async fn forward_to_openai_style_endpoint_streaming(
mut save_url: &String,
save_url: &mut String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
endpoint_chat_passthrough: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
let is_passthrough = prompt.starts_with("MESSAGES ");
let url = if !is_passthrough { endpoint_template.replace("$MODEL", model_name) } else { endpoint_chat_passthrough.clone() };
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
headers.insert(AUTHORIZATION, HeaderValue::from_str(format!("Bearer {}", bearer).as_str()).unwrap());
}
let data = json!({
let mut data = json!({
"model": model_name,
"prompt": prompt,
"echo": false,
"stream": true,
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
"echo": false,
});
if is_passthrough {
_passthrough_messages_to_json(&mut data, prompt);
} else {
data["prompt"] = serde_json::Value::String(prompt.to_string());
}
let builder = client.post(&url)
.headers(headers)
.body(data.to_string());
Expand All @@ -80,3 +94,13 @@ pub async fn forward_to_openai_style_endpoint_streaming(
)?;
Ok(event_source)
}

fn _passthrough_messages_to_json(
data: &mut serde_json::Value,
prompt: &str,
) {
assert!(prompt.starts_with("MESSAGES "));
let messages_str = &prompt[9..];
let messages: Vec<call_validation::ChatMessage> = serde_json::from_str(&messages_str).unwrap();
data["messages"] = serde_json::json!(messages);
}
10 changes: 6 additions & 4 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ pub async fn scratchpad_interaction_not_stream(
parameters: &SamplingParameters,
) -> Result<Response<Body>, ScratchError> {
let t2 = std::time::SystemTime::now();
let (endpoint_style, endpoint_template, tele_storage) = {
let (endpoint_style, endpoint_template, endpoint_chat_passthrough, tele_storage) = {
let cx = global_context.write().await;
let caps = cx.caps.clone().unwrap();
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), caps_locked.endpoint_chat_passthrough.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let model_says = if endpoint_style == "hf" {
Expand All @@ -52,6 +52,7 @@ pub async fn scratchpad_interaction_not_stream(
&prompt,
&client,
&endpoint_template,
&endpoint_chat_passthrough,
&parameters,
).await
}.map_err(|e| {
Expand Down Expand Up @@ -137,11 +138,11 @@ pub async fn scratchpad_interaction_stream(
let t1 = std::time::SystemTime::now();
let evstream = stream! {
let scratch: &mut Box<dyn ScratchpadAbstract> = &mut scratchpad;
let (endpoint_style, endpoint_template, tele_storage) = {
let (endpoint_style, endpoint_template, endpoint_chat_passthrough, tele_storage) = {
let cx = global_context.write().await;
let caps = cx.caps.clone().unwrap();
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), caps_locked.endpoint_chat_passthrough.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
loop {
Expand All @@ -163,6 +164,7 @@ pub async fn scratchpad_interaction_stream(
&prompt,
&client,
&endpoint_template,
&endpoint_chat_passthrough,
&parameters,
).await
};
Expand Down
21 changes: 11 additions & 10 deletions src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use tracing::info;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use tokio::sync::Mutex as AMutex;
use async_trait::async_trait;

Expand All @@ -9,10 +8,11 @@ use crate::call_validation::ChatPost;
use crate::call_validation::ChatMessage;
use crate::call_validation::SamplingParameters;
use crate::scratchpads::chat_utils_limit_history::limit_messages_history_in_bytes;
use crate::vecdb_search::{VecdbSearch, embed_vecdb_results};
// use crate::vecdb_search::{VecdbSearch, embed_vecdb_results};
use crate::vecdb_search::VecdbSearch;


const DEBUG: bool = true;
// const DEBUG: bool = true;


// #[derive(Debug)]
Expand All @@ -21,9 +21,10 @@ pub struct ChatPassthrough {
pub default_system_message: String,
pub limit_bytes: usize,
pub vecdb_search: Arc<AMutex<Box<dyn VecdbSearch + Send>>>,
pub limited_msgs: Vec<ChatMessage>,
}

const DEFAULT_LIMIT_BYTES: usize = 4096*3;

impl ChatPassthrough {
pub fn new(
post: ChatPost,
Expand All @@ -32,9 +33,8 @@ impl ChatPassthrough {
ChatPassthrough {
post,
default_system_message: "".to_string(),
limit_bytes: 4096*3, // one token translates to 3 bytes (not unicode chars)
limit_bytes: DEFAULT_LIMIT_BYTES, // one token translates to 3 bytes (not unicode chars)
vecdb_search,
limited_msgs: Vec::new(),
}
}
}
Expand All @@ -46,7 +46,7 @@ impl ScratchpadAbstract for ChatPassthrough {
patch: &serde_json::Value,
) -> Result<(), String> {
self.default_system_message = patch.get("default_system_message").and_then(|x| x.as_str()).unwrap_or("").to_string();
self.limit_bytes = patch.get("limit_bytes").and_then(|x| x.as_u64()).unwrap_or(4096*3) as usize;
self.limit_bytes = patch.get("limit_bytes").and_then(|x| x.as_u64()).unwrap_or(DEFAULT_LIMIT_BYTES as u64) as usize;
Ok(())
}

Expand All @@ -55,9 +55,10 @@ impl ScratchpadAbstract for ChatPassthrough {
context_size: usize,
sampling_parameters_to_patch: &mut SamplingParameters,
) -> Result<String, String> {
let limited_msgs: Vec<ChatMessage> = limit_messages_history_in_bytes(&self.post, context_size, self.limit_bytes, &self.default_system_message)?;
info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &self.limited_msgs.len());
Ok("".to_string())
let limited_msgs: Vec<ChatMessage> = limit_messages_history_in_bytes(&self.post, self.limit_bytes, &self.default_system_message)?;
info!("chat passthrough {} messages -> {} messages after applying limits and possibly adding the default system message", &limited_msgs.len(), &limited_msgs.len());
let prompt = "MESSAGES ".to_string() + &serde_json::to_string(&limited_msgs).unwrap();
Ok(prompt.to_string())
}

fn response_n_choices(
Expand Down
1 change: 0 additions & 1 deletion src/scratchpads/chat_utils_limit_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ pub fn limit_messages_history(

pub fn limit_messages_history_in_bytes(
post: &ChatPost,
context_size: usize,
bytes_limit: usize,
default_system_mesage: &String,
) -> Result<Vec<ChatMessage>, String>
Expand Down
3 changes: 2 additions & 1 deletion src/scratchpads/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ pub async fn create_chat_scratchpad(
vecdb_search: Arc<AMutex<Box<dyn vecdb_search::VecdbSearch + Send>>>,
) -> Result<Box<dyn ScratchpadAbstract>, String> {
let mut result: Box<dyn ScratchpadAbstract>;
let tokenizer_arc: Arc<StdRwLock<Tokenizer>> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?;
if scratchpad_name == "CHAT-GENERIC" {
let tokenizer_arc: Arc<StdRwLock<Tokenizer>> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?;
result = Box::new(chat_generic::GenericChatScratchpad::new(tokenizer_arc, post, vecdb_search));
} else if scratchpad_name == "CHAT-LLAMA2" {
let tokenizer_arc: Arc<StdRwLock<Tokenizer>> = cached_tokenizers::cached_tokenizer(caps, global_context, model_name_for_tokenizer).await?;
result = Box::new(chat_llama2::ChatLlama2::new(tokenizer_arc, post, vecdb_search));
} else if scratchpad_name == "PASSTHROUGH" {
result = Box::new(chat_passthrough::ChatPassthrough::new(post, vecdb_search));
Expand Down

0 comments on commit b767fce

Please sign in to comment.