Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov authored and valaises committed Oct 19, 2023
1 parent a85a5c4 commit d3ba5e9
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 71 deletions.
1 change: 1 addition & 0 deletions src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct CodeAssistantCaps {
pub tokenizer_path_template: String,
pub tokenizer_rewrite_path: HashMap<String, String>,
pub telemetry_basic_dest: String,
pub telemetry_corrected_snippets_dest: String,
#[serde(default)]
pub code_completion_models: HashMap<String, ModelRecord>,
pub code_completion_default_model: String,
Expand Down
22 changes: 7 additions & 15 deletions src/forward_to_hf_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_hf_style_endpoint(
save_url: &mut String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -34,34 +30,30 @@ pub async fn forward_to_hf_style_endpoint(
"inputs": prompt,
"parameters": params_json,
});
let req = client.post(&url)
let req = client.post(save_url)
.headers(headers)
.body(data.to_string())
.send()
.await;
let resp = req.map_err(|e| format!("{}", e))?;
let status_code = resp.status().as_u16();
let response_txt = resp.text().await.map_err(|e|
format!("reading from socket {}: {}", url, e)
format!("reading from socket {}: {}", save_url, e)
)?;
if status_code != 200 {
return Err(format!("{} status={} text {}", url, status_code, response_txt));
return Err(format!("{} status={} text {}", save_url, status_code, response_txt));
}
Ok(serde_json::from_str(&response_txt).unwrap())
}


pub async fn forward_to_hf_style_endpoint_streaming(
save_url: &mut String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -77,11 +69,11 @@ pub async fn forward_to_hf_style_endpoint_streaming(
"stream": true,
});

