From 9a28bee12808555c7311e0bf9d8fbbd24fd80da8 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 11:22:47 +0100 Subject: [PATCH 01/10] initial commit --- Cargo.lock | 11 ++ Cargo.toml | 1 + crates/pg_completions/Cargo.toml | 1 + crates/pg_completions/src/complete.rs | 4 +- crates/pg_completions/src/context.rs | 86 ++++++---- .../pg_completions/src/providers/functions.rs | 8 +- crates/pg_completions/src/providers/tables.rs | 8 +- crates/pg_test_utils/Cargo.toml | 4 + .../pg_test_utils/src/bin/tree_query_debug.rs | 91 ++++++++++ crates/pg_treesitter_queries/Cargo.toml | 25 +++ crates/pg_treesitter_queries/src/lib.rs | 162 ++++++++++++++++++ .../pg_treesitter_queries/src/queries/mod.rs | 41 +++++ .../src/queries/relations.rs | 85 +++++++++ 13 files changed, 486 insertions(+), 41 deletions(-) create mode 100644 crates/pg_test_utils/src/bin/tree_query_debug.rs create mode 100644 crates/pg_treesitter_queries/Cargo.toml create mode 100644 crates/pg_treesitter_queries/src/lib.rs create mode 100644 crates/pg_treesitter_queries/src/queries/mod.rs create mode 100644 crates/pg_treesitter_queries/src/queries/relations.rs diff --git a/Cargo.lock b/Cargo.lock index 222e79b5f..bfcce706b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2313,6 +2313,7 @@ dependencies = [ "async-std", "pg_schema_cache", "pg_test_utils", + "pg_treesitter_queries", "sqlx", "text-size", "tokio", @@ -2665,6 +2666,16 @@ dependencies = [ "text-size", ] +[[package]] +name = "pg_treesitter_queries" +version = "0.0.0" +dependencies = [ + "clap", + "tokio", + "tree-sitter", + "tree_sitter_sql", +] + [[package]] name = "pg_type_resolver" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 54a18bd88..5b6fb00a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -75,6 +75,7 @@ pg_schema_cache = { path = "./crates/pg_schema_cache", version = "0.0. pg_statement_splitter = { path = "./crates/pg_statement_splitter", version = "0.0.0" } pg_syntax = { path = "./crates/pg_syntax", version = "0.0.0" } pg_text_edit = { path = "./crates/pg_text_edit", version = "0.0.0" } +pg_treesitter_queries = { path = "./crates/pg_treesitter_queries", version = "0.0.0" } pg_type_resolver = { path = "./crates/pg_type_resolver", version = "0.0.0" } pg_typecheck = { path = "./crates/pg_typecheck", version = "0.0.0" } pg_workspace = { path = "./crates/pg_workspace", version = "0.0.0" } diff --git a/crates/pg_completions/Cargo.toml b/crates/pg_completions/Cargo.toml index 89a370593..8852bf5a7 100644 --- a/crates/pg_completions/Cargo.toml +++ b/crates/pg_completions/Cargo.toml @@ -19,6 +19,7 @@ text-size.workspace = true pg_schema_cache.workspace = true tree-sitter.workspace = true tree_sitter_sql.workspace = true +pg_treesitter_queries.workspace = true sqlx.workspace = true diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs index 6ea1a1398..0495f2ed2 100644 --- a/crates/pg_completions/src/complete.rs +++ b/crates/pg_completions/src/complete.rs @@ -30,8 +30,8 @@ impl IntoIterator for CompletionResult { } } -pub fn complete(params: CompletionParams) -> CompletionResult { - let ctx = CompletionContext::new(¶ms); +pub async fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult { + let ctx = CompletionContext::new(¶ms).await; let mut builder = CompletionBuilder::new(); diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index 79354cf6d..c271983b2 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -1,4 +1,7 @@ +use std::ops::Range; + use pg_schema_cache::SchemaCache; +use pg_treesitter_queries::{queries, TreeSitterQueriesExecutor}; use crate::CompletionParams; @@ -52,10 +55,13 @@ pub(crate) struct CompletionContext<'a> { pub schema_name: Option, pub wrapping_clause_type: Option, pub is_invocation: bool, + pub wrapping_statement_range: Option>, + + pub ts_query_executor: Option>, } impl<'a> CompletionContext<'a> { - pub fn new(params: &'a CompletionParams) -> Self { + pub async fn new(params: &'a CompletionParams<'a>) -> Self { let mut ctx = Self { tree: params.tree, text: ¶ms.text, @@ -65,14 +71,30 @@ impl<'a> CompletionContext<'a> { ts_node: None, schema_name: None, wrapping_clause_type: None, + wrapping_statement_range: None, is_invocation: false, + ts_query_executor: None, }; ctx.gather_tree_context(); + ctx.dispatch_ts_queries().await; ctx } + async fn dispatch_ts_queries(&mut self) { + let tree = match self.tree.as_ref() { + None => return, + Some(t) => t, + }; + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text); + + executor.add_query_results::().await; + + self.ts_query_executor = Some(executor); + } + pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { let source = self.text; match ts_node.utf8_text(source.as_bytes()) { @@ -100,36 +122,38 @@ impl<'a> CompletionContext<'a> { * We'll therefore adjust the cursor position such that it meets the last node of the AST. * `select * from use {}` becomes `select * from use{}`. */ - let current_node_kind = cursor.node().kind(); + let current_node = cursor.node(); while cursor.goto_first_child_for_byte(self.position).is_none() && self.position > 0 { self.position -= 1; } - self.gather_context_from_node(cursor, current_node_kind); + self.gather_context_from_node(cursor, current_node); } fn gather_context_from_node( &mut self, mut cursor: tree_sitter::TreeCursor<'a>, - previous_node_kind: &str, + previous_node: tree_sitter::Node<'a>, ) { let current_node = cursor.node(); - let current_node_kind = current_node.kind(); // prevent infinite recursion – this can happen if we only have a PROGRAM node - if current_node_kind == previous_node_kind { + if current_node.kind() == previous_node.kind() { self.ts_node = Some(current_node); return; } - match previous_node_kind { - "statement" => self.wrapping_clause_type = current_node_kind.try_into().ok(), + match previous_node.kind() { + "statement" => { + self.wrapping_clause_type = current_node.kind().try_into().ok(); + self.wrapping_statement_range = Some(previous_node.byte_range()); + } "invocation" => self.is_invocation = true, _ => {} } - match current_node_kind { + match current_node.kind() { "object_reference" => { let txt = self.get_ts_node_content(current_node); if let Some(txt) = txt { @@ -159,7 +183,7 @@ impl<'a> CompletionContext<'a> { } cursor.goto_first_child_for_byte(self.position); - self.gather_context_from_node(cursor, current_node_kind); + self.gather_context_from_node(cursor, current_node); } } @@ -179,8 +203,8 @@ mod tests { parser.parse(input, None).expect("Unable to parse tree") } - #[test] - fn identifies_clauses() { + #[tokio::test] + async fn identifies_clauses() { let test_cases = vec![ (format!("Select {}* from users;", CURSOR_POS), "select"), (format!("Select * from u{};", CURSOR_POS), "from"), @@ -220,14 +244,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); } } - #[test] - fn identifies_schema() { + #[tokio::test] + async fn identifies_schema() { let test_cases = vec![ ( format!("Select * from private.u{}", CURSOR_POS), @@ -252,14 +276,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); } } - #[test] - fn identifies_invocation() { + #[tokio::test] + async fn identifies_invocation() { let test_cases = vec![ (format!("Select * from u{}sers", CURSOR_POS), false), (format!("Select * from u{}sers()", CURSOR_POS), true), @@ -286,14 +310,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; assert_eq!(ctx.is_invocation, is_invocation); } } - #[test] - fn does_not_fail_on_leading_whitespace() { + #[tokio::test] + async fn does_not_fail_on_leading_whitespace() { let cases = vec![ format!("{} select * from", CURSOR_POS), format!(" {} select * from", CURSOR_POS), @@ -311,7 +335,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; let node = ctx.ts_node.map(|n| n.clone()).unwrap(); @@ -324,8 +348,8 @@ mod tests { } } - #[test] - fn does_not_fail_on_trailing_whitespace() { + #[tokio::test] + async fn does_not_fail_on_trailing_whitespace() { let query = format!("select * from {}", CURSOR_POS); let (position, text) = get_text_and_position(query.as_str()); @@ -339,7 +363,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; let node = ctx.ts_node.map(|n| n.clone()).unwrap(); @@ -350,8 +374,8 @@ mod tests { ); } - #[test] - fn does_not_fail_with_empty_statements() { + #[tokio::test] + async fn does_not_fail_with_empty_statements() { let query = format!("{}", CURSOR_POS); let (position, text) = get_text_and_position(query.as_str()); @@ -365,7 +389,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; let node = ctx.ts_node.map(|n| n.clone()).unwrap(); @@ -373,8 +397,8 @@ mod tests { assert_eq!(ctx.wrapping_clause_type, None); } - #[test] - fn does_not_fail_on_incomplete_keywords() { + #[tokio::test] + async fn does_not_fail_on_incomplete_keywords() { // Instead of autocompleting "FROM", we'll assume that the user // is selecting a certain column name, such as `frozen_account`. let query = format!("select * fro{}", CURSOR_POS); @@ -390,7 +414,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms); + let ctx = CompletionContext::new(¶ms).await; let node = ctx.ts_node.map(|n| n.clone()).unwrap(); diff --git a/crates/pg_completions/src/providers/functions.rs b/crates/pg_completions/src/providers/functions.rs index d6c9db4c7..09ea9419e 100644 --- a/crates/pg_completions/src/providers/functions.rs +++ b/crates/pg_completions/src/providers/functions.rs @@ -45,7 +45,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; let CompletionItem { label, .. } = results .into_iter() @@ -78,7 +78,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; let CompletionItem { label, kind, .. } = results .into_iter() @@ -112,7 +112,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; let CompletionItem { label, kind, .. } = results .into_iter() @@ -146,7 +146,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; let CompletionItem { label, kind, .. } = results .into_iter() diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 70574ec85..0bc642516 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -43,7 +43,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; assert!(!results.items.is_empty()); @@ -81,7 +81,7 @@ mod tests { for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; assert!(!results.items.is_empty()); @@ -126,7 +126,7 @@ mod tests { for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; assert!(!results.items.is_empty()); @@ -163,7 +163,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params); + let results = complete(params).await; let CompletionItem { label, kind, .. } = results .into_iter() diff --git a/crates/pg_test_utils/Cargo.toml b/crates/pg_test_utils/Cargo.toml index ea8377937..6325193ad 100644 --- a/crates/pg_test_utils/Cargo.toml +++ b/crates/pg_test_utils/Cargo.toml @@ -15,6 +15,10 @@ version = "0.0.0" name = "tree_print" path = "src/bin/tree_print.rs" +[[bin]] +name = "query_debug" +path = "src/bin/tree_query_debug.rs" + [dependencies] anyhow = "1.0.81" clap = { version = "4.5.23", features = ["derive"] } diff --git a/crates/pg_test_utils/src/bin/tree_query_debug.rs b/crates/pg_test_utils/src/bin/tree_query_debug.rs new file mode 100644 index 000000000..4e56f673b --- /dev/null +++ b/crates/pg_test_utils/src/bin/tree_query_debug.rs @@ -0,0 +1,91 @@ +use clap::*; + +#[derive(Parser)] +#[command(name = "query-debugger", about = "Debugs a query")] +struct Args { + #[arg(long = "file", short = 'f')] + file: String, +} + +fn main() { + let args = Args::parse(); + + let stmt = std::fs::read_to_string(&args.file).expect("Failed to read file."); + + let mut parser = tree_sitter::Parser::new(); + let lang = tree_sitter_sql::language(); + parser + .set_language(lang.clone()) + .expect("Setting Language failed."); + + let tree = parser + .parse(stmt.clone(), None) + .expect("Failed to parse Statement"); + + let results = relation_matches(tree.root_node(), &stmt); + + for r in results { + println!("{}", r.to_full_name(&stmt)) + } +} + +struct RelationMatch<'a> { + schema: Option>, + table: tree_sitter::Node<'a>, +} + +impl<'a> RelationMatch<'a> { + fn to_full_name(&self, stmt: &str) -> String { + match self.schema { + Some(s) => format!( + "{}.{}", + s.utf8_text(stmt.as_bytes()).unwrap(), + self.table.utf8_text(stmt.as_bytes()).unwrap() + ), + None => format!("{}", self.table.utf8_text(stmt.as_bytes()).unwrap()), + } + } +} + +fn relation_matches<'a>(root_node: tree_sitter::Node<'a>, stmt: &str) -> Vec> { + static QUERY: &str = r#" + (relation + (object_reference + (identifier)+ @schema_or_table + "." + (identifier) @table + )+ + ) + "#; + + let query = + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY).expect("Invalid Query!"); + + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&query, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(RelationMatch { + schema: None, + table: capture, + }); + } + + if m.captures.len() == 2 { + let schema = m.captures[0].node; + let table = m.captures[1].node; + + to_return.push(RelationMatch { + schema: Some(schema), + table, + }); + } + } + + to_return +} diff --git a/crates/pg_treesitter_queries/Cargo.toml b/crates/pg_treesitter_queries/Cargo.toml new file mode 100644 index 000000000..edc917def --- /dev/null +++ b/crates/pg_treesitter_queries/Cargo.toml @@ -0,0 +1,25 @@ +[package] +authors.workspace = true +categories.workspace = true +description = "" +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +name = "pg_treesitter_queries" +repository.workspace = true +version = "0.0.0" + + +[dependencies] +clap = { version = "4.5.23", features = ["derive"] } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true +tokio.workspace = true + +[dev-dependencies] + +[lib] +doctest = false + +[features] diff --git a/crates/pg_treesitter_queries/src/lib.rs b/crates/pg_treesitter_queries/src/lib.rs new file mode 100644 index 000000000..dbb30e6cc --- /dev/null +++ b/crates/pg_treesitter_queries/src/lib.rs @@ -0,0 +1,162 @@ +pub mod queries; + +use std::{ops::Range, slice::Iter}; + +use queries::{Query, QueryResult}; + +pub struct TreeSitterQueriesExecutor<'a> { + root_node: tree_sitter::Node<'a>, + stmt: &'a str, + results: Vec>, +} + +impl<'a> TreeSitterQueriesExecutor<'a> { + pub fn new(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Self { + Self { + root_node, + stmt, + results: vec![], + } + } + + #[allow(private_bounds)] + pub async fn add_query_results>(&mut self) { + let mut results = Q::execute(self.root_node, &self.stmt).await; + self.results.append(&mut results); + } + + pub fn get_iter(&self, range: Option>) -> QueryResultIter { + match range { + Some(r) => QueryResultIter::new(&self.results).within_range(r), + None => QueryResultIter::new(&self.results), + } + } +} + +pub struct QueryResultIter<'a> { + inner: Iter<'a, QueryResult<'a>>, + range: Option>, +} + +impl<'a> QueryResultIter<'a> { + pub(crate) fn new(results: &'a Vec>) -> Self { + Self { + inner: results.iter(), + range: None, + } + } + + pub fn within_range(mut self, r: Range) -> Self { + self.range = Some(r); + self + } +} + +impl<'a> Iterator for QueryResultIter<'a> { + type Item = &'a QueryResult<'a>; + fn next(&mut self) -> Option { + match self.inner.next() { + None => return None, + Some(n) => { + if self.range.as_ref().is_some_and(|r| !n.within_range(r)) { + return self.next(); + } + + Some(n) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::{queries::RelationMatch, TreeSitterQueriesExecutor}; + + #[tokio::test] + async fn finds_all_relations_and_ignores_functions() { + let sql = r#" +select + * +from + ( + select + something + from + public.cool_table pu + join private.cool_tableau pr on pu.id = pr.id + where + x = '123' + union + select + something_else + from + another_table puat + inner join private.another_tableau prat on puat.id = prat.id + union + select + x, + y + from + public.get_something_cool () + ) +where + col = 17; +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(&sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), &sql); + + executor.add_query_results::().await; + + let results: Vec<&RelationMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!( + results[0] + .schema + .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), + Some("public") + ); + assert_eq!( + results[0].table.utf8_text(&sql.as_bytes()).unwrap(), + "cool_table" + ); + + assert_eq!( + results[1] + .schema + .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), + Some("private") + ); + assert_eq!( + results[1].table.utf8_text(&sql.as_bytes()).unwrap(), + "cool_tableau" + ); + + assert_eq!(results[2].schema, None); + assert_eq!( + results[2].table.utf8_text(&sql.as_bytes()).unwrap(), + "another_table" + ); + + assert_eq!( + results[3] + .schema + .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), + Some("private") + ); + assert_eq!( + results[3].table.utf8_text(&sql.as_bytes()).unwrap(), + "another_tableau" + ); + + // we have exhausted the matches: function invocations are ignored. + assert!(results.len() == 4); + } +} diff --git a/crates/pg_treesitter_queries/src/queries/mod.rs b/crates/pg_treesitter_queries/src/queries/mod.rs new file mode 100644 index 000000000..793c8a2e5 --- /dev/null +++ b/crates/pg_treesitter_queries/src/queries/mod.rs @@ -0,0 +1,41 @@ +mod relations; + +use std::ops::Range; + +pub use relations::*; + +pub enum QueryResult<'a> { + Relation(RelationMatch<'a>), +} + +impl<'a> QueryResult<'a> { + pub fn within_range(&self, range: &Range) -> bool { + match self { + Self::Relation(rm) => { + let tb_range = rm.table.byte_range(); + + let start = match rm.schema { + Some(s) => s.byte_range().start, + None => tb_range.start, + }; + + let end = tb_range.end; + + range.contains(&start) && range.contains(&end) + } + } + } +} + +// This trait enforces that for any `Self` that implements `Query`, +// its &Self must implement TryFrom<&QueryResult> +pub(crate) trait QueryTryFrom<'a>: Sized { + type Ref: for<'any> TryFrom<&'a QueryResult<'a>, Error = String>; +} + +pub(crate) trait Query<'a>: QueryTryFrom<'a> { + async fn execute( + root_node: tree_sitter::Node<'a>, + stmt: &'a str, + ) -> Vec>; +} diff --git a/crates/pg_treesitter_queries/src/queries/relations.rs b/crates/pg_treesitter_queries/src/queries/relations.rs new file mode 100644 index 000000000..3d262c8c2 --- /dev/null +++ b/crates/pg_treesitter_queries/src/queries/relations.rs @@ -0,0 +1,85 @@ +use std::sync::Arc; +use tokio::sync::OnceCell; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static INSTANCE: OnceCell> = OnceCell::const_new(); + +static QUERY: &'static str = r#" + (relation + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + )+ + ) +"#; + +pub struct RelationMatch<'a> { + pub(crate) schema: Option>, + pub(crate) table: tree_sitter::Node<'a>, +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a RelationMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::Relation(r) => Ok(&r), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for RelationMatch<'a> { + type Ref = &'a RelationMatch<'a>; +} + +impl<'a> Query<'a> for RelationMatch<'a> { + async fn execute( + root_node: tree_sitter::Node<'a>, + stmt: &'a str, + ) -> Vec> { + let query = INSTANCE + .get_or_init(|| async { + Arc::new( + tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY) + .expect("Invalid Query."), + ) + }) + .await; + + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&query, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::Relation(RelationMatch { + schema: None, + table: capture, + })); + } + + if m.captures.len() == 2 { + let schema = m.captures[0].node; + let table = m.captures[1].node; + + to_return.push(QueryResult::Relation(RelationMatch { + schema: Some(schema), + table, + })); + } + } + + to_return + } +} From 4ca0f963a23a6c1900fe6090a7c7e2d62aee6739 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 12:20:13 +0100 Subject: [PATCH 02/10] jeez --- Cargo.lock | 1 - crates/pg_completions/src/complete.rs | 7 +- crates/pg_completions/src/context.rs | 92 ++++++++++++------- .../pg_completions/src/providers/columns.rs | 20 ++++ .../pg_completions/src/providers/functions.rs | 8 +- crates/pg_completions/src/providers/mod.rs | 2 + crates/pg_completions/src/providers/tables.rs | 8 +- crates/pg_completions/src/relevance.rs | 57 +++++++++++- crates/pg_schema_cache/src/lib.rs | 1 + crates/pg_treesitter_queries/Cargo.toml | 1 - crates/pg_treesitter_queries/src/lib.rs | 16 ++-- .../pg_treesitter_queries/src/queries/mod.rs | 5 +- .../src/queries/relations.rs | 31 +++---- 13 files changed, 172 insertions(+), 77 deletions(-) create mode 100644 crates/pg_completions/src/providers/columns.rs diff --git a/Cargo.lock b/Cargo.lock index f3a3fa993..f36aea5e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2674,7 +2674,6 @@ name = "pg_treesitter_queries" version = "0.0.0" dependencies = [ "clap", - "tokio", "tree-sitter", "tree_sitter_sql", ] diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs index 4e30b5e6f..4b6f79d74 100644 --- a/crates/pg_completions/src/complete.rs +++ b/crates/pg_completions/src/complete.rs @@ -5,7 +5,7 @@ use crate::{ builder::CompletionBuilder, context::CompletionContext, item::CompletionItem, - providers::{complete_functions, complete_tables}, + providers::{complete_columns, complete_functions, complete_tables}, }; pub const LIMIT: usize = 50; @@ -31,13 +31,14 @@ impl IntoIterator for CompletionResult { } } -pub async fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult { - let ctx = CompletionContext::new(¶ms).await; +pub fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult { + let ctx = CompletionContext::new(¶ms); let mut builder = CompletionBuilder::new(); complete_tables(&ctx, &mut builder); complete_functions(&ctx, &mut builder); + complete_columns(&ctx, &mut builder); builder.finish() } diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index 4d4d4c04f..3d0eb14fc 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -1,7 +1,14 @@ -use std::ops::Range; +use std::{ + collections::{HashMap, HashSet}, + hash::Hash, + ops::Range, +}; use pg_schema_cache::SchemaCache; -use pg_treesitter_queries::{queries, TreeSitterQueriesExecutor}; +use pg_treesitter_queries::{ + queries::{self, QueryResult}, + TreeSitterQueriesExecutor, +}; use crate::CompletionParams; @@ -57,11 +64,11 @@ pub(crate) struct CompletionContext<'a> { pub is_invocation: bool, pub wrapping_statement_range: Option>, - pub ts_query_executor: Option>, + pub mentioned_relations: HashMap, HashSet>, } impl<'a> CompletionContext<'a> { - pub async fn new(params: &'a CompletionParams<'a>) -> Self { + pub fn new(params: &'a CompletionParams<'a>) -> Self { let mut ctx = Self { tree: params.tree, text: ¶ms.text, @@ -73,26 +80,49 @@ impl<'a> CompletionContext<'a> { wrapping_clause_type: None, wrapping_statement_range: None, is_invocation: false, - ts_query_executor: None, + mentioned_relations: HashMap::new(), }; ctx.gather_tree_context(); - ctx.dispatch_ts_queries().await; + ctx.gather_info_from_ts_queries(); ctx } - async fn dispatch_ts_queries(&mut self) { + fn gather_info_from_ts_queries(&mut self) { let tree = match self.tree.as_ref() { None => return, Some(t) => t, }; - let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text); + let stmt_range = self.wrapping_statement_range.as_ref(); + let sql = self.text; - executor.add_query_results::().await; + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text); - self.ts_query_executor = Some(executor); + executor.add_query_results::(); + + for relation_match in executor.get_iter(stmt_range) { + match relation_match { + QueryResult::Relation(r) => { + let schema_name = r.get_schema(sql); + let table_name = r.get_table(sql); + + let current = self.mentioned_relations.get_mut(&schema_name); + + match current { + Some(c) => { + c.insert(table_name); + } + None => { + let mut new = HashSet::new(); + new.insert(table_name); + self.mentioned_relations.insert(schema_name, new); + } + }; + } + }; + } } pub fn get_ts_node_content(&self, ts_node: tree_sitter::Node<'a>) -> Option<&'a str> { @@ -203,8 +233,8 @@ mod tests { parser.parse(input, None).expect("Unable to parse tree") } - #[tokio::test] - async fn identifies_clauses() { + #[test] + fn identifies_clauses() { let test_cases = vec![ (format!("Select {}* from users;", CURSOR_POS), "select"), (format!("Select * from u{};", CURSOR_POS), "from"), @@ -244,14 +274,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.wrapping_clause_type, expected_clause.try_into().ok()); } } - #[tokio::test] - async fn identifies_schema() { + #[test] + fn identifies_schema() { let test_cases = vec![ ( format!("Select * from private.u{}", CURSOR_POS), @@ -276,14 +306,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.schema_name, expected_schema.map(|f| f.to_string())); } } - #[tokio::test] - async fn identifies_invocation() { + #[test] + fn identifies_invocation() { let test_cases = vec![ (format!("Select * from u{}sers", CURSOR_POS), false), (format!("Select * from u{}sers()", CURSOR_POS), true), @@ -310,14 +340,14 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); assert_eq!(ctx.is_invocation, is_invocation); } } - #[tokio::test] - async fn does_not_fail_on_leading_whitespace() { + #[test] + fn does_not_fail_on_leading_whitespace() { let cases = vec![ format!("{} select * from", CURSOR_POS), format!(" {} select * from", CURSOR_POS), @@ -335,7 +365,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); let node = ctx.ts_node.unwrap(); @@ -348,8 +378,8 @@ mod tests { } } - #[tokio::test] - async fn does_not_fail_on_trailing_whitespace() { + #[test] + fn does_not_fail_on_trailing_whitespace() { let query = format!("select * from {}", CURSOR_POS); let (position, text) = get_text_and_position(query.as_str()); @@ -363,7 +393,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); let node = ctx.ts_node.unwrap(); @@ -374,8 +404,8 @@ mod tests { ); } - #[tokio::test] - async fn does_not_fail_with_empty_statements() { + #[test] + fn does_not_fail_with_empty_statements() { let query = format!("{}", CURSOR_POS); let (position, text) = get_text_and_position(query.as_str()); @@ -389,7 +419,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); let node = ctx.ts_node.unwrap(); @@ -397,8 +427,8 @@ mod tests { assert_eq!(ctx.wrapping_clause_type, None); } - #[tokio::test] - async fn does_not_fail_on_incomplete_keywords() { + #[test] + fn does_not_fail_on_incomplete_keywords() { // Instead of autocompleting "FROM", we'll assume that the user // is selecting a certain column name, such as `frozen_account`. let query = format!("select * fro{}", CURSOR_POS); @@ -414,7 +444,7 @@ mod tests { schema: &pg_schema_cache::SchemaCache::new(), }; - let ctx = CompletionContext::new(¶ms).await; + let ctx = CompletionContext::new(¶ms); let node = ctx.ts_node.unwrap(); diff --git a/crates/pg_completions/src/providers/columns.rs b/crates/pg_completions/src/providers/columns.rs new file mode 100644 index 000000000..2d84cba6d --- /dev/null +++ b/crates/pg_completions/src/providers/columns.rs @@ -0,0 +1,20 @@ +use crate::{ + builder::CompletionBuilder, context::CompletionContext, relevance::CompletionRelevanceData, + CompletionItem, CompletionItemKind, +}; + +pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder) { + let available_columns = &ctx.schema_cache.columns; + + for col in available_columns { + let item = CompletionItem { + label: col.name.clone(), + score: CompletionRelevanceData::Column(col).get_score(ctx), + description: format!("Table: {}.{}", col.schema_name, col.table_name), + preselected: false, + kind: CompletionItemKind::Function, + }; + + builder.add_item(item); + } +} diff --git a/crates/pg_completions/src/providers/functions.rs b/crates/pg_completions/src/providers/functions.rs index 09ea9419e..d6c9db4c7 100644 --- a/crates/pg_completions/src/providers/functions.rs +++ b/crates/pg_completions/src/providers/functions.rs @@ -45,7 +45,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); let CompletionItem { label, .. } = results .into_iter() @@ -78,7 +78,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); let CompletionItem { label, kind, .. } = results .into_iter() @@ -112,7 +112,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); let CompletionItem { label, kind, .. } = results .into_iter() @@ -146,7 +146,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); let CompletionItem { label, kind, .. } = results .into_iter() diff --git a/crates/pg_completions/src/providers/mod.rs b/crates/pg_completions/src/providers/mod.rs index 105482062..930551290 100644 --- a/crates/pg_completions/src/providers/mod.rs +++ b/crates/pg_completions/src/providers/mod.rs @@ -1,5 +1,7 @@ +mod columns; mod functions; mod tables; +pub use columns::*; pub use functions::*; pub use tables::*; diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 0bc642516..70574ec85 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -43,7 +43,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); assert!(!results.items.is_empty()); @@ -81,7 +81,7 @@ mod tests { for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); assert!(!results.items.is_empty()); @@ -126,7 +126,7 @@ mod tests { for (query, expected_label) in test_cases { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); assert!(!results.items.is_empty()); @@ -163,7 +163,7 @@ mod tests { let (tree, cache) = get_test_deps(setup, &query).await; let params = get_test_params(&tree, &cache, &query); - let results = complete(params).await; + let results = complete(params); let CompletionItem { label, kind, .. } = results .into_iter() diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index 5408a8e46..c113ef966 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -4,6 +4,7 @@ use crate::context::{ClauseType, CompletionContext}; pub(crate) enum CompletionRelevanceData<'a> { Table(&'a pg_schema_cache::Table), Function(&'a pg_schema_cache::Function), + Column(&'a pg_schema_cache::Column), } impl<'a> CompletionRelevanceData<'a> { @@ -34,6 +35,7 @@ impl<'a> CompletionRelevance<'a> { self.check_if_catalog(ctx); self.check_is_invocation(ctx); self.check_matching_clause_type(ctx); + self.check_relations_in_stmt(ctx); self.score } @@ -49,6 +51,7 @@ impl<'a> CompletionRelevance<'a> { let name = match self.data { CompletionRelevanceData::Function(f) => f.name.as_str(), CompletionRelevanceData::Table(t) => t.name.as_str(), + CompletionRelevanceData::Column(c) => c.name.as_str(), }; if name.starts_with(content) { @@ -79,6 +82,11 @@ impl<'a> CompletionRelevance<'a> { ClauseType::From => 0, _ => -50, }, + CompletionRelevanceData::Column(_) => match clause_type { + ClauseType::Select => 15, + ClauseType::Where => 15, + _ => -15, + }, } } @@ -107,10 +115,7 @@ impl<'a> CompletionRelevance<'a> { Some(n) => n, }; - let data_schema = match self.data { - CompletionRelevanceData::Function(f) => f.schema.as_str(), - CompletionRelevanceData::Table(t) => t.schema.as_str(), - }; + let data_schema = self.get_schema_name(); if schema_name == data_schema { self.score += 25; @@ -119,6 +124,22 @@ impl<'a> CompletionRelevance<'a> { } } + fn get_schema_name(&self) -> &str { + match self.data { + CompletionRelevanceData::Function(f) => f.schema.as_str(), + CompletionRelevanceData::Table(t) => t.schema.as_str(), + CompletionRelevanceData::Column(c) => c.schema_name.as_str(), + } + } + + fn get_table_name(&self) -> Option<&str> { + match self.data { + CompletionRelevanceData::Column(c) => Some(c.table_name.as_str()), + CompletionRelevanceData::Table(t) => Some(t.name.as_str()), + _ => None, + } + } + fn check_if_catalog(&mut self, ctx: &CompletionContext) { if ctx.schema_name.as_ref().is_some_and(|n| n == "pg_catalog") { return; @@ -126,4 +147,32 @@ impl<'a> CompletionRelevance<'a> { self.score -= 5; // unlikely that the user wants schema data } + + fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { + match self.data { + CompletionRelevanceData::Table(_) => return, + CompletionRelevanceData::Function(_) => return, + _ => {} + } + + let schema = self.get_schema_name().to_string(); + let table_name = match self.get_table_name() { + Some(t) => t, + None => return, + }; + + if ctx + .mentioned_relations + .get(&Some(schema.to_string())) + .is_some_and(|tables| tables.contains(table_name)) + { + self.score += 45; + } else if ctx + .mentioned_relations + .get(&None) + .is_some_and(|tables| tables.contains(table_name)) + { + self.score += 30; + } + } } diff --git a/crates/pg_schema_cache/src/lib.rs b/crates/pg_schema_cache/src/lib.rs index 719da4049..c6dad0b7d 100644 --- a/crates/pg_schema_cache/src/lib.rs +++ b/crates/pg_schema_cache/src/lib.rs @@ -10,6 +10,7 @@ mod tables; mod types; mod versions; +pub use columns::*; pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; pub use schema_cache::SchemaCache; pub use tables::{ReplicaIdentity, Table}; diff --git a/crates/pg_treesitter_queries/Cargo.toml b/crates/pg_treesitter_queries/Cargo.toml index edc917def..2d92fca31 100644 --- a/crates/pg_treesitter_queries/Cargo.toml +++ b/crates/pg_treesitter_queries/Cargo.toml @@ -15,7 +15,6 @@ version = "0.0.0" clap = { version = "4.5.23", features = ["derive"] } tree-sitter.workspace = true tree_sitter_sql.workspace = true -tokio.workspace = true [dev-dependencies] diff --git a/crates/pg_treesitter_queries/src/lib.rs b/crates/pg_treesitter_queries/src/lib.rs index dbb30e6cc..978910c8e 100644 --- a/crates/pg_treesitter_queries/src/lib.rs +++ b/crates/pg_treesitter_queries/src/lib.rs @@ -20,12 +20,12 @@ impl<'a> TreeSitterQueriesExecutor<'a> { } #[allow(private_bounds)] - pub async fn add_query_results>(&mut self) { - let mut results = Q::execute(self.root_node, &self.stmt).await; + pub fn add_query_results>(&mut self) { + let mut results = Q::execute(self.root_node, &self.stmt); self.results.append(&mut results); } - pub fn get_iter(&self, range: Option>) -> QueryResultIter { + pub fn get_iter(&self, range: Option<&'a Range>) -> QueryResultIter { match range { Some(r) => QueryResultIter::new(&self.results).within_range(r), None => QueryResultIter::new(&self.results), @@ -35,7 +35,7 @@ impl<'a> TreeSitterQueriesExecutor<'a> { pub struct QueryResultIter<'a> { inner: Iter<'a, QueryResult<'a>>, - range: Option>, + range: Option<&'a Range>, } impl<'a> QueryResultIter<'a> { @@ -46,7 +46,7 @@ impl<'a> QueryResultIter<'a> { } } - pub fn within_range(mut self, r: Range) -> Self { + fn within_range(mut self, r: &'a Range) -> Self { self.range = Some(r); self } @@ -72,8 +72,8 @@ impl<'a> Iterator for QueryResultIter<'a> { mod tests { use crate::{queries::RelationMatch, TreeSitterQueriesExecutor}; - #[tokio::test] - async fn finds_all_relations_and_ignores_functions() { + #[test] + fn finds_all_relations_and_ignores_functions() { let sql = r#" select * @@ -110,7 +110,7 @@ where let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), &sql); - executor.add_query_results::().await; + executor.add_query_results::(); let results: Vec<&RelationMatch> = executor .get_iter(None) diff --git a/crates/pg_treesitter_queries/src/queries/mod.rs b/crates/pg_treesitter_queries/src/queries/mod.rs index 793c8a2e5..9e3abbee4 100644 --- a/crates/pg_treesitter_queries/src/queries/mod.rs +++ b/crates/pg_treesitter_queries/src/queries/mod.rs @@ -34,8 +34,5 @@ pub(crate) trait QueryTryFrom<'a>: Sized { } pub(crate) trait Query<'a>: QueryTryFrom<'a> { - async fn execute( - root_node: tree_sitter::Node<'a>, - stmt: &'a str, - ) -> Vec>; + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec>; } diff --git a/crates/pg_treesitter_queries/src/queries/relations.rs b/crates/pg_treesitter_queries/src/queries/relations.rs index 3d262c8c2..588f82eb6 100644 --- a/crates/pg_treesitter_queries/src/queries/relations.rs +++ b/crates/pg_treesitter_queries/src/queries/relations.rs @@ -1,12 +1,7 @@ -use std::sync::Arc; -use tokio::sync::OnceCell; - use crate::{Query, QueryResult}; use super::QueryTryFrom; -static INSTANCE: OnceCell> = OnceCell::const_new(); - static QUERY: &'static str = r#" (relation (object_reference @@ -23,6 +18,17 @@ pub struct RelationMatch<'a> { pub(crate) table: tree_sitter::Node<'a>, } +impl<'a> RelationMatch<'a> { + pub fn get_schema(&self, sql: &str) -> Option { + let str = self.schema.as_ref()?.utf8_text(sql.as_bytes()).unwrap(); + Some(str.to_string()) + } + + pub fn get_table(&self, sql: &str) -> String { + self.table.utf8_text(sql.as_bytes()).unwrap().to_string() + } +} + impl<'a> TryFrom<&'a QueryResult<'a>> for &'a RelationMatch<'a> { type Error = String; @@ -41,18 +47,9 @@ impl<'a> QueryTryFrom<'a> for RelationMatch<'a> { } impl<'a> Query<'a> for RelationMatch<'a> { - async fn execute( - root_node: tree_sitter::Node<'a>, - stmt: &'a str, - ) -> Vec> { - let query = INSTANCE - .get_or_init(|| async { - Arc::new( - tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY) - .expect("Invalid Query."), - ) - }) - .await; + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let query = + tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY).expect("Invalid Query."); let mut cursor = tree_sitter::QueryCursor::new(); From 646440eda869cbc75d96fd48f2c937cf5a8b764c Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 12:25:59 +0100 Subject: [PATCH 03/10] jeez --- crates/pg_completions/src/item.rs | 1 + .../pg_completions/src/providers/columns.rs | 2 +- crates/pg_completions/src/relevance.rs | 3 +- crates/pg_test_utils/Cargo.toml | 4 - .../pg_test_utils/src/bin/tree_query_debug.rs | 91 ------------------- crates/pg_treesitter_queries/src/lib.rs | 15 ++- 6 files changed, 9 insertions(+), 107 deletions(-) delete mode 100644 crates/pg_test_utils/src/bin/tree_query_debug.rs diff --git a/crates/pg_completions/src/item.rs b/crates/pg_completions/src/item.rs index 06771f926..d14485c2f 100644 --- a/crates/pg_completions/src/item.rs +++ b/crates/pg_completions/src/item.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; pub enum CompletionItemKind { Table, Function, + Column, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/pg_completions/src/providers/columns.rs b/crates/pg_completions/src/providers/columns.rs index 2d84cba6d..4260f5369 100644 --- a/crates/pg_completions/src/providers/columns.rs +++ b/crates/pg_completions/src/providers/columns.rs @@ -12,7 +12,7 @@ pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder score: CompletionRelevanceData::Column(col).get_score(ctx), description: format!("Table: {}.{}", col.schema_name, col.table_name), preselected: false, - kind: CompletionItemKind::Function, + kind: CompletionItemKind::Column, }; builder.add_item(item); diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index c113ef966..d4d05252e 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -150,8 +150,7 @@ impl<'a> CompletionRelevance<'a> { fn check_relations_in_stmt(&mut self, ctx: &CompletionContext) { match self.data { - CompletionRelevanceData::Table(_) => return, - CompletionRelevanceData::Function(_) => return, + CompletionRelevanceData::Table(_) | CompletionRelevanceData::Function(_) => return, _ => {} } diff --git a/crates/pg_test_utils/Cargo.toml b/crates/pg_test_utils/Cargo.toml index 6325193ad..ea8377937 100644 --- a/crates/pg_test_utils/Cargo.toml +++ b/crates/pg_test_utils/Cargo.toml @@ -15,10 +15,6 @@ version = "0.0.0" name = "tree_print" path = "src/bin/tree_print.rs" -[[bin]] -name = "query_debug" -path = "src/bin/tree_query_debug.rs" - [dependencies] anyhow = "1.0.81" clap = { version = "4.5.23", features = ["derive"] } diff --git a/crates/pg_test_utils/src/bin/tree_query_debug.rs b/crates/pg_test_utils/src/bin/tree_query_debug.rs deleted file mode 100644 index 4e56f673b..000000000 --- a/crates/pg_test_utils/src/bin/tree_query_debug.rs +++ /dev/null @@ -1,91 +0,0 @@ -use clap::*; - -#[derive(Parser)] -#[command(name = "query-debugger", about = "Debugs a query")] -struct Args { - #[arg(long = "file", short = 'f')] - file: String, -} - -fn main() { - let args = Args::parse(); - - let stmt = std::fs::read_to_string(&args.file).expect("Failed to read file."); - - let mut parser = tree_sitter::Parser::new(); - let lang = tree_sitter_sql::language(); - parser - .set_language(lang.clone()) - .expect("Setting Language failed."); - - let tree = parser - .parse(stmt.clone(), None) - .expect("Failed to parse Statement"); - - let results = relation_matches(tree.root_node(), &stmt); - - for r in results { - println!("{}", r.to_full_name(&stmt)) - } -} - -struct RelationMatch<'a> { - schema: Option>, - table: tree_sitter::Node<'a>, -} - -impl<'a> RelationMatch<'a> { - fn to_full_name(&self, stmt: &str) -> String { - match self.schema { - Some(s) => format!( - "{}.{}", - s.utf8_text(stmt.as_bytes()).unwrap(), - self.table.utf8_text(stmt.as_bytes()).unwrap() - ), - None => format!("{}", self.table.utf8_text(stmt.as_bytes()).unwrap()), - } - } -} - -fn relation_matches<'a>(root_node: tree_sitter::Node<'a>, stmt: &str) -> Vec> { - static QUERY: &str = r#" - (relation - (object_reference - (identifier)+ @schema_or_table - "." - (identifier) @table - )+ - ) - "#; - - let query = - tree_sitter::Query::new(tree_sitter_sql::language(), QUERY).expect("Invalid Query!"); - - let mut cursor = tree_sitter::QueryCursor::new(); - - let matches = cursor.matches(&query, root_node, stmt.as_bytes()); - - let mut to_return = vec![]; - - for m in matches { - if m.captures.len() == 1 { - let capture = m.captures[0].node; - to_return.push(RelationMatch { - schema: None, - table: capture, - }); - } - - if m.captures.len() == 2 { - let schema = m.captures[0].node; - let table = m.captures[1].node; - - to_return.push(RelationMatch { - schema: Some(schema), - table, - }); - } - } - - to_return -} diff --git a/crates/pg_treesitter_queries/src/lib.rs b/crates/pg_treesitter_queries/src/lib.rs index 978910c8e..0bcdbdefa 100644 --- a/crates/pg_treesitter_queries/src/lib.rs +++ b/crates/pg_treesitter_queries/src/lib.rs @@ -55,16 +55,13 @@ impl<'a> QueryResultIter<'a> { impl<'a> Iterator for QueryResultIter<'a> { type Item = &'a QueryResult<'a>; fn next(&mut self) -> Option { - match self.inner.next() { - None => return None, - Some(n) => { - if self.range.as_ref().is_some_and(|r| !n.within_range(r)) { - return self.next(); - } - - Some(n) - } + let next = self.inner.next()?; + + if self.range.as_ref().is_some_and(|r| !next.within_range(r)) { + return self.next(); } + + Some(next) } } From 089c50e2591cc1c3f7203c02e9b9439611821f96 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 12:27:34 +0100 Subject: [PATCH 04/10] not necessary --- crates/pg_completions/src/complete.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/pg_completions/src/complete.rs b/crates/pg_completions/src/complete.rs index 4b6f79d74..c45d01aca 100644 --- a/crates/pg_completions/src/complete.rs +++ b/crates/pg_completions/src/complete.rs @@ -31,7 +31,7 @@ impl IntoIterator for CompletionResult { } } -pub fn complete<'a>(params: CompletionParams<'a>) -> CompletionResult { +pub fn complete(params: CompletionParams) -> CompletionResult { let ctx = CompletionContext::new(¶ms); let mut builder = CompletionBuilder::new(); From 4ee920deee2b2b91eab459d0b668f67b04fe75fe Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 17:07:38 +0100 Subject: [PATCH 05/10] beautiful --- crates/pg_completions/src/context.rs | 31 +++-- .../pg_completions/src/providers/columns.rs | 94 +++++++++++++++ .../pg_completions/src/providers/functions.rs | 16 +-- crates/pg_completions/src/providers/tables.rs | 17 +-- crates/pg_completions/src/test_helper.rs | 43 +++++-- crates/pg_lsp_new/src/handlers/completions.rs | 1 + crates/pg_treesitter_queries/src/lib.rs | 113 +++++++++++------- .../pg_treesitter_queries/src/queries/mod.rs | 15 +-- .../src/queries/relations.rs | 13 +- 9 files changed, 245 insertions(+), 98 deletions(-) diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index 3d0eb14fc..b284b1668 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -1,8 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - hash::Hash, - ops::Range, -}; +use std::collections::{HashMap, HashSet}; use pg_schema_cache::SchemaCache; use pg_treesitter_queries::{ @@ -62,7 +58,7 @@ pub(crate) struct CompletionContext<'a> { pub schema_name: Option, pub wrapping_clause_type: Option, pub is_invocation: bool, - pub wrapping_statement_range: Option>, + pub wrapping_statement_range: Option, pub mentioned_relations: HashMap, HashSet>, } @@ -74,7 +70,6 @@ impl<'a> CompletionContext<'a> { text: ¶ms.text, schema_cache: params.schema, position: usize::from(params.position), - ts_node: None, schema_name: None, wrapping_clause_type: None, @@ -86,6 +81,8 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); + dbg!(ctx.wrapping_statement_range); + ctx } @@ -98,6 +95,8 @@ impl<'a> CompletionContext<'a> { let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; + dbg!(sql); + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text); executor.add_query_results::(); @@ -174,9 +173,9 @@ impl<'a> CompletionContext<'a> { } match previous_node.kind() { - "statement" => { + "statement" | "subquery" => { self.wrapping_clause_type = current_node.kind().try_into().ok(); - self.wrapping_statement_range = Some(previous_node.byte_range()); + self.wrapping_statement_range = Some(previous_node.range()); } "invocation" => self.is_invocation = true, @@ -263,7 +262,7 @@ mod tests { ]; for (query, expected_clause) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -296,7 +295,7 @@ mod tests { ]; for (query, expected_schema) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -330,7 +329,7 @@ mod tests { ]; for (query, is_invocation) in test_cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); let params = crate::CompletionParams { @@ -354,7 +353,7 @@ mod tests { ]; for query in cases { - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -382,7 +381,7 @@ mod tests { fn does_not_fail_on_trailing_whitespace() { let query = format!("select * from {}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -408,7 +407,7 @@ mod tests { fn does_not_fail_with_empty_statements() { let query = format!("{}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); @@ -433,7 +432,7 @@ mod tests { // is selecting a certain column name, such as `frozen_account`. let query = format!("select * fro{}", CURSOR_POS); - let (position, text) = get_text_and_position(query.as_str()); + let (position, text) = get_text_and_position(query.as_str().into()); let tree = get_tree(text.as_str()); diff --git a/crates/pg_completions/src/providers/columns.rs b/crates/pg_completions/src/providers/columns.rs index 4260f5369..87539c026 100644 --- a/crates/pg_completions/src/providers/columns.rs +++ b/crates/pg_completions/src/providers/columns.rs @@ -18,3 +18,97 @@ pub fn complete_columns(ctx: &CompletionContext, builder: &mut CompletionBuilder builder.add_item(item); } } + +#[cfg(test)] +mod tests { + use crate::{ + complete, + test_helper::{get_test_deps, get_test_params, InputQuery, CURSOR_POS}, + CompletionItem, + }; + + struct TestCase { + query: String, + message: &'static str, + label: &'static str, + description: &'static str, + } + + impl TestCase { + fn get_input_query(&self) -> InputQuery { + let strs: Vec<&str> = self.query.split_whitespace().collect(); + strs.join(" ").as_str().into() + } + } + + #[tokio::test] + async fn completes_columns() { + let setup = r#" + create schema private; + + create table public.users ( + id serial primary key, + name text + ); + + create table public.audio_books ( + id serial primary key, + narrator text + ); + + create table private.audio_books ( + id serial primary key, + narrator_id text + ); + "#; + + let queries: Vec = vec![ + TestCase { + message: "correctly prefers the columns of present tables", + query: format!(r#"select na{} from public.audio_books;"#, CURSOR_POS), + label: "narrator", + description: "Table: public.audio_books", + }, + TestCase { + message: "correctly handles nested queries", + query: format!( + r#" + select + * + from ( + select id, na{} + from private.audio_books + ) as subquery + join public.users u + on u.id = subquery.id; + "#, + CURSOR_POS + ), + label: "narrator_id", + description: "Table: private.audio_books", + }, + TestCase { + message: "works without a schema", + query: format!(r#"select na{} from users;"#, CURSOR_POS), + label: "name", + description: "Table: public.users", + }, + ]; + + for q in queries { + let (tree, cache) = get_test_deps(setup, q.get_input_query()).await; + let params = get_test_params(&tree, &cache, q.get_input_query()); + let results = complete(params); + + let CompletionItem { + label, description, .. + } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, q.label, "{}", q.message); + assert_eq!(description, q.description, "{}", q.message); + } + } +} diff --git a/crates/pg_completions/src/providers/functions.rs b/crates/pg_completions/src/providers/functions.rs index d6c9db4c7..e8e530205 100644 --- a/crates/pg_completions/src/providers/functions.rs +++ b/crates/pg_completions/src/providers/functions.rs @@ -43,8 +43,8 @@ mod tests { let query = format!("select coo{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, .. } = results @@ -76,8 +76,8 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results @@ -110,8 +110,8 @@ mod tests { let query = format!(r#"select coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results @@ -144,8 +144,8 @@ mod tests { let query = format!(r#"select * from coo{}()"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results diff --git a/crates/pg_completions/src/providers/tables.rs b/crates/pg_completions/src/providers/tables.rs index 70574ec85..c3d924259 100644 --- a/crates/pg_completions/src/providers/tables.rs +++ b/crates/pg_completions/src/providers/tables.rs @@ -23,6 +23,7 @@ pub fn complete_tables(ctx: &CompletionContext, builder: &mut CompletionBuilder) #[cfg(test)] mod tests { + use crate::{ complete, test_helper::{get_test_deps, get_test_params, CURSOR_POS}, @@ -41,8 +42,8 @@ mod tests { let query = format!("select * from u{}", CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -79,8 +80,8 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -124,8 +125,8 @@ mod tests { ]; for (query, expected_label) in test_cases { - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); assert!(!results.items.is_empty()); @@ -161,8 +162,8 @@ mod tests { let query = format!(r#"select * from coo{}"#, CURSOR_POS); - let (tree, cache) = get_test_deps(setup, &query).await; - let params = get_test_params(&tree, &cache, &query); + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); let results = complete(params); let CompletionItem { label, kind, .. } = results diff --git a/crates/pg_completions/src/test_helper.rs b/crates/pg_completions/src/test_helper.rs index 4c29d1e7a..83f9cdd90 100644 --- a/crates/pg_completions/src/test_helper.rs +++ b/crates/pg_completions/src/test_helper.rs @@ -6,9 +6,34 @@ use crate::CompletionParams; pub static CURSOR_POS: char = '€'; +pub struct InputQuery { + sql: String, + position: usize, +} + +impl From<&str> for InputQuery { + fn from(value: &str) -> Self { + let position = value + .find(CURSOR_POS) + .map(|p| p.saturating_sub(1)) + .expect("Insert Cursor Position into your Query."); + + InputQuery { + sql: value.replace(CURSOR_POS, ""), + position, + } + } +} + +impl ToString for InputQuery { + fn to_string(&self) -> String { + self.sql.clone() + } +} + pub(crate) async fn get_test_deps( setup: &str, - input: &str, + input: InputQuery, ) -> (tree_sitter::Tree, pg_schema_cache::SchemaCache) { let test_db = get_new_test_db().await; @@ -26,27 +51,19 @@ pub(crate) async fn get_test_deps( .set_language(tree_sitter_sql::language()) .expect("Error loading sql language"); - let tree = parser.parse(input, None).unwrap(); + let tree = parser.parse(&input.to_string(), None).unwrap(); (tree, schema_cache) } -pub(crate) fn get_text_and_position(sql: &str) -> (usize, String) { - // the cursor is to the left of the `CURSOR_POS` - let position = sql - .find(CURSOR_POS) - .expect("Please insert the CURSOR_POS into your query.") - .saturating_sub(1); - - let text = sql.replace(CURSOR_POS, ""); - - (position, text) +pub(crate) fn get_text_and_position(q: InputQuery) -> (usize, String) { + (q.position, q.sql) } pub(crate) fn get_test_params<'a>( tree: &'a tree_sitter::Tree, schema_cache: &'a pg_schema_cache::SchemaCache, - sql: &'a str, + sql: InputQuery, ) -> CompletionParams<'a> { let (position, text) = get_text_and_position(sql); diff --git a/crates/pg_lsp_new/src/handlers/completions.rs b/crates/pg_lsp_new/src/handlers/completions.rs index 5f7a13091..4a1775033 100644 --- a/crates/pg_lsp_new/src/handlers/completions.rs +++ b/crates/pg_lsp_new/src/handlers/completions.rs @@ -56,5 +56,6 @@ fn to_lsp_types_completion_item_kind( match pg_comp_kind { pg_completions::CompletionItemKind::Function | pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + pg_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, } } diff --git a/crates/pg_treesitter_queries/src/lib.rs b/crates/pg_treesitter_queries/src/lib.rs index 0bcdbdefa..8d29db38e 100644 --- a/crates/pg_treesitter_queries/src/lib.rs +++ b/crates/pg_treesitter_queries/src/lib.rs @@ -1,6 +1,6 @@ pub mod queries; -use std::{ops::Range, slice::Iter}; +use std::slice::Iter; use queries::{Query, QueryResult}; @@ -25,7 +25,7 @@ impl<'a> TreeSitterQueriesExecutor<'a> { self.results.append(&mut results); } - pub fn get_iter(&self, range: Option<&'a Range>) -> QueryResultIter { + pub fn get_iter(&self, range: Option<&'a tree_sitter::Range>) -> QueryResultIter { match range { Some(r) => QueryResultIter::new(&self.results).within_range(r), None => QueryResultIter::new(&self.results), @@ -35,7 +35,7 @@ impl<'a> TreeSitterQueriesExecutor<'a> { pub struct QueryResultIter<'a> { inner: Iter<'a, QueryResult<'a>>, - range: Option<&'a Range>, + range: Option<&'a tree_sitter::Range>, } impl<'a> QueryResultIter<'a> { @@ -46,7 +46,7 @@ impl<'a> QueryResultIter<'a> { } } - fn within_range(mut self, r: &'a Range) -> Self { + fn within_range(mut self, r: &'a tree_sitter::Range) -> Self { self.range = Some(r); self } @@ -67,6 +67,7 @@ impl<'a> Iterator for QueryResultIter<'a> { #[cfg(test)] mod tests { + use crate::{queries::RelationMatch, TreeSitterQueriesExecutor}; #[test] @@ -114,46 +115,74 @@ where .filter_map(|q| q.try_into().ok()) .collect(); - assert_eq!( - results[0] - .schema - .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), - Some("public") - ); - assert_eq!( - results[0].table.utf8_text(&sql.as_bytes()).unwrap(), - "cool_table" - ); - - assert_eq!( - results[1] - .schema - .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), - Some("private") - ); - assert_eq!( - results[1].table.utf8_text(&sql.as_bytes()).unwrap(), - "cool_tableau" - ); - - assert_eq!(results[2].schema, None); - assert_eq!( - results[2].table.utf8_text(&sql.as_bytes()).unwrap(), - "another_table" - ); - - assert_eq!( - results[3] - .schema - .map(|s| s.utf8_text(&sql.as_bytes()).unwrap()), - Some("private") - ); - assert_eq!( - results[3].table.utf8_text(&sql.as_bytes()).unwrap(), - "another_tableau" - ); + assert_eq!(results[0].get_schema(sql), Some("public".into())); + assert_eq!(results[0].get_table(sql), "cool_table"); + + assert_eq!(results[1].get_schema(sql), Some("private".into())); + assert_eq!(results[1].get_table(sql), "cool_tableau"); + + assert_eq!(results[2].get_schema(sql), None); + assert_eq!(results[2].get_table(sql), "another_table"); + + assert_eq!(results[3].get_schema(sql), Some("private".into())); + assert_eq!(results[3].get_table(sql), "another_tableau"); // we have exhausted the matches: function invocations are ignored. assert!(results.len() == 4); } + + #[test] + fn only_considers_nodes_in_requested_range() { + let sql = r#" +select + * +from ( + select * + from ( + select * + from private.something + ) as sq2 + join private.tableau pt1 + on sq2.id = pt1.id + ) as sq1 +join private.table pt +on sq1.id = pt.id; +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(&sql, None).unwrap(); + + // trust me bro + let range = { + let mut cursor = tree.root_node().walk(); + cursor.goto_first_child(); // statement + cursor.goto_first_child(); // select + cursor.goto_next_sibling(); // from + cursor.goto_first_child(); // keyword_from + cursor.goto_next_sibling(); // relation + cursor.goto_first_child(); // subquery (1) + cursor.goto_first_child(); // "(" + cursor.goto_next_sibling(); // select + cursor.goto_next_sibling(); // from + cursor.goto_first_child(); // keyword_from + cursor.goto_next_sibling(); // relation + cursor.goto_first_child(); // subquery (2) + cursor.node().range() + }; + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), &sql); + + executor.add_query_results::(); + + let results: Vec<&RelationMatch> = executor + .get_iter(Some(&range)) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].get_schema(sql), Some("private".into())); + assert_eq!(results[0].get_table(sql), "something"); + } } diff --git a/crates/pg_treesitter_queries/src/queries/mod.rs b/crates/pg_treesitter_queries/src/queries/mod.rs index 9e3abbee4..92e3b06c2 100644 --- a/crates/pg_treesitter_queries/src/queries/mod.rs +++ b/crates/pg_treesitter_queries/src/queries/mod.rs @@ -1,27 +1,24 @@ mod relations; -use std::ops::Range; - pub use relations::*; +#[derive(Debug)] pub enum QueryResult<'a> { Relation(RelationMatch<'a>), } impl<'a> QueryResult<'a> { - pub fn within_range(&self, range: &Range) -> bool { + pub fn within_range(&self, range: &tree_sitter::Range) -> bool { match self { Self::Relation(rm) => { - let tb_range = rm.table.byte_range(); - let start = match rm.schema { - Some(s) => s.byte_range().start, - None => tb_range.start, + Some(s) => s.start_position(), + None => rm.table.start_position(), }; - let end = tb_range.end; + let end = rm.table.end_position(); - range.contains(&start) && range.contains(&end) + start >= range.start_point && end <= range.end_point } } } diff --git a/crates/pg_treesitter_queries/src/queries/relations.rs b/crates/pg_treesitter_queries/src/queries/relations.rs index 588f82eb6..ba9324c96 100644 --- a/crates/pg_treesitter_queries/src/queries/relations.rs +++ b/crates/pg_treesitter_queries/src/queries/relations.rs @@ -13,6 +13,7 @@ static QUERY: &'static str = r#" ) "#; +#[derive(Debug)] pub struct RelationMatch<'a> { pub(crate) schema: Option>, pub(crate) table: tree_sitter::Node<'a>, @@ -20,12 +21,20 @@ pub struct RelationMatch<'a> { impl<'a> RelationMatch<'a> { pub fn get_schema(&self, sql: &str) -> Option { - let str = self.schema.as_ref()?.utf8_text(sql.as_bytes()).unwrap(); + let str = self + .schema + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch"); + Some(str.to_string()) } pub fn get_table(&self, sql: &str) -> String { - self.table.utf8_text(sql.as_bytes()).unwrap().to_string() + self.table + .utf8_text(sql.as_bytes()) + .expect("Failed to get schema from RelationMatch") + .to_string() } } From 7413aa10ea48b7ae92bb32821425dbfa1465a7ad Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 17:16:20 +0100 Subject: [PATCH 06/10] fixie fixie --- crates/pg_completions/src/context.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/crates/pg_completions/src/context.rs b/crates/pg_completions/src/context.rs index b284b1668..a5fb0c6be 100644 --- a/crates/pg_completions/src/context.rs +++ b/crates/pg_completions/src/context.rs @@ -64,7 +64,7 @@ pub(crate) struct CompletionContext<'a> { } impl<'a> CompletionContext<'a> { - pub fn new(params: &'a CompletionParams<'a>) -> Self { + pub fn new(params: &'a CompletionParams) -> Self { let mut ctx = Self { tree: params.tree, text: ¶ms.text, @@ -81,8 +81,6 @@ impl<'a> CompletionContext<'a> { ctx.gather_tree_context(); ctx.gather_info_from_ts_queries(); - dbg!(ctx.wrapping_statement_range); - ctx } @@ -95,9 +93,7 @@ impl<'a> CompletionContext<'a> { let stmt_range = self.wrapping_statement_range.as_ref(); let sql = self.text; - dbg!(sql); - - let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), self.text); + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); executor.add_query_results::(); From 1003f82cfa8050a9e723bdfd98c649161854c2b4 Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 17:38:46 +0100 Subject: [PATCH 07/10] randomly change score until tests do what i want --- crates/pg_completions/src/relevance.rs | 12 ++++++++---- crates/pg_lsp/src/utils/to_lsp_types.rs | 1 + crates/pg_lsp_new/src/handlers/completions.rs | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index d4d05252e..706318c5f 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -70,6 +70,8 @@ impl<'a> CompletionRelevance<'a> { Some(ct) => ct, }; + let has_mentioned_tables = ctx.mentioned_relations.len() > 0; + self.score += match self.data { CompletionRelevanceData::Table(_) => match clause_type { ClauseType::From => 5, @@ -78,13 +80,15 @@ impl<'a> CompletionRelevance<'a> { _ => -50, }, CompletionRelevanceData::Function(_) => match clause_type { - ClauseType::Select => 5, + ClauseType::Select if !has_mentioned_tables => 15, + ClauseType::Select if has_mentioned_tables => 0, ClauseType::From => 0, _ => -50, }, CompletionRelevanceData::Column(_) => match clause_type { - ClauseType::Select => 15, - ClauseType::Where => 15, + ClauseType::Select if has_mentioned_tables => 10, + ClauseType::Select if !has_mentioned_tables => 0, + ClauseType::Where => 10, _ => -15, }, } @@ -96,7 +100,7 @@ impl<'a> CompletionRelevance<'a> { if ctx.is_invocation { 30 } else { - -30 + -10 } } _ => { diff --git a/crates/pg_lsp/src/utils/to_lsp_types.rs b/crates/pg_lsp/src/utils/to_lsp_types.rs index ca5f3f421..24dcc443c 100644 --- a/crates/pg_lsp/src/utils/to_lsp_types.rs +++ b/crates/pg_lsp/src/utils/to_lsp_types.rs @@ -6,5 +6,6 @@ pub fn to_completion_kind( match kind { pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pg_completions::CompletionItemKind::Function => lsp_types::CompletionItemKind::FUNCTION, + pg_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, } } diff --git a/crates/pg_lsp_new/src/handlers/completions.rs b/crates/pg_lsp_new/src/handlers/completions.rs index 4a1775033..4efba2108 100644 --- a/crates/pg_lsp_new/src/handlers/completions.rs +++ b/crates/pg_lsp_new/src/handlers/completions.rs @@ -54,8 +54,8 @@ fn to_lsp_types_completion_item_kind( pg_comp_kind: pg_completions::CompletionItemKind, ) -> lsp_types::CompletionItemKind { match pg_comp_kind { - pg_completions::CompletionItemKind::Function - | pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, + pg_completions::CompletionItemKind::Function => lsp_types::CompletionItemKind::FUNCTION, + pg_completions::CompletionItemKind::Table => lsp_types::CompletionItemKind::CLASS, pg_completions::CompletionItemKind::Column => lsp_types::CompletionItemKind::FIELD, } } From 5b25305d4b37551b78b9fef646d7489166e441bb Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 17:40:39 +0100 Subject: [PATCH 08/10] i like the syntax --- crates/pg_completions/src/relevance.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/crates/pg_completions/src/relevance.rs b/crates/pg_completions/src/relevance.rs index 706318c5f..f7a42b16f 100644 --- a/crates/pg_completions/src/relevance.rs +++ b/crates/pg_completions/src/relevance.rs @@ -96,20 +96,10 @@ impl<'a> CompletionRelevance<'a> { fn check_is_invocation(&mut self, ctx: &CompletionContext) { self.score += match self.data { - CompletionRelevanceData::Function(_) => { - if ctx.is_invocation { - 30 - } else { - -10 - } - } - _ => { - if ctx.is_invocation { - -10 - } else { - 0 - } - } + CompletionRelevanceData::Function(_) if ctx.is_invocation => 30, + CompletionRelevanceData::Function(_) if !ctx.is_invocation => -10, + _ if ctx.is_invocation => -10, + _ => 0, }; } From e4bb9462d9f813d7ce83e623ca80d4225c8c825c Mon Sep 17 00:00:00 2001 From: Julian Date: Wed, 8 Jan 2025 17:43:25 +0100 Subject: [PATCH 09/10] format TOML --- crates/pg_completions/Cargo.toml | 10 +++++----- crates/pg_lsp_new/Cargo.toml | 2 +- crates/pg_treesitter_queries/Cargo.toml | 2 +- crates/pg_workspace_new/Cargo.toml | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/pg_completions/Cargo.toml b/crates/pg_completions/Cargo.toml index e27cccec6..140ef9105 100644 --- a/crates/pg_completions/Cargo.toml +++ b/crates/pg_completions/Cargo.toml @@ -16,12 +16,12 @@ async-std = "1.12.0" text-size.workspace = true -serde = { workspace = true, features = ["derive"] } -serde_json = { workspace = true } -pg_schema_cache.workspace = true -tree-sitter.workspace = true -tree_sitter_sql.workspace = true +pg_schema_cache.workspace = true pg_treesitter_queries.workspace = true +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true sqlx.workspace = true diff --git a/crates/pg_lsp_new/Cargo.toml b/crates/pg_lsp_new/Cargo.toml index 0454893e7..8e20b521f 100644 --- a/crates/pg_lsp_new/Cargo.toml +++ b/crates/pg_lsp_new/Cargo.toml @@ -16,10 +16,10 @@ anyhow = { workspace = true } biome_deserialize = { workspace = true } futures = "0.3.31" pg_analyse = { workspace = true } +pg_completions = { workspace = true } pg_configuration = { workspace = true } pg_console = { workspace = true } pg_diagnostics = { workspace = true } -pg_completions = { workspace = true } pg_fs = { workspace = true } pg_lsp_converters = { workspace = true } pg_text_edit = { workspace = true } diff --git a/crates/pg_treesitter_queries/Cargo.toml b/crates/pg_treesitter_queries/Cargo.toml index 2d92fca31..bb85c4482 100644 --- a/crates/pg_treesitter_queries/Cargo.toml +++ b/crates/pg_treesitter_queries/Cargo.toml @@ -12,7 +12,7 @@ version = "0.0.0" [dependencies] -clap = { version = "4.5.23", features = ["derive"] } +clap = { version = "4.5.23", features = ["derive"] } tree-sitter.workspace = true tree_sitter_sql.workspace = true diff --git a/crates/pg_workspace_new/Cargo.toml b/crates/pg_workspace_new/Cargo.toml index 9da718cf2..c48bb6e2b 100644 --- a/crates/pg_workspace_new/Cargo.toml +++ b/crates/pg_workspace_new/Cargo.toml @@ -18,10 +18,10 @@ futures = "0.3.31" ignore = { workspace = true } pg_analyse = { workspace = true, features = ["serde"] } pg_analyser = { workspace = true } +pg_completions = { workspace = true } pg_configuration = { workspace = true } pg_console = { workspace = true } pg_diagnostics = { workspace = true } -pg_completions = { workspace = true } pg_fs = { workspace = true, features = ["serde"] } pg_query_ext = { workspace = true } pg_schema_cache = { workspace = true } From ffba6ebe23c0312b3ed108024a27e3b473fde7a3 Mon Sep 17 00:00:00 2001 From: Julian Date: Fri, 10 Jan 2025 15:50:28 +0100 Subject: [PATCH 10/10] use lazyLock --- .../pg_treesitter_queries/src/queries/relations.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/pg_treesitter_queries/src/queries/relations.rs b/crates/pg_treesitter_queries/src/queries/relations.rs index ba9324c96..2ca27a055 100644 --- a/crates/pg_treesitter_queries/src/queries/relations.rs +++ b/crates/pg_treesitter_queries/src/queries/relations.rs @@ -1,8 +1,11 @@ +use std::sync::LazyLock; + use crate::{Query, QueryResult}; use super::QueryTryFrom; -static QUERY: &'static str = r#" +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &'static str = r#" (relation (object_reference . @@ -12,6 +15,8 @@ static QUERY: &'static str = r#" )+ ) "#; + tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY_STR).expect("Invalid TS Query") +}); #[derive(Debug)] pub struct RelationMatch<'a> { @@ -57,12 +62,9 @@ impl<'a> QueryTryFrom<'a> for RelationMatch<'a> { impl<'a> Query<'a> for RelationMatch<'a> { fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { - let query = - tree_sitter::Query::new(tree_sitter_sql::language(), &QUERY).expect("Invalid Query."); - let mut cursor = tree_sitter::QueryCursor::new(); - let matches = cursor.matches(&query, root_node, stmt.as_bytes()); + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); let mut to_return = vec![];