diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 1e572290e..8fc603ed7 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -4,8 +4,10 @@ use tokio::sync::RwLock as ARwLock; use axum::Extension; use axum::response::Result; +use futures_util::future::ok; use hyper::{Body, Response, StatusCode}; -use tracing::info; +use ropey::Rope; +use tracing::{error, info}; use crate::call_validation::CodeCompletionPost; use crate::caps; @@ -36,10 +38,29 @@ async fn _lookup_code_completion_scratchpad( Ok((model_name, sname.clone(), patch.clone(), n_ctx)) } +fn validate_post(code_completion_post: CodeCompletionPost) -> Result<(), ScratchError> { + let pos = code_completion_post.inputs.cursor.clone(); + let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + }; + let text = Rope::from_str(&*source); + let line_number = pos.line as usize; + if line_number >= text.len_lines() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + let line = text.line(line_number); + let col = pos.character as usize; + if col > line.len_chars() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + Ok(()) +} + pub async fn handle_v1_code_completion( global_context: Arc>, code_completion_post: &mut CodeCompletionPost, ) -> Result, ScratchError> { + validate_post(code_completion_post.clone())?; let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone(), 0).await?; let maybe = _lookup_code_completion_scratchpad( caps.clone(),