let builder = client.post(&url)
let builder = client.post(save_url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
format!("can't stream from {}: {}", url, e)
format!("can't stream from {}: {}", save_url, e)
)?;
Ok(event_source)
}
20 changes: 7 additions & 13 deletions src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_openai_style_endpoint(
mut save_url: &String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -31,34 +28,31 @@ pub async fn forward_to_openai_style_endpoint(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let req = client.post(&url)
let req = client.post(save_url)
.headers(headers)
.body(data.to_string())
.send()
.await;
let resp = req.map_err(|e| format!("{}", e))?;
let status_code = resp.status().as_u16();
let response_txt = resp.text().await.map_err(|e|
format!("reading from socket {}: {}", url, e)
format!("reading from socket {}: {}", save_url, e)
)?;
// info!("forward_to_openai_style_endpoint: {} {}\n{}", url, status_code, response_txt);
if status_code != 200 {
return Err(format!("{} status={} text {}", url, status_code, response_txt));
return Err(format!("{} status={} text {}", save_url, status_code, response_txt));
}
Ok(serde_json::from_str(&response_txt).unwrap())
}

pub async fn forward_to_openai_style_endpoint_streaming(
mut save_url: &String,
save_url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -72,11 +66,11 @@ pub async fn forward_to_openai_style_endpoint_streaming(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let builder = client.post(&url)
let builder = client.post(save_url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
format!("can't stream from {}: {}", url, e)
format!("can't stream from {}: {}", save_url, e)
)?;
Ok(event_source)
}
21 changes: 8 additions & 13 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,23 @@ pub async fn scratchpad_interaction_not_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone();

let model_says = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
Expand Down Expand Up @@ -141,30 +139,27 @@ pub async fn scratchpad_interaction_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let save_url: String = endpoint_template.replace("$MODEL", &model_name).clone();
let mut event_source = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
&save_url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
endpoint_template.clone().replace("$MODEL", &model_name),
scope.clone(),
false,
e.to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/telemetry_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ pub async fn compress_basic_telemetry_to_file(
}
// even if there's an error with i/o, storage is now clear, preventing infinite memory growth
info!("basic telemetry save \"{}\"", fn_net.to_str().unwrap());
let io_result = _file_save(fn_net.clone(), big_json_net).await;
let io_result = file_save(fn_net.clone(), big_json_net).await;
if io_result.is_err() {
error!("error: {}", io_result.err().unwrap());
}
}

async fn _file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> {
pub async fn file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> {
let mut f = tokio::fs::File::create(path).await.map_err(|e| format!("{:?}", e))?;
f.write_all(serde_json::to_string_pretty(&json).unwrap().as_bytes()).await.map_err(|e| format!("{}", e))?;
Ok(())
Expand Down
70 changes: 62 additions & 8 deletions src/telemetry_snippets.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use tracing::info;
use tracing::{error, info};
use std::sync::Arc;
use tokio::sync::RwLock as ARwLock;
use std::sync::RwLock as StdRwLock;
// use std::collections::HashMap;
// use reqwest_eventsource::Event;
// use futures::StreamExt;
// use async_stream::stream;
// use serde_json::json;
// use crate::caps::CodeAssistantCaps;
use std::path::PathBuf;

use crate::telemetry_basic::file_save;
use crate::call_validation;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use crate::global_context;
use crate::completion_cache;
use crate::telemetry_storage;
Expand Down Expand Up @@ -56,6 +54,7 @@ pub struct SnippetTelemetry {
// pub remaining_percent_300s: f64,
// pub remaining_percent_walkaway: f64,
// pub walkaway_ms: u64,
pub created_at: i64
}

pub fn snippet_register(
Expand All @@ -71,6 +70,7 @@ pub fn snippet_register(
accepted: false,
corrected_by_user: "".to_string(),
remaining_percent_30s: 0.0,
created_at: chrono::Local::now().timestamp(),
};
storage_locked.tele_snippet_next_id += 1;
storage_locked.tele_snippets.push(snip);
Expand Down Expand Up @@ -108,7 +108,7 @@ pub async fn snippet_accepted(
return false;
}

pub async fn sources_changed(
pub async fn sources_changed( // TODO
gcx: Arc<ARwLock<global_context::GlobalContext>>,
uri: &String,
text: &String,
Expand All @@ -135,10 +135,12 @@ pub async fn sources_changed(
if !orig_text.is_some() {
continue;
}
// let time_from_creation = chrono::Local::now().timestamp() - snip.created_at;
let (valid1, mut gray_suggested) = if_head_tail_equal_return_added_text(
orig_text.unwrap(),
text
);
snip.corrected_by_user = gray_suggested.clone();
gray_suggested = gray_suggested.replace("\r", "");
info!("valid1: {:?}, gray_suggested: {:?}", valid1, gray_suggested);
info!("orig grey_text: {:?}", snip.grey_text);
Expand Down Expand Up @@ -219,3 +221,55 @@ pub fn unchanged_percentage(
let largest_of_two = text_a.len().max(text_b.len());
(common as f64) / (largest_of_two as f64)
}

fn _compress_telemetry_snippets(
storage: Arc<StdRwLock<telemetry_storage::Storage>>,
) -> serde_json::Value {
let mut records = serde_json::json!([]);
{
let storage_locked = storage.read().unwrap();
for rec in storage_locked.tele_snippets.iter() {
let json_dict = serde_json::to_value(rec).unwrap();
records.as_array_mut().unwrap().push(json_dict);
}
}
records
}


pub async fn compress_telemetry_snippets_to_file(
cx: Arc<ARwLock<global_context::GlobalContext>>,
) {
let now = chrono::Local::now();
let cache_dir: PathBuf;
let storage: Arc<StdRwLock<telemetry_storage::Storage>>;
let enduser_client_version;
{
let cx_locked = cx.read().await;
storage = cx_locked.telemetry.clone();
cache_dir = cx_locked.cache_dir.clone();
enduser_client_version = cx_locked.cmdline.enduser_client_version.clone();
}
let dir = cache_dir.join("telemetry").join("compressed");
tokio::fs::create_dir_all(dir.clone()).await.unwrap_or_else(|_| {});

let records = _compress_telemetry_snippets(storage.clone());
let fn_snip = dir.join(format!("{}-snip.json", now.format("%Y%m%d-%H%M%S")));
let mut big_json_snip = json!({
"records": records,
"ts_end": now.timestamp(),
"teletype": "snippets",
"enduser_client_version": enduser_client_version,
});
{
let mut storage_locked = storage.write().unwrap();
storage_locked.tele_snippets.clear();
big_json_snip.as_object_mut().unwrap().insert("ts_start".to_string(), json!(storage_locked.last_flushed_ts));
storage_locked.last_flushed_ts = now.timestamp();
}
let io_result = file_save(fn_snip, big_json_snip).await;
if io_result.is_err() {
error!("cannot save telemetry file: {}", io_result.err().unwrap());
}
}

Loading

0 comments on commit d3ba5e9

Please sign in to comment.