Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions crates/pgls_lsp/src/handlers/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ pub fn get_actions(
.map(|reason| CodeActionDisabled { reason }),
..Default::default()
}),
CommandActionCategory::InvalidateSchemaCache => Some(CodeAction {
title: title.clone(),
kind: Some(lsp_types::CodeActionKind::EMPTY),
command: Some({
Command {
title: title.clone(),
command: command_id,
arguments: None,
}
}),
disabled: action
.disabled_reason
.map(|reason| CodeActionDisabled { reason }),
..Default::default()
}),
}
}

Expand All @@ -68,7 +83,8 @@ pub fn get_actions(

pub fn command_id(command: &CommandActionCategory) -> String {
match command {
CommandActionCategory::ExecuteStatement(_) => "pgt.executeStatement".into(),
CommandActionCategory::ExecuteStatement(_) => "pgls.executeStatement".into(),
CommandActionCategory::InvalidateSchemaCache => "pgls.invalidateSchemaCache".into(),
}
}

Expand All @@ -80,7 +96,7 @@ pub async fn execute_command(
let command = params.command;

match command.as_str() {
"pgt.executeStatement" => {
"pgls.executeStatement" => {
let statement_id = serde_json::from_value::<pgls_workspace::workspace::StatementId>(
params.arguments[0].clone(),
)?;
Expand All @@ -105,7 +121,16 @@ pub async fn execute_command(

Ok(None)
}
"pgls.invalidateSchemaCache" => {
session.workspace.invalidate_schema_cache(true)?;

session
.client
.show_message(MessageType::INFO, "Schema cache invalidated")
.await;

Ok(None)
}
any => Err(anyhow!(format!("Unknown command: {}", any))),
}
}
1 change: 1 addition & 0 deletions crates/pgls_lsp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ impl ServerFactory {
workspace_method!(builder, get_completions);
workspace_method!(builder, register_project_folder);
workspace_method!(builder, unregister_project_folder);
workspace_method!(builder, invalidate_schema_cache);

let (service, socket) = builder.finish();
ServerConnection { socket, service }
Expand Down
160 changes: 159 additions & 1 deletion crates/pgls_lsp/tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> {
.find_map(|action_or_cmd| match action_or_cmd {
lsp::CodeActionOrCommand::CodeAction(code_action) => {
let command = code_action.command.as_ref();
if command.is_some_and(|cmd| &cmd.command == "pgt.executeStatement") {
if command.is_some_and(|cmd| &cmd.command == "pgls.executeStatement") {
let command = command.unwrap();
let arguments = command.arguments.as_ref().unwrap().clone();
Some((command.command.clone(), arguments))
Expand Down Expand Up @@ -952,6 +952,164 @@ async fn test_execute_statement(test_db: PgPool) -> Result<()> {
Ok(())
}

#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")]
async fn test_invalidate_schema_cache(test_db: PgPool) -> Result<()> {
let factory = ServerFactory::default();
let mut fs = MemoryFileSystem::default();

let database = test_db
.connect_options()
.get_database()
.unwrap()
.to_string();
let host = test_db.connect_options().get_host().to_string();

// Setup: Create a table with only id column (no name column yet)
let setup = r#"
create table public.users (
id serial primary key
);
"#;

test_db
.execute(setup)
.await
.expect("Failed to setup test database");

let mut conf = PartialConfiguration::init();
conf.merge_with(PartialConfiguration {
db: Some(PartialDatabaseConfiguration {
database: Some(database),
host: Some(host),
..Default::default()
}),
..Default::default()
});

fs.insert(
url!("postgres-language-server.jsonc")
.to_file_path()
.unwrap(),
serde_json::to_string_pretty(&conf).unwrap(),
);

let (service, client) = factory
.create_with_fs(None, DynRef::Owned(Box::new(fs)))
.into_inner();

let (stream, sink) = client.split();
let mut server = Server::new(service);

let (sender, _receiver) = channel(CHANNEL_BUFFER_SIZE);
let reader = tokio::spawn(client_handler(stream, sink, sender));

server.initialize().await?;
server.initialized().await?;

server.load_configuration().await?;

// Open a document to get completions from
let doc_content = "select from public.users;";
server.open_document(doc_content).await?;

// Get completions before adding the column - 'name' should NOT be present
let completions_before = server
.get_completion(CompletionParams {
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
context: None,
text_document_position: TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: url!("document.sql"),
},
position: Position {
line: 0,
character: 7,
},
},
})
.await?
.unwrap();

let items_before = match completions_before {
CompletionResponse::Array(ref a) => a,
CompletionResponse::List(ref l) => &l.items,
};

let has_name_before = items_before.iter().any(|item| {
item.label == "name"
&& item.label_details.as_ref().is_some_and(|d| {
d.description
.as_ref()
.is_some_and(|desc| desc.contains("public.users"))
})
});

assert!(
!has_name_before,
"Column 'name' should not be in completions before it's added to the table"
);

// Add the missing column to the database
let alter_table = r#"
alter table public.users
add column name text;
"#;

test_db
.execute(alter_table)
.await
.expect("Failed to add column to table");

// Invalidate the schema cache (all = false for current connection only)
server
.request::<bool, ()>("pgt/invalidate_schema_cache", "_invalidate_cache", false)
.await?;

// Get completions after invalidating cache - 'name' should NOW be present
let completions_after = server
.get_completion(CompletionParams {
work_done_progress_params: WorkDoneProgressParams::default(),
partial_result_params: PartialResultParams::default(),
context: None,
text_document_position: TextDocumentPositionParams {
text_document: TextDocumentIdentifier {
uri: url!("document.sql"),
},
position: Position {
line: 0,
character: 7,
},
},
})
.await?
.unwrap();

let items_after = match completions_after {
CompletionResponse::Array(ref a) => a,
CompletionResponse::List(ref l) => &l.items,
};

let has_name_after = items_after.iter().any(|item| {
item.label == "name"
&& item.label_details.as_ref().is_some_and(|d| {
d.description
.as_ref()
.is_some_and(|desc| desc.contains("public.users"))
})
});

assert!(
has_name_after,
"Column 'name' should be in completions after schema cache invalidation"
);

server.shutdown().await?;
reader.abort();

Ok(())
}

#[sqlx::test(migrator = "pgls_test_utils::MIGRATIONS")]
async fn test_issue_281(test_db: PgPool) -> Result<()> {
let factory = ServerFactory::default();
Expand Down
1 change: 1 addition & 0 deletions crates/pgls_workspace/src/features/code_actions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub struct CommandAction {
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub enum CommandActionCategory {
ExecuteStatement(StatementId),
InvalidateSchemaCache,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
Expand Down
8 changes: 8 additions & 0 deletions crates/pgls_workspace/src/workspace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ pub trait Workspace: Send + Sync + RefUnwindSafe {
&self,
params: ExecuteStatementParams,
) -> Result<ExecuteStatementResult, WorkspaceError>;

/// Invalidate the schema cache.
///
/// # Arguments
/// * `all` - If true, clears all cached schemas. If false, clears only the current connection's cache.
///
/// The schema will be reloaded lazily on the next operation that requires it.
fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError>;
}

/// Convenience function for constructing a server instance of [Workspace]
Expand Down
4 changes: 4 additions & 0 deletions crates/pgls_workspace/src/workspace/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,8 @@ where
) -> Result<crate::features::on_hover::OnHoverResult, WorkspaceError> {
self.request("pgt/on_hover", params)
}

fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> {
self.request("pgt/invalidate_schema_cache", all)
}
}
29 changes: 28 additions & 1 deletion crates/pgls_workspace/src/workspace/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ impl Workspace for WorkspaceServer {
None => Some("Statement execution not allowed against database.".into()),
};

let actions = parser
let mut actions: Vec<CodeAction> = parser
.iter_with_filter(
DefaultMapper,
CursorPositionFilter::new(params.cursor_position),
Expand All @@ -379,6 +379,20 @@ impl Workspace for WorkspaceServer {
})
.collect();

let invalidate_disabled_reason = if self.get_current_connection().is_some() {
None
} else {
Some("No database connection available.".into())
};

actions.push(CodeAction {
title: "Invalidate Schema Cache".into(),
kind: CodeActionKind::Command(CommandAction {
category: CommandActionCategory::InvalidateSchemaCache,
}),
disabled_reason: invalidate_disabled_reason,
});

Ok(CodeActionsResult { actions })
}

Expand Down Expand Up @@ -424,6 +438,19 @@ impl Workspace for WorkspaceServer {
})
}

fn invalidate_schema_cache(&self, all: bool) -> Result<(), WorkspaceError> {
if all {
self.schema_cache.clear_all();
} else {
// Only clear current connection if one exists
if let Some(pool) = self.get_current_connection() {
self.schema_cache.clear(&pool);
}
// If no connection, nothing to clear - just return Ok
}
Ok(())
}

#[ignored_path(path=&params.path)]
fn pull_diagnostics(
&self,
Expand Down
13 changes: 13 additions & 0 deletions crates/pgls_workspace/src/workspace/server/schema_cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,17 @@ impl SchemaCacheManager {
schemas.insert(key, schema_cache.clone());
Ok(schema_cache)
}

/// Clear the schema cache for a specific connection
pub fn clear(&self, pool: &PgPool) {
let key: ConnectionKey = pool.into();
let mut schemas = self.schemas.write().unwrap();
schemas.remove(&key);
}

/// Clear all schema caches
pub fn clear_all(&self) {
let mut schemas = self.schemas.write().unwrap();
schemas.clear();
}
}