diff --git a/src/http/routers/v1/code_completion.rs b/src/http/routers/v1/code_completion.rs index 57c64bb0f..cc4e8117d 100644 --- a/src/http/routers/v1/code_completion.rs +++ b/src/http/routers/v1/code_completion.rs @@ -3,8 +3,10 @@ use std::sync::RwLock as StdRwLock; use axum::Extension; use axum::response::Result; +use futures_util::future::ok; use hyper::{Body, Response, StatusCode}; -use tracing::info; +use ropey::Rope; +use tracing::{error, info}; use crate::call_validation::CodeCompletionPost; use crate::caps; @@ -35,10 +37,29 @@ async fn _lookup_code_completion_scratchpad( Ok((model_name, sname.clone(), patch.clone(), n_ctx)) } +fn validate_post(code_completion_post: CodeCompletionPost) -> Result<(), ScratchError> { + let pos = code_completion_post.inputs.cursor.clone(); + let Some(source) = code_completion_post.inputs.sources.get(&code_completion_post.inputs.cursor.file) else { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + }; + let text = Rope::from_str(&*source); + let line_number = pos.line as usize; + if line_number >= text.len_lines() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + let line = text.line(line_number); + let col = pos.character as usize; + if col >= line.len_chars() { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "invalid post".to_string())) + } + Ok(()) +} + pub async fn handle_v1_code_completion( global_context: SharedGlobalContext, code_completion_post: &mut CodeCompletionPost, ) -> Result, ScratchError> { + validate_post(code_completion_post.clone())?; let caps = crate::global_context::try_load_caps_quickly_if_not_present(global_context.clone()).await?; let (model_name, scratchpad_name, scratchpad_patch, n_ctx) = _lookup_code_completion_scratchpad( caps.clone(),