diff --git a/src/call_validation.rs b/src/call_validation.rs index b22a0609e..03347800d 100644 --- a/src/call_validation.rs +++ b/src/call_validation.rs @@ -1,6 +1,9 @@ use serde::Deserialize; use serde::Serialize; use std::collections::HashMap; +use axum::http::StatusCode; +use ropey::Rope; +use crate::custom_error::ScratchError; #[derive(Debug, Serialize, Deserialize, Clone, Default)] @@ -41,6 +44,136 @@ pub struct CodeCompletionPost { pub no_cache: bool, } +pub(crate) fn validate_post(code_completion_post: CodeCompletionPost) -> axum::response::Result<(), ScratchError> { + let pos = code_completion_post.inputs.cursor.clone(); + let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + }; + let text = Rope::from_str(&*source); + let line_number = pos.line as usize; + if line_number >= text.len_lines() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + let line = text.line(line_number); + let col = pos.character as usize; + if col > line.len_chars() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use crate::call_validation::{CodeCompletionInputs, CursorPosition, SamplingParameters}; + use super::*; + + #[test] + fn test_valid_post1() { + let post = CodeCompletionPost { + inputs: CodeCompletionInputs { + sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]), + cursor: CursorPosition { + file: "hello.py".to_string(), + line: 0, + character: 18, + }, + multiline: true, + }, + parameters: SamplingParameters { + max_new_tokens: 20, + temperature: Some(0.1), + top_p: None, + stop: None, + }, + model: "".to_string(), + scratchpad: "".to_string(), + stream: false, + no_cache: false, + }; + assert!(validate_post(post).is_ok()); + } + + #[test] + fn test_valid_post2() { + let post = CodeCompletionPost { + inputs: CodeCompletionInputs { + sources: HashMap::from_iter([("hello.py".to_string(), "你好世界Ωßåß🤖".to_string())]), + cursor: CursorPosition { + file: "hello.py".to_string(), + line: 0, + character: 10, + }, + multiline: true, + }, + parameters: SamplingParameters { + max_new_tokens: 20, + temperature: Some(0.1), + top_p: None, + stop: None, + }, + model: "".to_string(), + scratchpad: "".to_string(), + stream: false, + no_cache: false, + }; + assert!(validate_post(post).is_ok()); + } + + #[test] + fn test_invalid_post_incorrect_line() { + let post = CodeCompletionPost { + inputs: CodeCompletionInputs { + sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]), + cursor: CursorPosition { + file: "hello.py".to_string(), + line: 2, + character: 18, + }, + multiline: true, + }, + parameters: SamplingParameters { + max_new_tokens: 20, + temperature: Some(0.1), + top_p: None, + stop: None, + }, + model: "".to_string(), + scratchpad: "".to_string(), + stream: false, + no_cache: false, + }; + assert!(validate_post(post).is_err()); + } + + #[test] + fn test_invalid_post_incorrect_col() { + let post = CodeCompletionPost { + inputs: CodeCompletionInputs { + sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]), + cursor: CursorPosition { + file: "hello.py".to_string(), + line: 0, + character: 80, + }, + multiline: true, + }, + parameters: SamplingParameters { + max_new_tokens: 20, + temperature: Some(0.1), + top_p: None, + stop: None, + }, + model: "".to_string(), + scratchpad: "".to_string(), + stream: false, + no_cache: false, + }; + assert!(validate_post(post).is_err()); + } + +} + #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ContextFile { pub file_name: String, diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 8fc603ed7..7b00ed0ad 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -9,7 +9,7 @@ use hyper::{Body, Response, StatusCode}; use ropey::Rope; use tracing::{error, info}; -use crate::call_validation::CodeCompletionPost; +use crate::call_validation::{CodeCompletionPost, validate_post}; use crate::caps; use crate::caps::CodeAssistantCaps; use crate::completion_cache; @@ -38,24 +38,6 @@ async fn _lookup_code_completion_scratchpad( Ok((model_name, sname.clone(), patch.clone(), n_ctx)) } -fn validate_post(code_completion_post: CodeCompletionPost) -> Result<(), ScratchError> { - let pos = code_completion_post.inputs.cursor.clone(); - let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else { - return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) - }; - let text = Rope::from_str(&*source); - let line_number = pos.line as usize; - if line_number >= text.len_lines() { - return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) - } - let line = text.line(line_number); - let col = pos.character as usize; - if col > line.len_chars() { - return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) - } - Ok(()) -} - pub async fn handle_v1_code_completion( global_context: Arc>, code_completion_post: &mut CodeCompletionPost,