1+ use std:: collections:: { HashMap , HashSet } ;
2+
13use pg_schema_cache:: SchemaCache ;
4+ use pg_treesitter_queries:: {
5+ queries:: { self , QueryResult } ,
6+ TreeSitterQueriesExecutor ,
7+ } ;
28
39use crate :: CompletionParams ;
410
@@ -52,6 +58,9 @@ pub(crate) struct CompletionContext<'a> {
5258 pub schema_name : Option < String > ,
5359 pub wrapping_clause_type : Option < ClauseType > ,
5460 pub is_invocation : bool ,
61+ pub wrapping_statement_range : Option < tree_sitter:: Range > ,
62+
63+ pub mentioned_relations : HashMap < Option < String > , HashSet < String > > ,
5564}
5665
5766impl < ' a > CompletionContext < ' a > {
@@ -61,18 +70,56 @@ impl<'a> CompletionContext<'a> {
6170 text : & params. text ,
6271 schema_cache : params. schema ,
6372 position : usize:: from ( params. position ) ,
64-
6573 ts_node : None ,
6674 schema_name : None ,
6775 wrapping_clause_type : None ,
76+ wrapping_statement_range : None ,
6877 is_invocation : false ,
78+ mentioned_relations : HashMap :: new ( ) ,
6979 } ;
7080
7181 ctx. gather_tree_context ( ) ;
82+ ctx. gather_info_from_ts_queries ( ) ;
7283
7384 ctx
7485 }
7586
87+ fn gather_info_from_ts_queries ( & mut self ) {
88+ let tree = match self . tree . as_ref ( ) {
89+ None => return ,
90+ Some ( t) => t,
91+ } ;
92+
93+ let stmt_range = self . wrapping_statement_range . as_ref ( ) ;
94+ let sql = self . text ;
95+
96+ let mut executor = TreeSitterQueriesExecutor :: new ( tree. root_node ( ) , sql) ;
97+
98+ executor. add_query_results :: < queries:: RelationMatch > ( ) ;
99+
100+ for relation_match in executor. get_iter ( stmt_range) {
101+ match relation_match {
102+ QueryResult :: Relation ( r) => {
103+ let schema_name = r. get_schema ( sql) ;
104+ let table_name = r. get_table ( sql) ;
105+
106+ let current = self . mentioned_relations . get_mut ( & schema_name) ;
107+
108+ match current {
109+ Some ( c) => {
110+ c. insert ( table_name) ;
111+ }
112+ None => {
113+ let mut new = HashSet :: new ( ) ;
114+ new. insert ( table_name) ;
115+ self . mentioned_relations . insert ( schema_name, new) ;
116+ }
117+ } ;
118+ }
119+ } ;
120+ }
121+ }
122+
76123 pub fn get_ts_node_content ( & self , ts_node : tree_sitter:: Node < ' a > ) -> Option < & ' a str > {
77124 let source = self . text ;
78125 match ts_node. utf8_text ( source. as_bytes ( ) ) {
@@ -100,36 +147,38 @@ impl<'a> CompletionContext<'a> {
100147 * We'll therefore adjust the cursor position such that it meets the last node of the AST.
101148 * `select * from use {}` becomes `select * from use{}`.
102149 */
103- let current_node_kind = cursor. node ( ) . kind ( ) ;
150+ let current_node = cursor. node ( ) ;
104151 while cursor. goto_first_child_for_byte ( self . position ) . is_none ( ) && self . position > 0 {
105152 self . position -= 1 ;
106153 }
107154
108- self . gather_context_from_node ( cursor, current_node_kind ) ;
155+ self . gather_context_from_node ( cursor, current_node ) ;
109156 }
110157
111158 fn gather_context_from_node (
112159 & mut self ,
113160 mut cursor : tree_sitter:: TreeCursor < ' a > ,
114- previous_node_kind : & str ,
161+ previous_node : tree_sitter :: Node < ' a > ,
115162 ) {
116163 let current_node = cursor. node ( ) ;
117- let current_node_kind = current_node. kind ( ) ;
118164
119165 // prevent infinite recursion – this can happen if we only have a PROGRAM node
120- if current_node_kind == previous_node_kind {
166+ if current_node . kind ( ) == previous_node . kind ( ) {
121167 self . ts_node = Some ( current_node) ;
122168 return ;
123169 }
124170
125- match previous_node_kind {
126- "statement" => self . wrapping_clause_type = current_node_kind. try_into ( ) . ok ( ) ,
171+ match previous_node. kind ( ) {
172+ "statement" | "subquery" => {
173+ self . wrapping_clause_type = current_node. kind ( ) . try_into ( ) . ok ( ) ;
174+ self . wrapping_statement_range = Some ( previous_node. range ( ) ) ;
175+ }
127176 "invocation" => self . is_invocation = true ,
128177
129178 _ => { }
130179 }
131180
132- match current_node_kind {
181+ match current_node . kind ( ) {
133182 "object_reference" => {
134183 let txt = self . get_ts_node_content ( current_node) ;
135184 if let Some ( txt) = txt {
@@ -159,7 +208,7 @@ impl<'a> CompletionContext<'a> {
159208 }
160209
161210 cursor. goto_first_child_for_byte ( self . position ) ;
162- self . gather_context_from_node ( cursor, current_node_kind ) ;
211+ self . gather_context_from_node ( cursor, current_node ) ;
163212 }
164213}
165214
@@ -209,7 +258,7 @@ mod tests {
209258 ] ;
210259
211260 for ( query, expected_clause) in test_cases {
212- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
261+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
213262
214263 let tree = get_tree ( text. as_str ( ) ) ;
215264
@@ -242,7 +291,7 @@ mod tests {
242291 ] ;
243292
244293 for ( query, expected_schema) in test_cases {
245- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
294+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
246295
247296 let tree = get_tree ( text. as_str ( ) ) ;
248297 let params = crate :: CompletionParams {
@@ -276,7 +325,7 @@ mod tests {
276325 ] ;
277326
278327 for ( query, is_invocation) in test_cases {
279- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
328+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
280329
281330 let tree = get_tree ( text. as_str ( ) ) ;
282331 let params = crate :: CompletionParams {
@@ -300,7 +349,7 @@ mod tests {
300349 ] ;
301350
302351 for query in cases {
303- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
352+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
304353
305354 let tree = get_tree ( text. as_str ( ) ) ;
306355
@@ -328,7 +377,7 @@ mod tests {
328377 fn does_not_fail_on_trailing_whitespace ( ) {
329378 let query = format ! ( "select * from {}" , CURSOR_POS ) ;
330379
331- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
380+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
332381
333382 let tree = get_tree ( text. as_str ( ) ) ;
334383
@@ -354,7 +403,7 @@ mod tests {
354403 fn does_not_fail_with_empty_statements ( ) {
355404 let query = format ! ( "{}" , CURSOR_POS ) ;
356405
357- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
406+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
358407
359408 let tree = get_tree ( text. as_str ( ) ) ;
360409
@@ -379,7 +428,7 @@ mod tests {
379428 // is selecting a certain column name, such as `frozen_account`.
380429 let query = format ! ( "select * fro{}" , CURSOR_POS ) ;
381430
382- let ( position, text) = get_text_and_position ( query. as_str ( ) ) ;
431+ let ( position, text) = get_text_and_position ( query. as_str ( ) . into ( ) ) ;
383432
384433 let tree = get_tree ( text. as_str ( ) ) ;
385434
0 commit comments