Skip to content

Commit

Permalink
add tests and small refact
Browse files Browse the repository at this point in the history
  • Loading branch information
reymondzzzz committed Jan 15, 2024
1 parent 284ed51 commit 369d3ef
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 19 deletions.
133 changes: 133 additions & 0 deletions src/call_validation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use axum::http::StatusCode;
use ropey::Rope;
use crate::custom_error::ScratchError;


#[derive(Debug, Serialize, Deserialize, Clone, Default)]
Expand Down Expand Up @@ -41,6 +44,136 @@ pub struct CodeCompletionPost {
pub no_cache: bool,
}

pub(crate) fn validate_post(code_completion_post: CodeCompletionPost) -> axum::response::Result<(), ScratchError> {
let pos = code_completion_post.inputs.cursor.clone();
let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
};
let text = Rope::from_str(&*source);
let line_number = pos.line as usize;
if line_number >= text.len_lines() {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
}
let line = text.line(line_number);
let col = pos.character as usize;
if col > line.len_chars() {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
}
Ok(())
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use crate::call_validation::{CodeCompletionInputs, CursorPosition, SamplingParameters};
use super::*;

#[test]
fn test_valid_post1() {
let post = CodeCompletionPost {
inputs: CodeCompletionInputs {
sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]),
cursor: CursorPosition {
file: "hello.py".to_string(),
line: 0,
character: 18,
},
multiline: true,
},
parameters: SamplingParameters {
max_new_tokens: 20,
temperature: Some(0.1),
top_p: None,
stop: None,
},
model: "".to_string(),
scratchpad: "".to_string(),
stream: false,
no_cache: false,
};
assert!(validate_post(post).is_ok());
}

#[test]
fn test_valid_post2() {
let post = CodeCompletionPost {
inputs: CodeCompletionInputs {
sources: HashMap::from_iter([("hello.py".to_string(), "你好世界Ωßåß🤖".to_string())]),
cursor: CursorPosition {
file: "hello.py".to_string(),
line: 0,
character: 10,
},
multiline: true,
},
parameters: SamplingParameters {
max_new_tokens: 20,
temperature: Some(0.1),
top_p: None,
stop: None,
},
model: "".to_string(),
scratchpad: "".to_string(),
stream: false,
no_cache: false,
};
assert!(validate_post(post).is_ok());
}

#[test]
fn test_invalid_post_incorrect_line() {
let post = CodeCompletionPost {
inputs: CodeCompletionInputs {
sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]),
cursor: CursorPosition {
file: "hello.py".to_string(),
line: 2,
character: 18,
},
multiline: true,
},
parameters: SamplingParameters {
max_new_tokens: 20,
temperature: Some(0.1),
top_p: None,
stop: None,
},
model: "".to_string(),
scratchpad: "".to_string(),
stream: false,
no_cache: false,
};
assert!(validate_post(post).is_err());
}

#[test]
fn test_invalid_post_incorrect_col() {
let post = CodeCompletionPost {
inputs: CodeCompletionInputs {
sources: HashMap::from_iter([("hello.py".to_string(), "def hello_world():".to_string())]),
cursor: CursorPosition {
file: "hello.py".to_string(),
line: 0,
character: 80,
},
multiline: true,
},
parameters: SamplingParameters {
max_new_tokens: 20,
temperature: Some(0.1),
top_p: None,
stop: None,
},
model: "".to_string(),
scratchpad: "".to_string(),
stream: false,
no_cache: false,
};
assert!(validate_post(post).is_err());
}

}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ContextFile {
pub file_name: String,
Expand Down
20 changes: 1 addition & 19 deletions src/http/routers/v1/code_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use hyper::{Body, Response, StatusCode};
use ropey::Rope;
use tracing::{error, info};

use crate::call_validation::CodeCompletionPost;
use crate::call_validation::{CodeCompletionPost, validate_post};
use crate::caps;
use crate::caps::CodeAssistantCaps;
use crate::completion_cache;
Expand Down Expand Up @@ -38,24 +38,6 @@ async fn _lookup_code_completion_scratchpad(
Ok((model_name, sname.clone(), patch.clone(), n_ctx))
}

fn validate_post(code_completion_post: CodeCompletionPost) -> Result<(), ScratchError> {
let pos = code_completion_post.inputs.cursor.clone();
let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
};
let text = Rope::from_str(&*source);
let line_number = pos.line as usize;
if line_number >= text.len_lines() {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
}
let line = text.line(line_number);
let col = pos.character as usize;
if col > line.len_chars() {
return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string()))
}
Ok(())
}

pub async fn handle_v1_code_completion(
global_context: Arc<ARwLock<GlobalContext>>,
code_completion_post: &mut CodeCompletionPost,
Expand Down

0 comments on commit 369d3ef

Please sign in to comment.