diff --git a/crates/pgls_lsp/src/handlers/code_actions.rs b/crates/pgls_lsp/src/handlers/code_actions.rs index 5fd1546d9..04d8b243d 100644 --- a/crates/pgls_lsp/src/handlers/code_actions.rs +++ b/crates/pgls_lsp/src/handlers/code_actions.rs @@ -53,6 +53,21 @@ pub fn get_actions( .map(|reason| CodeActionDisabled { reason }), ..Default::default() }), + CommandActionCategory::InvalidateSchemaCache => Some(CodeAction { + title: title.clone(), + kind: Some(lsp_types::CodeActionKind::EMPTY), + command: Some({ + Command { + title: title.clone(), + command: command_id, + arguments: None, + } + }), + disabled: action + .disabled_reason + .map(|reason| CodeActionDisabled { reason }), + ..Default::default() + }), } } @@ -68,7 +83,8 @@ pub fn get_actions( pub fn command_id(command: &CommandActionCategory) -> String { match command { - CommandActionCategory::ExecuteStatement(_) => "pgt.executeStatement".into(), + CommandActionCategory::ExecuteStatement(_) => "pgls.executeStatement".into(), + CommandActionCategory::InvalidateSchemaCache => "pgls.invalidateSchemaCache".into(), } } @@ -80,7 +96,7 @@ pub async fn execute_command( let command = params.command; match command.as_str() { - "pgt.executeStatement" => { + "pgls.executeStatement" => { let statement_id = serde_json::from_value::( params.arguments[0].clone(), )?; @@ -105,7 +121,16 @@ pub async fn execute_command( Ok(None) } + "pgls.invalidateSchemaCache" => { + session.workspace.invalidate_schema_cache(true)?; + session + .client + .show_message(MessageType::INFO, "Schema cache invalidated") + .await; + + Ok(None) + } any => Err(anyhow!(format!("Unknown command: {}", any))), } } diff --git a/crates/pgls_lsp/src/server.rs b/crates/pgls_lsp/src/server.rs index 18f38f007..b5bc92ac2 100644 --- a/crates/pgls_lsp/src/server.rs +++ b/crates/pgls_lsp/src/server.rs @@ -461,6 +461,7 @@ impl ServerFactory { workspace_method!(builder, get_completions); workspace_method!(builder, register_project_folder); workspace_method!(builder, unregister_project_folder); + workspace_method!(builder, invalidate_schema_cache); let (service, socket) = builder.finish(); ServerConnection { socket, service } diff --git a/crates/pgls_lsp/tests/server.rs b/crates/pgls_lsp/tests/server.rs index e4798c536..fd88774b6 100644 --- a/crates/pgls_lsp/tests/server.rs +++ b/crates/pgls_lsp/tests/server.rs @@ -916,7 +916,7 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> { .find_map(|action_or_cmd| match action_or_cmd { lsp::CodeActionOrCommand::CodeAction(code_action) => { let command = code_action.command.as_ref(); - if command.is_some_and(|cmd| &cmd.command == "pgt.executeStatement") { + if command.is_some_and(|cmd| &cmd.command == "pgls.executeStatement") { let command = command.unwrap(); let arguments = command.arguments.as_ref().unwrap().clone(); Some((command.command.clone(), arguments)) @@ -952,6 +952,164 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> { Ok(()) } +#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")] +async fn test_invalidate_schema_cache(test_db: PgPool) -> Result<()> { + let factory = ServerFactory::default(); + let mut fs = MemoryFileSystem::default(); + + let database = test_db + .connect_options() + .get_database() + .unwrap() + .to_string(); + let host = test_db.connect_options().get_host().to_string(); + + // Setup: Create a table with only id column (no name column yet) + let setup = r#" + create table public.users ( + id serial primary key + ); + "#; + + test_db + .execute(setup) + .await + .expect("Failed to setup test database"); + + let mut conf = PartialConfiguration::init(); + conf.merge_with(PartialConfiguration { + db: Some(PartialDatabaseConfiguration { + database: Some(database), + host: Some(host), + ..Default::default() + }), + ..Default::default() + }); + + fs.insert( + url!("postgres-language-server.jsonc") + .to_file_path() + .unwrap(), + serde_json::to_string_pretty(&conf).unwrap(), + ); + + let (service, client) = factory + .create_with_fs(None, DynRef::Owned(Box::new(fs))) + .into_inner(); + + let (stream, sink) = client.split(); + let mut server = Server::new(service); + + let (sender, _receiver) = channel(CHANNEL_BUFFER_SIZE); + let reader = tokio::spawn(client_handler(stream, sink, sender)); + + server.initialize().await?; + server.initialized().await?; + + server.load_configuration().await?; + + // Open a document to get completions from + let doc_content = "select from public.users;"; + server.open_document(doc_content).await?; + + // Get completions before adding the column - 'name' should NOT be present + let completions_before = server + .get_completion(CompletionParams { + work_done_progress_params: WorkDoneProgressParams::default(), + partial_result_params: PartialResultParams::default(), + context: None, + text_document_position: TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: url!("document.sql"), + }, + position: Position { + line: 0, + character: 7, + }, + }, + }) + .await? + .unwrap(); + + let items_before = match completions_before { + CompletionResponse::Array(ref a) => a, + CompletionResponse::List(ref l) => &l.items, + }; + + let has_name_before = items_before.iter().any(|item| { + item.label == "name" + && item.label_details.as_ref().is_some_and(|d| { + d.description + .as_ref() + .is_some_and(|desc| desc.contains("public.users")) + }) + }); + + assert!( + !has_name_before, + "Column 'name' should not be in completions before it's added to the table" + ); + + // Add the missing column to the database + let alter_table = r#" + alter table public.users + add column name text; + "#; + + test_db + .execute(alter_table) + .await + .expect("Failed to add column to table"); + + // Invalidate the schema cache (all = false for current connection only) + server + .request::("pgt/invalidate_schema_cache", "_invalidate_cache", false) + .await?; + + // Get completions after invalidating cache - 'name' should NOW be present + let completions_after = server + .get_completion(CompletionParams { + work_done_progress_params: WorkDoneProgressParams::default(), + partial_result_params: PartialResultParams::default(), + context: None, + text_document_position: TextDocumentPositionParams { + text_document: TextDocumentIdentifier { + uri: url!("document.sql"), + }, + position: Position { + line: 0, + character: 7, + }, + }, + }) + .await? + .unwrap(); + + let items_after = match completions_after { + CompletionResponse::Array(ref a) => a, + CompletionResponse::List(ref l) => &l.items, + }; + + let has_name_after = items_after.iter().any(|item| { + item.label == "name" + && item.label_details.as_ref().is_some_and(|d| { + d.description + .as_ref() + .is_some_and(|desc| desc.contains("public.users")) + }) + }); + + assert!( + has_name_after, + "Column 'name' should be in completions after schema cache invalidation" + ); + + server.shutdown().await?; + reader.abort(); + + Ok(()) +} + #[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")] async fn test_issue_281(test_db: PgPool) -> Result<()> { let factory = ServerFactory::default(); diff --git a/crates/pgls_workspace/src/features/code_actions.rs b/crates/pgls_workspace/src/features/code_actions.rs index 55a3d0ca0..fb16e49d2 100644 --- a/crates/pgls_workspace/src/features/code_actions.rs +++ b/crates/pgls_workspace/src/features/code_actions.rs @@ -48,6 +48,7 @@ pub struct CommandAction { #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] pub enum CommandActionCategory { ExecuteStatement(StatementId), + InvalidateSchemaCache, } #[derive(Debug, serde::Serialize, serde::Deserialize)] diff --git a/crates/pgls_workspace/src/workspace.rs b/crates/pgls_workspace/src/workspace.rs index dda370a48..d82ca5058 100644 --- a/crates/pgls_workspace/src/workspace.rs +++ b/crates/pgls_workspace/src/workspace.rs @@ -158,6 +158,14 @@ pub trait Workspace: Send + Sync + RefUnwindSafe { &self, params: ExecuteStatementParams, ) -> Result; + + /// Invalidate the schema cache. + /// + /// # Arguments + /// * `all` - If true, clears all cached schemas. If false, clears only the current connection's cache. + /// + /// The schema will be reloaded lazily on the next operation that requires it. + fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError>; } /// Convenience function for constructing a server instance of [Workspace] diff --git a/crates/pgls_workspace/src/workspace/client.rs b/crates/pgls_workspace/src/workspace/client.rs index 70f7c20a3..431e029fd 100644 --- a/crates/pgls_workspace/src/workspace/client.rs +++ b/crates/pgls_workspace/src/workspace/client.rs @@ -168,4 +168,8 @@ where ) -> Result { self.request("pgt/on_hover", params) } + + fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> { + self.request("pgt/invalidate_schema_cache", all) + } } diff --git a/crates/pgls_workspace/src/workspace/server.rs b/crates/pgls_workspace/src/workspace/server.rs index f7ae3225d..a3a388681 100644 --- a/crates/pgls_workspace/src/workspace/server.rs +++ b/crates/pgls_workspace/src/workspace/server.rs @@ -358,7 +358,7 @@ impl Workspace for WorkspaceServer { None => Some("Statement execution not allowed against database.".into()), }; - let actions = parser + let mut actions: Vec = parser .iter_with_filter( DefaultMapper, CursorPositionFilter::new(params.cursor_position), @@ -379,6 +379,20 @@ impl Workspace for WorkspaceServer { }) .collect(); + let invalidate_disabled_reason = if self.get_current_connection().is_some() { + None + } else { + Some("No database connection available.".into()) + }; + + actions.push(CodeAction { + title: "Invalidate Schema Cache".into(), + kind: CodeActionKind::Command(CommandAction { + category: CommandActionCategory::InvalidateSchemaCache, + }), + disabled_reason: invalidate_disabled_reason, + }); + Ok(CodeActionsResult { actions }) } @@ -424,6 +438,19 @@ impl Workspace for WorkspaceServer { }) } + fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> { + if all { + self.schema_cache.clear_all(); + } else { + // Only clear current connection if one exists + if let Some(pool) = self.get_current_connection() { + self.schema_cache.clear(&pool); + } + // If no connection, nothing to clear - just return Ok + } + Ok(()) + } + #[ignored_path(path=¶ms.path)] fn pull_diagnostics( &self, diff --git a/crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs b/crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs index 9786946bf..fe7ee41c4 100644 --- a/crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs +++ b/crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs @@ -46,4 +46,17 @@ impl SchemaCacheManager { schemas.insert(key, schema_cache.clone()); Ok(schema_cache) } + + /// Clear the schema cache for a specific connection + pub fn clear(&self, pool: &PgPool) { + let key: ConnectionKey = pool.into(); + let mut schemas = self.schemas.write().unwrap(); + schemas.remove(&key); + } + + /// Clear all schema caches + pub fn clear_all(&self) { + let mut schemas = self.schemas.write().unwrap(); + schemas.clear(); + } }