Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
olegklimov authored and valaises committed Oct 24, 2023
1 parent 78010d9 commit 89b9dff
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 75 deletions.
1 change: 1 addition & 0 deletions src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct CodeAssistantCaps {
pub tokenizer_path_template: String,
pub tokenizer_rewrite_path: HashMap<String, String>,
pub telemetry_basic_dest: String,
pub telemetry_corrected_snippets_dest: String,
#[serde(default)]
pub code_completion_models: HashMap<String, ModelRecord>,
pub code_completion_default_model: String,
Expand Down
16 changes: 4 additions & 12 deletions src/forward_to_hf_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,12 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_hf_style_endpoint(
save_url: &mut String,
url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -34,7 +30,7 @@ pub async fn forward_to_hf_style_endpoint(
"inputs": prompt,
"parameters": params_json,
});
let req = client.post(&url)
let req = client.post(url)
.headers(headers)
.body(data.to_string())
.send()
Expand All @@ -52,16 +48,12 @@ pub async fn forward_to_hf_style_endpoint(


pub async fn forward_to_hf_style_endpoint_streaming(
save_url: &mut String,
url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -77,7 +69,7 @@ pub async fn forward_to_hf_style_endpoint_streaming(
"stream": true,
});

let builder = client.post(&url)
let builder = client.post(url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
Expand Down
14 changes: 4 additions & 10 deletions src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@ use crate::call_validation::SamplingParameters;


pub async fn forward_to_openai_style_endpoint(
mut save_url: &String,
url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<serde_json::Value, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -31,7 +28,7 @@ pub async fn forward_to_openai_style_endpoint(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let req = client.post(&url)
let req = client.post(url)
.headers(headers)
.body(data.to_string())
.send()
Expand All @@ -49,16 +46,13 @@ pub async fn forward_to_openai_style_endpoint(
}

pub async fn forward_to_openai_style_endpoint_streaming(
mut save_url: &String,
url: &String,
bearer: String,
model_name: &str,
prompt: &str,
client: &reqwest::Client,
endpoint_template: &String,
sampling_parameters: &SamplingParameters,
) -> Result<EventSource, String> {
let url = endpoint_template.replace("$MODEL", model_name);
save_url.clone_from(&&url);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap());
if !bearer.is_empty() {
Expand All @@ -72,7 +66,7 @@ pub async fn forward_to_openai_style_endpoint_streaming(
"temperature": sampling_parameters.temperature,
"max_tokens": sampling_parameters.max_new_tokens,
});
let builder = client.post(&url)
let builder = client.post(url)
.headers(headers)
.body(data.to_string());
let event_source: EventSource = EventSource::new(builder).map_err(|e|
Expand Down
2 changes: 2 additions & 0 deletions src/global_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub struct CommandLine {
pub enduser_client_version: String,
#[structopt(long, short="b", help="Send basic telemetry (counters and errors)")]
pub basic_telemetry: bool,
#[structopt(long, short="s", help="Send snippet telemetry (code snippets)")]
pub snippet_telemetry: bool,
#[structopt(long, default_value="0", help="Bind 127.0.0.1:<port> and act as an LSP server. This is compatible with having an HTTP server at the same time.")]
pub lsp_port: u16,
#[structopt(long, default_value="0", help="Act as an LSP server, use stdin stdout for communication. This is compatible with having an HTTP server at the same time. But it's not compatible with LSP port.")]
Expand Down
3 changes: 3 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async fn main() {
let gcx4 = gcx.clone();
let caps_reload_task = tokio::spawn(global_context::caps_background_reload(gcx.clone()));
let tele_backgr_task = tokio::spawn(telemetry_storage::telemetry_background_task(gcx.clone()));
let tele_snip_backgr_task = tokio::spawn(telemetry_snippets::tele_snip_background_task(gcx.clone()));
let http_server_task = tokio::spawn(async move {
let gcx_clone = gcx.clone();
let server = http_server::start_server(gcx_clone);
Expand Down Expand Up @@ -109,6 +110,8 @@ async fn main() {
let _ = caps_reload_task.await;
tele_backgr_task.abort();
let _ = tele_backgr_task.await;
tele_snip_backgr_task.abort();
let _ = tele_snip_backgr_task.await;
lsp_task.abort();
let _ = lsp_task.await;
info!("saving telemetry without sending, so should be quick");
Expand Down
29 changes: 12 additions & 17 deletions src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,36 @@ pub async fn scratchpad_interaction_not_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let url: String = endpoint_template.replace("$MODEL", &model_name).clone();

let model_says = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint(
&mut save_url,
&url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint(
&mut save_url,
&url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
}.map_err(|e| {
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
url.clone(),
scope.clone(),
false,
e.to_string(),
));
ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e))
})?;
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
url.clone(),
scope.clone(),
true,
"".to_string(),
Expand Down Expand Up @@ -143,26 +141,23 @@ pub async fn scratchpad_interaction_stream(
let caps_locked = caps.read().unwrap();
(caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone())
};
let mut save_url: String = String::new();
let url: String = endpoint_template.replace("$MODEL", &model_name).clone();
loop {
let event_source_maybe = if endpoint_style == "hf" {
forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
&url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
} else {
forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
&url,
bearer.clone(),
&model_name,
&prompt,
&client,
&endpoint_template,
&parameters,
).await
};
Expand All @@ -171,7 +166,7 @@ pub async fn scratchpad_interaction_stream(
Err(e) => {
let e_str = format!("forward_to_endpoint: {:?}", e);
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
url.clone(),
scope.clone(),
false,
e_str.to_string(),
Expand Down Expand Up @@ -227,7 +222,7 @@ pub async fn scratchpad_interaction_stream(
let problem_str = format!("restream error: {}", err);
{
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
url.clone(),
scope.clone(),
false,
problem_str.clone(),
Expand Down Expand Up @@ -256,7 +251,7 @@ pub async fn scratchpad_interaction_stream(
info!("yield: [DONE]");
yield Result::<_, String>::Ok("data: [DONE]\n\n".to_string());
tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new(
save_url.clone(),
url.clone(),
scope.clone(),
true,
"".to_string(),
Expand Down
4 changes: 2 additions & 2 deletions src/telemetry_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ pub async fn compress_basic_telemetry_to_file(
}
// even if there's an error with i/o, storage is now clear, preventing infinite memory growth
info!("basic telemetry save \"{}\"", fn_net.to_str().unwrap());
let io_result = _file_save(fn_net.clone(), big_json_net).await;
let io_result = file_save(fn_net.clone(), big_json_net).await;
if io_result.is_err() {
error!("error: {}", io_result.err().unwrap());
}
}

async fn _file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> {
pub async fn file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> {
let mut f = tokio::fs::File::create(path).await.map_err(|e| format!("{:?}", e))?;
f.write_all(serde_json::to_string_pretty(&json).unwrap().as_bytes()).await.map_err(|e| format!("{}", e))?;
Ok(())
Expand Down
89 changes: 82 additions & 7 deletions src/telemetry_snippets.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
use tracing::info;
use tracing::{error, info};
use std::sync::Arc;
use tokio::sync::RwLock as ARwLock;
use std::sync::RwLock as StdRwLock;
// use std::collections::HashMap;
// use reqwest_eventsource::Event;
// use futures::StreamExt;
// use async_stream::stream;
// use serde_json::json;
// use crate::caps::CodeAssistantCaps;

use crate::call_validation;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use crate::global_context;
use crate::completion_cache;
use crate::telemetry_storage;
Expand All @@ -24,6 +20,8 @@ use difference;
// 3. LSP looks at file changes (LSP can be replaced with reaction to a next completion?)
// 4. Changes are translated to "after_walkaway_remaining50to95" etc

const SNIP_FINISHED_AFTER : i64 = 300;


#[derive(Debug, Clone)]
pub struct SaveSnippet {
Expand Down Expand Up @@ -56,6 +54,7 @@ pub struct SnippetTelemetry {
// pub remaining_percent_300s: f64,
// pub remaining_percent_walkaway: f64,
// pub walkaway_ms: u64,
pub created_at: i64
}

pub fn snippet_register(
Expand All @@ -71,6 +70,7 @@ pub fn snippet_register(
accepted: false,
corrected_by_user: "".to_string(),
remaining_percent_30s: 0.0,
created_at: chrono::Local::now().timestamp(),
};
storage_locked.tele_snippet_next_id += 1;
storage_locked.tele_snippets.push(snip);
Expand Down Expand Up @@ -135,10 +135,12 @@ pub async fn sources_changed(
if !orig_text.is_some() {
continue;
}
// let time_from_creation = chrono::Local::now().timestamp() - snip.created_at;
let (valid1, mut gray_suggested) = if_head_tail_equal_return_added_text(
orig_text.unwrap(),
text
);
snip.corrected_by_user = gray_suggested.clone();
gray_suggested = gray_suggested.replace("\r", "");
info!("valid1: {:?}, gray_suggested: {:?}", valid1, gray_suggested);
info!("orig grey_text: {:?}", snip.grey_text);
Expand Down Expand Up @@ -219,3 +221,76 @@ pub fn unchanged_percentage(
let largest_of_two = text_a.len().max(text_b.len());
(common as f64) / (largest_of_two as f64)
}

async fn manage_finished_snippets(gcx: Arc<ARwLock<global_context::GlobalContext>>) {
let tele_storage;
let now = chrono::Local::now().timestamp();
let enduser_client_version;
let api_key: String;
let caps;
let mothership_enabled: bool;
let mut telemetry_corrected_snippets_dest = String::new();
{
let cx = gcx.read().await;
enduser_client_version = cx.cmdline.enduser_client_version.clone();
tele_storage = cx.telemetry.clone();
api_key = cx.cmdline.api_key.clone();
caps = cx.caps.clone();
mothership_enabled = cx.cmdline.snippet_telemetry;
}
if let Some(caps) = &caps {
telemetry_corrected_snippets_dest = caps.read().unwrap().telemetry_corrected_snippets_dest.clone();
}

let mut snips_send: Vec<SnippetTelemetry> = vec![];
{
let mut to_remove: Vec<usize> = vec![];
let mut storage_locked = tele_storage.write().unwrap();
for (idx, snip) in &mut storage_locked.tele_snippets.iter().enumerate() {
if now - snip.created_at >= SNIP_FINISHED_AFTER {
if snip.accepted {
snips_send.push(snip.clone());
}
to_remove.push(idx);
}
}
for idx in to_remove.iter().rev() {
storage_locked.tele_snippets.remove(*idx);
}
}

if !mothership_enabled {
info!("telemetry snippets sending not enabled, skip");
return;
}

for snip in snips_send {
let json_dict = serde_json::to_value(snip).unwrap();
let big_json_snip = json!({
"records": [json_dict],
"ts_start": now,
"ts_end": chrono::Local::now().timestamp(),
"teletype": "snippets",
"enduser_client_version": enduser_client_version,
});
let resp_maybe = telemetry_storage::send_telemetry_data(
big_json_snip.to_string(),
&telemetry_corrected_snippets_dest,
&api_key
).await;
if resp_maybe.is_err() {
error!("snippet send failed: {}", resp_maybe.err().unwrap());
error!("too bad snippet is lost now");
continue;
}
}
}

pub async fn tele_snip_background_task(
global_context: Arc<ARwLock<global_context::GlobalContext>>,
) -> () {
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
manage_finished_snippets(global_context.clone()).await;
}
}
Loading

0 comments on commit 89b9dff

Please sign in to comment.