diff --git a/src/caps.rs b/src/caps.rs index a7cd22531..18c10b327 100644 --- a/src/caps.rs +++ b/src/caps.rs @@ -31,6 +31,7 @@ pub struct CodeAssistantCaps { pub tokenizer_path_template: String, pub tokenizer_rewrite_path: HashMap, pub telemetry_basic_dest: String, + pub telemetry_corrected_snippets_dest: String, #[serde(default)] pub code_completion_models: HashMap, pub code_completion_default_model: String, diff --git a/src/forward_to_hf_endpoint.rs b/src/forward_to_hf_endpoint.rs index e070646c3..5a4c342f5 100644 --- a/src/forward_to_hf_endpoint.rs +++ b/src/forward_to_hf_endpoint.rs @@ -11,16 +11,12 @@ use crate::call_validation::SamplingParameters; pub async fn forward_to_hf_style_endpoint( - save_url: &mut String, + url: &String, bearer: String, - model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -34,7 +30,7 @@ pub async fn forward_to_hf_style_endpoint( "inputs": prompt, "parameters": params_json, }); - let req = client.post(&url) + let req = client.post(url) .headers(headers) .body(data.to_string()) .send() @@ -52,16 +48,12 @@ pub async fn forward_to_hf_style_endpoint( pub async fn forward_to_hf_style_endpoint_streaming( - save_url: &mut String, + url: &String, bearer: String, - model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -77,7 +69,7 @@ pub async fn forward_to_hf_style_endpoint_streaming( "stream": true, }); - let builder = client.post(&url) + let builder = client.post(url) .headers(headers) .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| diff --git a/src/forward_to_openai_endpoint.rs b/src/forward_to_openai_endpoint.rs index bdca278ab..c682ee798 100644 --- a/src/forward_to_openai_endpoint.rs +++ b/src/forward_to_openai_endpoint.rs @@ -8,16 +8,13 @@ use crate::call_validation::SamplingParameters; pub async fn forward_to_openai_style_endpoint( - mut save_url: &String, + url: &String, bearer: String, model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -31,7 +28,7 @@ pub async fn forward_to_openai_style_endpoint( "temperature": sampling_parameters.temperature, "max_tokens": sampling_parameters.max_new_tokens, }); - let req = client.post(&url) + let req = client.post(url) .headers(headers) .body(data.to_string()) .send() @@ -49,16 +46,13 @@ pub async fn forward_to_openai_style_endpoint( } pub async fn forward_to_openai_style_endpoint_streaming( - mut save_url: &String, + url: &String, bearer: String, model_name: &str, prompt: &str, client: &reqwest::Client, - endpoint_template: &String, sampling_parameters: &SamplingParameters, ) -> Result { - let url = endpoint_template.replace("$MODEL", model_name); - save_url.clone_from(&&url); let mut headers = HeaderMap::new(); headers.insert(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()); if !bearer.is_empty() { @@ -72,7 +66,7 @@ pub async fn forward_to_openai_style_endpoint_streaming( "temperature": sampling_parameters.temperature, "max_tokens": sampling_parameters.max_new_tokens, }); - let builder = client.post(&url) + let builder = client.post(url) .headers(headers) .body(data.to_string()); let event_source: EventSource = EventSource::new(builder).map_err(|e| diff --git a/src/global_context.rs b/src/global_context.rs index 0a716f45c..b845e7394 100644 --- a/src/global_context.rs +++ b/src/global_context.rs @@ -30,6 +30,8 @@ pub struct CommandLine { pub enduser_client_version: String, #[structopt(long, short="b", help="Send basic telemetry (counters and errors)")] pub basic_telemetry: bool, + #[structopt(long, short="s", help="Send snippet telemetry (code snippets)")] + pub snippet_telemetry: bool, #[structopt(long, default_value="0", help="Bind 127.0.0.1: and act as an LSP server. This is compatible with having an HTTP server at the same time.")] pub lsp_port: u16, #[structopt(long, default_value="0", help="Act as an LSP server, use stdin stdout for communication. This is compatible with having an HTTP server at the same time. But it's not compatible with LSP port.")] diff --git a/src/main.rs b/src/main.rs index fd261ccd2..2f8803d21 100644 --- a/src/main.rs +++ b/src/main.rs @@ -51,6 +51,7 @@ async fn main() { let gcx4 = gcx.clone(); let caps_reload_task = tokio::spawn(global_context::caps_background_reload(gcx.clone())); let tele_backgr_task = tokio::spawn(telemetry_storage::telemetry_background_task(gcx.clone())); + let tele_snip_backgr_task = tokio::spawn(telemetry_snippets::tele_snip_background_task(gcx.clone())); let http_server_task = tokio::spawn(async move { let gcx_clone = gcx.clone(); let server = http_server::start_server(gcx_clone); @@ -109,6 +110,8 @@ async fn main() { let _ = caps_reload_task.await; tele_backgr_task.abort(); let _ = tele_backgr_task.await; + tele_snip_backgr_task.abort(); + let _ = tele_snip_backgr_task.await; lsp_task.abort(); let _ = lsp_task.await; info!("saving telemetry without sending, so should be quick"); diff --git a/src/restream.rs b/src/restream.rs index 2133e6dfd..f8f17cfd3 100644 --- a/src/restream.rs +++ b/src/restream.rs @@ -33,30 +33,28 @@ pub async fn scratchpad_interaction_not_stream( let caps_locked = caps.read().unwrap(); (caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone()) }; - let mut save_url: String = String::new(); + let url: String = endpoint_template.replace("$MODEL", &model_name).clone(); + let model_says = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint( - &mut save_url, + &url, bearer.clone(), - &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await } else { forward_to_openai_endpoint::forward_to_openai_style_endpoint( - &mut save_url, + &url, bearer.clone(), &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await }.map_err(|e| { tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + url.clone(), scope.clone(), false, e.to_string(), @@ -64,7 +62,7 @@ pub async fn scratchpad_interaction_not_stream( ScratchError::new_but_skip_telemetry(StatusCode::INTERNAL_SERVER_ERROR, format!("forward_to_endpoint: {}", e)) })?; tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + url.clone(), scope.clone(), true, "".to_string(), @@ -143,26 +141,23 @@ pub async fn scratchpad_interaction_stream( let caps_locked = caps.read().unwrap(); (caps_locked.endpoint_style.clone(), caps_locked.endpoint_template.clone(), cx.telemetry.clone()) }; - let mut save_url: String = String::new(); + let url: String = endpoint_template.replace("$MODEL", &model_name).clone(); loop { let event_source_maybe = if endpoint_style == "hf" { forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming( - &mut save_url, + &url, bearer.clone(), - &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await } else { forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming( - &mut save_url, + &url, bearer.clone(), &model_name, &prompt, &client, - &endpoint_template, ¶meters, ).await }; @@ -171,7 +166,7 @@ pub async fn scratchpad_interaction_stream( Err(e) => { let e_str = format!("forward_to_endpoint: {:?}", e); tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + url.clone(), scope.clone(), false, e_str.to_string(), @@ -227,7 +222,7 @@ pub async fn scratchpad_interaction_stream( let problem_str = format!("restream error: {}", err); { tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + url.clone(), scope.clone(), false, problem_str.clone(), @@ -256,7 +251,7 @@ pub async fn scratchpad_interaction_stream( info!("yield: [DONE]"); yield Result::<_, String>::Ok("data: [DONE]\n\n".to_string()); tele_storage.write().unwrap().tele_net.push(telemetry_basic::TelemetryNetwork::new( - save_url.clone(), + url.clone(), scope.clone(), true, "".to_string(), diff --git a/src/telemetry_basic.rs b/src/telemetry_basic.rs index 89139ea73..da7beb881 100644 --- a/src/telemetry_basic.rs +++ b/src/telemetry_basic.rs @@ -135,13 +135,13 @@ pub async fn compress_basic_telemetry_to_file( } // even if there's an error with i/o, storage is now clear, preventing infinite memory growth info!("basic telemetry save \"{}\"", fn_net.to_str().unwrap()); - let io_result = _file_save(fn_net.clone(), big_json_net).await; + let io_result = file_save(fn_net.clone(), big_json_net).await; if io_result.is_err() { error!("error: {}", io_result.err().unwrap()); } } -async fn _file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> { +pub async fn file_save(path: PathBuf, json: serde_json::Value) -> Result<(), String> { let mut f = tokio::fs::File::create(path).await.map_err(|e| format!("{:?}", e))?; f.write_all(serde_json::to_string_pretty(&json).unwrap().as_bytes()).await.map_err(|e| format!("{}", e))?; Ok(()) diff --git a/src/telemetry_snippets.rs b/src/telemetry_snippets.rs index 3cf05be67..c9ea4530c 100644 --- a/src/telemetry_snippets.rs +++ b/src/telemetry_snippets.rs @@ -1,16 +1,12 @@ -use tracing::info; +use tracing::{error, info}; use std::sync::Arc; use tokio::sync::RwLock as ARwLock; use std::sync::RwLock as StdRwLock; -// use std::collections::HashMap; -// use reqwest_eventsource::Event; -// use futures::StreamExt; -// use async_stream::stream; -// use serde_json::json; -// use crate::caps::CodeAssistantCaps; + use crate::call_validation; use serde::Deserialize; use serde::Serialize; +use serde_json::json; use crate::global_context; use crate::completion_cache; use crate::telemetry_storage; @@ -24,6 +20,8 @@ use difference; // 3. LSP looks at file changes (LSP can be replaced with reaction to a next completion?) // 4. Changes are translated to "after_walkaway_remaining50to95" etc +const SNIP_FINISHED_AFTER : i64 = 300; + #[derive(Debug, Clone)] pub struct SaveSnippet { @@ -56,6 +54,7 @@ pub struct SnippetTelemetry { // pub remaining_percent_300s: f64, // pub remaining_percent_walkaway: f64, // pub walkaway_ms: u64, + pub created_at: i64 } pub fn snippet_register( @@ -71,6 +70,7 @@ pub fn snippet_register( accepted: false, corrected_by_user: "".to_string(), remaining_percent_30s: 0.0, + created_at: chrono::Local::now().timestamp(), }; storage_locked.tele_snippet_next_id += 1; storage_locked.tele_snippets.push(snip); @@ -135,10 +135,12 @@ pub async fn sources_changed( if !orig_text.is_some() { continue; } + // let time_from_creation = chrono::Local::now().timestamp() - snip.created_at; let (valid1, mut gray_suggested) = if_head_tail_equal_return_added_text( orig_text.unwrap(), text ); + snip.corrected_by_user = gray_suggested.clone(); gray_suggested = gray_suggested.replace("\r", ""); info!("valid1: {:?}, gray_suggested: {:?}", valid1, gray_suggested); info!("orig grey_text: {:?}", snip.grey_text); @@ -219,3 +221,76 @@ pub fn unchanged_percentage( let largest_of_two = text_a.len().max(text_b.len()); (common as f64) / (largest_of_two as f64) } + +async fn manage_finished_snippets(gcx: Arc>) { + let tele_storage; + let now = chrono::Local::now().timestamp(); + let enduser_client_version; + let api_key: String; + let caps; + let mothership_enabled: bool; + let mut telemetry_corrected_snippets_dest = String::new(); + { + let cx = gcx.read().await; + enduser_client_version = cx.cmdline.enduser_client_version.clone(); + tele_storage = cx.telemetry.clone(); + api_key = cx.cmdline.api_key.clone(); + caps = cx.caps.clone(); + mothership_enabled = cx.cmdline.snippet_telemetry; + } + if let Some(caps) = &caps { + telemetry_corrected_snippets_dest = caps.read().unwrap().telemetry_corrected_snippets_dest.clone(); + } + + let mut snips_send: Vec = vec![]; + { + let mut to_remove: Vec = vec![]; + let mut storage_locked = tele_storage.write().unwrap(); + for (idx, snip) in &mut storage_locked.tele_snippets.iter().enumerate() { + if now - snip.created_at >= SNIP_FINISHED_AFTER { + if snip.accepted { + snips_send.push(snip.clone()); + } + to_remove.push(idx); + } + } + for idx in to_remove.iter().rev() { + storage_locked.tele_snippets.remove(*idx); + } + } + + if !mothership_enabled { + info!("telemetry snippets sending not enabled, skip"); + return; + } + + for snip in snips_send { + let json_dict = serde_json::to_value(snip).unwrap(); + let big_json_snip = json!({ + "records": [json_dict], + "ts_start": now, + "ts_end": chrono::Local::now().timestamp(), + "teletype": "snippets", + "enduser_client_version": enduser_client_version, + }); + let resp_maybe = telemetry_storage::send_telemetry_data( + big_json_snip.to_string(), + &telemetry_corrected_snippets_dest, + &api_key + ).await; + if resp_maybe.is_err() { + error!("snippet send failed: {}", resp_maybe.err().unwrap()); + error!("too bad snippet is lost now"); + continue; + } + } +} + +pub async fn tele_snip_background_task( + global_context: Arc>, +) -> () { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; + manage_finished_snippets(global_context.clone()).await; + } +} diff --git a/src/telemetry_storage.rs b/src/telemetry_storage.rs index a4b8451a4..cf4754e4d 100644 --- a/src/telemetry_storage.rs +++ b/src/telemetry_storage.rs @@ -83,6 +83,33 @@ async fn _read_file(path: PathBuf) -> Result { Ok(contents) } +pub async fn send_telemetry_data( + contents: String, + telemetry_dest: &String, + api_key: &String, +) -> Result<(), String>{ + let resp_maybe = reqwest::Client::new().post(telemetry_dest.clone()) + .body(contents) + .header(reqwest::header::AUTHORIZATION, format!("Bearer {}", api_key)) + .header(reqwest::header::CONTENT_TYPE, format!("application/json")) + .send().await; + if resp_maybe.is_err() { + return Err(format!("telemetry send failed: {}\ndest url was\n{}", resp_maybe.err().unwrap(), telemetry_dest)); + } + let resp = resp_maybe.unwrap(); + if resp.status()!= reqwest::StatusCode::OK { + return Err(format!("telemetry send failed: {}\ndest url was\n{}", resp.status(), telemetry_dest)); + } + let resp_body = resp.text().await.unwrap_or_else(|_| "-empty-".to_string()); + info!("telemetry send success, response:\n{}", resp_body); + let resp_json = serde_json::from_str::(&resp_body).unwrap_or_else(|_| json!({})); + let retcode = resp_json["retcode"].as_str().unwrap_or("").to_string(); + if retcode != "OK" { + return Err("retcode is not OK".to_string()); + } + Ok(()) +} + pub async fn send_telemetry_files_to_mothership( dir_compressed: PathBuf, dir_sent: PathBuf, @@ -91,36 +118,23 @@ pub async fn send_telemetry_files_to_mothership( ) { // Send files found in dir_compressed, move to dir_sent if successful. let files = _sorted_files(dir_compressed.clone()).await; - let http_client = reqwest::Client::new(); for path in files { let contents_maybe = _read_file(path.clone()).await; if contents_maybe.is_err() { error!("cannot read {}: {}", path.display(), contents_maybe.err().unwrap()); - break; + continue } let contents = contents_maybe.unwrap(); - info!("sending telemetry file\n{}\nto url\n{}", path.to_str().unwrap(), telemetry_basic_dest); - let resp_maybe = http_client.post(telemetry_basic_dest.clone()) - .body(contents) - .header(reqwest::header::AUTHORIZATION, format!("Bearer {}", api_key)) - .header(reqwest::header::CONTENT_TYPE, format!("application/json")) - .send().await; - if resp_maybe.is_err() { - error!("telemetry send failed: {}\ndest url was\n{}", resp_maybe.err().unwrap(), telemetry_basic_dest); - break; - } - let resp = resp_maybe.unwrap(); - if resp.status()!= reqwest::StatusCode::OK { - error!("telemetry send failed: {}\ndest url was\n{}", resp.status(), telemetry_basic_dest); - break; - } - let resp_body = resp.text().await.unwrap_or_else(|_| "-empty-".to_string()); - info!("telemetry send success, response:\n{}", resp_body); - let resp_json = serde_json::from_str::(&resp_body).unwrap_or_else(|_| json!({})); - let retcode = resp_json["retcode"].as_str().unwrap_or("").to_string(); - if retcode != "OK" { - error!("retcode is not OK"); - break; + + if path.to_str().unwrap().ends_with("-net.json") { + info!("sending telemetry file\n{}\nto url\n{}", path.to_str().unwrap(), telemetry_basic_dest); + let resp = send_telemetry_data(contents, &telemetry_basic_dest, &api_key).await; + if resp.is_err() { + error!("telemetry send failed: {}", resp.err().unwrap()); + continue; + } + } else { + continue; } let new_path = dir_sent.join(path.file_name().unwrap()); info!("success, moving file to {}", new_path.to_str().unwrap()); @@ -151,13 +165,18 @@ pub async fn telemetry_full_cycle( mothership_enabled = cx.cmdline.basic_telemetry; } if caps.is_some() { - telemetry_basic_dest = caps.unwrap().read().unwrap().telemetry_basic_dest.clone(); + telemetry_basic_dest = caps.clone().unwrap().read().unwrap().telemetry_basic_dest.clone(); } telemetry_basic::compress_basic_telemetry_to_file(global_context.clone()).await; let dir_compressed = cache_dir.join("telemetry").join("compressed"); let dir_sent = cache_dir.join("telemetry").join("sent"); if mothership_enabled && !telemetry_basic_dest.is_empty() && !skip_sending_part { - send_telemetry_files_to_mothership(dir_compressed.clone(), dir_sent.clone(), telemetry_basic_dest, api_key).await; + send_telemetry_files_to_mothership( + dir_compressed.clone(), + dir_sent.clone(), + telemetry_basic_dest, + api_key + ).await; } if !mothership_enabled { info!("telemetry sending not enabled, skip"); @@ -170,7 +189,7 @@ pub async fn telemetry_background_task( global_context: Arc>, ) -> () { loop { - tokio::time::sleep(std::time::Duration::from_secs(TELEMETRY_COMPRESSION_SECONDS)).await; + tokio::time::sleep(tokio::time::Duration::from_secs(TELEMETRY_COMPRESSION_SECONDS)).await; telemetry_full_cycle(global_context.clone(), false).await; } }