Skip to content

Commit

Permalink
Merge pull request #61 from SilasMarvin/silas-custom-actions
Browse files Browse the repository at this point in the history
Custom actions, server shutdown fixes and a bunch of small things
  • Loading branch information
SilasMarvin authored Aug 21, 2024
2 parents ef57dc2 + 57ee3c0 commit 3064bef
Show file tree
Hide file tree
Showing 17 changed files with 794 additions and 251 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion crates/lsp-ai/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lsp-ai"
version = "0.5.1"
version = "0.6.0"

description.workspace = true
repository.workspace = true
Expand Down Expand Up @@ -43,6 +43,7 @@ fxhash = "0.2.1"
ordered-float = "4.2.1"
futures = "0.3"
clap = { version = "4.5.14", features = ["derive"] }
regex = "1.10.6"

[build-dependencies]
cc="1"
Expand Down
42 changes: 37 additions & 5 deletions crates/lsp-ai/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@ const fn max_requests_per_second_default() -> f32 {
1.
}

const fn true_default() -> bool {
true
}

#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct PostProcess {
pub extractor: Option<String>,
#[serde(default = "true_default")]
pub remove_duplicate_start: bool,
#[serde(default = "true_default")]
pub remove_duplicate_end: bool,
}

impl Default for PostProcess {
fn default() -> Self {
Self {
extractor: None,
remove_duplicate_start: true,
remove_duplicate_end: true,
}
Expand Down Expand Up @@ -353,13 +361,31 @@ pub struct Chat {
pub(crate) parameters: Kwargs,
}

#[derive(Clone, Debug, Deserialize)]
pub struct Action {
// The name to display in the editor
pub(crate) action_display_name: String,
// The model key to use
pub(crate) model: String,
// Args are deserialized by the backend using them
#[serde(default)]
pub(crate) parameters: Kwargs,
// Parameters for post processing
#[serde(default)]
pub(crate) post_process: PostProcess,
}

#[derive(Clone, Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub(crate) struct ValidConfig {
pub(crate) memory: ValidMemoryBackend,
pub(crate) models: HashMap<String, ValidModel>,
pub(crate) completion: Option<Completion>,
pub(crate) chat: Option<Vec<Chat>>,
#[serde(default)]
pub(crate) actions: Vec<Action>,
#[serde(default)]
#[serde(alias = "chat")] // Legacy from when it was called chat, remove soon
pub(crate) chats: Vec<Chat>,
}

#[derive(Clone, Debug, Deserialize, Default)]
Expand Down Expand Up @@ -396,8 +422,12 @@ impl Config {
// Helpers for the backends ///////////
///////////////////////////////////////

pub fn get_chat(&self) -> Option<&Vec<Chat>> {
self.config.chat.as_ref()
pub fn get_chats(&self) -> &Vec<Chat> {
&self.config.chats
}

pub fn get_actions(&self) -> &Vec<Action> {
&self.config.actions
}

pub fn is_completions_enabled(&self) -> bool {
Expand Down Expand Up @@ -446,7 +476,8 @@ impl Config {
memory: ValidMemoryBackend::FileStore(FileStore { crawl: None }),
models: HashMap::new(),
completion: None,
chat: None,
actions: vec![],
chats: vec![],
},
client_params: ValidClientParams { root_uri: None },
}
Expand All @@ -458,7 +489,8 @@ impl Config {
memory: ValidMemoryBackend::VectorStore(vector_store),
models: HashMap::new(),
completion: None,
chat: None,
actions: vec![],
chats: vec![],
},
client_params: ValidClientParams { root_uri: None },
}
Expand Down
27 changes: 19 additions & 8 deletions crates/lsp-ai/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use clap::Parser;
use directories::BaseDirs;
use lsp_server::{Connection, ExtractError, Message, Notification, Request, RequestId};
use lsp_types::{
request::{CodeActionRequest, CodeActionResolveRequest, Completion},
request::{CodeActionRequest, CodeActionResolveRequest, Completion, Shutdown},
CodeActionOptions, CompletionOptions, DidChangeTextDocumentParams, DidOpenTextDocumentParams,
RenameFilesParams, ServerCapabilities, TextDocumentSyncKind,
};
Expand Down Expand Up @@ -127,7 +127,10 @@ fn main() -> Result<()> {
})?;
let initialization_args = connection.initialize(server_capabilities)?;

main_loop(connection, initialization_args)?;
if let Err(e) = main_loop(connection, initialization_args) {
error!("{e:?}");
}

io_threads.join()?;
Ok(())
}
Expand All @@ -147,7 +150,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {

// Setup the transformer worker
let memory_backend: Box<dyn MemoryBackend + Send + Sync> = config.clone().try_into()?;
thread::spawn(move || memory_worker::run(memory_backend, memory_rx));
let memory_worker_thread = thread::spawn(move || memory_worker::run(memory_backend, memory_rx));

// Setup our transformer worker
let transformer_backends: HashMap<String, Box<dyn TransformerBackend + Send + Sync>> = config
Expand All @@ -160,7 +163,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
let thread_connection = connection.clone();
let thread_memory_tx = memory_tx.clone();
let thread_config = config.clone();
thread::spawn(move || {
let transformer_worker_thread = thread::spawn(move || {
transformer_worker::run(
transformer_backends,
thread_memory_tx,
Expand All @@ -173,10 +176,18 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
for msg in &connection.receiver {
match msg {
Message::Request(req) => {
if connection.handle_shutdown(&req)? {
if request_is::<Shutdown>(&req) {
memory_tx.send(memory_worker::WorkerRequest::Shutdown)?;
if let Err(e) = memory_worker_thread.join() {
std::panic::resume_unwind(e)
}
transformer_tx.send(WorkerRequest::Shutdown)?;
if let Err(e) = transformer_worker_thread.join() {
std::panic::resume_unwind(e)
}
connection.handle_shutdown(&req)?;
return Ok(());
}
if request_is::<Completion>(&req) {
} else if request_is::<Completion>(&req) {
match cast::<Completion>(req) {
Ok((id, params)) => {
let completion_request = CompletionRequest::new(id, params);
Expand Down Expand Up @@ -224,7 +235,7 @@ fn main_loop(connection: Connection, args: serde_json::Value) -> Result<()> {
Err(err) => error!("{err:?}"),
}
} else {
error!("Unsupported command - see the wiki for a list of supported commands")
error!("Unsupported command - see the wiki for a list of supported commands: {req:?}")
}
}
Message::Notification(not) => {
Expand Down
25 changes: 14 additions & 11 deletions crates/lsp-ai/src/memory_backends/file_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,20 +240,22 @@ impl FileStore {
let rope_slice = rope
.get_slice(start..end + "<CURSOR>".chars().count())
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: rope_slice.to_string(),
selected_text: None,
})
} else {
let start = cursor_index
.saturating_sub(tokens_to_estimated_characters(params.max_context));
let rope_slice = rope
.get_slice(start..cursor_index)
.context("Error getting rope slice")?;
Prompt::ContextAndCode(ContextAndCodePrompt::new(
"".to_string(),
rope_slice.to_string(),
))
Prompt::ContextAndCode(ContextAndCodePrompt {
context: "".to_string(),
code: rope_slice.to_string(),
selected_text: None,
})
}
}
PromptType::FIM => {
Expand All @@ -268,7 +270,10 @@ impl FileStore {
let suffix = rope
.get_slice(cursor_index..end)
.context("Error getting rope slice")?;
Prompt::FIM(FIMPrompt::new(prefix.to_string(), suffix.to_string()))
Prompt::FIM(FIMPrompt {
prompt: prefix.to_string(),
suffix: suffix.to_string(),
})
}
})
}
Expand Down Expand Up @@ -837,8 +842,6 @@ mod tests {

#[test]
fn test_file_store_tree_sitter() -> anyhow::Result<()> {
crate::init_logger();

let config = Config::default_with_file_store_without_models();
let file_store_config = if let config::ValidMemoryBackend::FileStore(file_store_config) =
config.config.memory.clone()
Expand Down
42 changes: 15 additions & 27 deletions crates/lsp-ai/src/memory_backends/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@ impl From<&Value> for MemoryRunParams {
pub struct ContextAndCodePrompt {
pub context: String,
pub code: String,
}

impl ContextAndCodePrompt {
pub fn new(context: String, code: String) -> Self {
Self { context, code }
}
pub selected_text: Option<String>,
}

#[derive(Debug)]
Expand All @@ -50,15 +45,6 @@ pub struct FIMPrompt {
pub suffix: String,
}

impl FIMPrompt {
pub fn new(prefix: String, suffix: String) -> Self {
Self {
prompt: prefix,
suffix,
}
}
}

#[derive(Debug)]
pub enum Prompt {
FIM(FIMPrompt),
Expand Down Expand Up @@ -159,23 +145,25 @@ impl TryFrom<Config> for Box<dyn MemoryBackend + Send + Sync> {
#[cfg(test)]
impl Prompt {
pub fn default_with_cursor() -> Self {
Self::ContextAndCode(ContextAndCodePrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n <CURSOR>"#.to_string(),
))
Self::ContextAndCode(ContextAndCodePrompt {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n <CURSOR>"#.to_string(),
selected_text: None,
})
}

pub fn default_fim() -> Self {
Self::FIM(FIMPrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n "#.to_string(),
))
Self::FIM(FIMPrompt {
prompt: r#"def test_context():\n pass"#.to_string(),
suffix: r#"def test_code():\n "#.to_string(),
})
}

pub fn default_without_cursor() -> Self {
Self::ContextAndCode(ContextAndCodePrompt::new(
r#"def test_context():\n pass"#.to_string(),
r#"def test_code():\n "#.to_string(),
))
Self::ContextAndCode(ContextAndCodePrompt {
context: r#"def test_context():\n pass"#.to_string(),
code: r#"def test_code():\n "#.to_string(),
selected_text: None,
})
}
}
17 changes: 9 additions & 8 deletions crates/lsp-ai/src/memory_backends/postgresml/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,19 +589,20 @@ impl MemoryBackend for PostgresML {
// Reconstruct the Prompts
Ok(match code {
Prompt::ContextAndCode(context_and_code) => {
Prompt::ContextAndCode(ContextAndCodePrompt::new(
context.to_owned(),
format_file_chunk(
Prompt::ContextAndCode(ContextAndCodePrompt {
context: context.to_owned(),
code: format_file_chunk(
position.text_document.uri.as_ref(),
&context_and_code.code,
self.config.client_params.root_uri.as_deref(),
),
))
selected_text: None,
})
}
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt),
fim.suffix,
)),
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt {
prompt: format!("{context}\n\n{}", fim.prompt),
suffix: fim.suffix,
}),
})
}

Expand Down
18 changes: 9 additions & 9 deletions crates/lsp-ai/src/memory_backends/vector_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,19 +726,20 @@ impl MemoryBackend for VectorStore {
// Reconstruct the prompts
Ok(match code {
Prompt::ContextAndCode(context_and_code) => {
Prompt::ContextAndCode(ContextAndCodePrompt::new(
context.to_owned(),
format_file_chunk(
Prompt::ContextAndCode(ContextAndCodePrompt {
context: context.to_owned(),
code: format_file_chunk(
position.text_document.uri.as_ref(),
&context_and_code.code,
self.config.client_params.root_uri.as_deref(),
),
))
selected_text: None,
})
}
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt::new(
format!("{context}\n\n{}", fim.prompt),
fim.suffix,
)),
Prompt::FIM(fim) => Prompt::FIM(FIMPrompt {
prompt: format!("{context}\n\n{}", fim.prompt),
suffix: fim.suffix,
}),
})
}
}
Expand Down Expand Up @@ -935,7 +936,6 @@ assert multiply_two_numbers(2, 3) == 6

#[tokio::test]
async fn can_build_prompt() -> anyhow::Result<()> {
crate::init_logger();
let text_document1 = generate_filler_text_document(None, None);
let params = lsp_types::DidOpenTextDocumentParams {
text_document: text_document1.clone(),
Expand Down
Loading

0 comments on commit 3064bef

Please sign in to comment.