Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix type symbol completion #1568

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ pub struct Context<'ctx> {
maybe_def: bool,
// whether in schema config right value, affect lookup def
in_config_r_value: bool,

is_type_expr: bool,
}

impl<'ctx> Context<'ctx> {
Expand Down Expand Up @@ -114,6 +116,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
cur_node: AstIndex::default(),
maybe_def: false,
in_config_r_value: false,
is_type_expr: false,
},
};
// Scan all scehma symbol
Expand Down
64 changes: 56 additions & 8 deletions kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,13 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
.unwrap_or(import_stmt.path.clone())
.get_span_pos();

let mut unresolved =
UnresolvedSymbol::new(import_stmt.path.node.clone(), start_pos, end_pos, None);
let mut unresolved = UnresolvedSymbol::new(
import_stmt.path.node.clone(),
start_pos,
end_pos,
None,
self.ctx.is_type_expr,
);
let package_symbol = match self
.gs
.get_symbols()
Expand Down Expand Up @@ -622,7 +627,13 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {

let (start_pos, end_pos): Range = name.get_span_pos();
let ast_id = name.id.clone();
let mut unresolved = UnresolvedSymbol::new(name.node.clone(), start_pos, end_pos, None);
let mut unresolved = UnresolvedSymbol::new(
name.node.clone(),
start_pos,
end_pos,
None,
self.ctx.is_type_expr,
);
unresolved.def = Some(def_symbol_ref);
let unresolved_ref = self.gs.get_symbols_mut().alloc_unresolved_symbol(
unresolved,
Expand Down Expand Up @@ -1106,6 +1117,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
start_pos.clone(),
end_pos.clone(),
None,
self.ctx.is_type_expr,
);
let name = def_symbol.get_name();
first_unresolved.def = Some(symbol_ref);
Expand Down Expand Up @@ -1172,8 +1184,13 @@ impl<'ctx> AdvancedResolver<'ctx> {

let (start_pos, end_pos): Range = name.get_span_pos();
let ast_id = name.id.clone();
let mut unresolved =
UnresolvedSymbol::new(name.node.clone(), start_pos, end_pos, None);
let mut unresolved = UnresolvedSymbol::new(
name.node.clone(),
start_pos,
end_pos,
None,
self.ctx.is_type_expr,
);
unresolved.def = Some(def_symbol_ref);

unresolved.sema_info = SymbolSemanticInfo {
Expand Down Expand Up @@ -1308,8 +1325,13 @@ impl<'ctx> AdvancedResolver<'ctx> {
// Get an unresolved symbol
if def_start_pos != start_pos || def_end_pos != end_pos {
let ast_id = first_name.id.clone();
let mut first_unresolved =
UnresolvedSymbol::new(first_name.node.clone(), start_pos, end_pos, None);
let mut first_unresolved = UnresolvedSymbol::new(
first_name.node.clone(),
start_pos,
end_pos,
None,
self.ctx.is_type_expr,
);
first_unresolved.def = Some(symbol_ref);
let first_unresolved_ref = self.gs.get_symbols_mut().alloc_unresolved_symbol(
first_unresolved,
Expand Down Expand Up @@ -1352,6 +1374,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
start_pos,
end_pos,
None,
self.ctx.is_type_expr,
);
unresolved.def = Some(def_symbol_ref);
unresolved.sema_info = SymbolSemanticInfo {
Expand Down Expand Up @@ -1641,11 +1664,12 @@ impl<'ctx> AdvancedResolver<'ctx> {
&mut self,
ty_node: Option<&'ctx ast::Node<ast::Type>>,
) -> ResolvedResult {
self.ctx.is_type_expr = true;
if let Some(ty_node) = ty_node {
match &ty_node.node {
ast::Type::Any => {}
ast::Type::Named(identifier) => {
self.walk_identifier(identifier)?;
let r = self.walk_identifier(identifier)?;
}
ast::Type::Basic(_) => {}
ast::Type::List(list_type) => {
Expand Down Expand Up @@ -1673,6 +1697,30 @@ impl<'ctx> AdvancedResolver<'ctx> {
}
}
}

if let Some(ty_node) = ty_node {
match self
.ctx
.node_ty_map
.borrow()
.get(&self.ctx.get_node_key(&ty_node.id))
{
Some(ty) => {
let (_, end) = ty_node.get_span_pos();
let mut expr_symbol =
ExpressionSymbol::new(format!("@{}", ty.ty_str()), end.clone(), end, None);

expr_symbol.sema_info.ty = Some(ty.clone());
self.gs.get_symbols_mut().alloc_expression_symbol(
expr_symbol,
self.ctx.get_node_key(&ty_node.id),
self.ctx.current_pkgpath.clone().unwrap(),
);
}
None => {}
}
}
self.ctx.is_type_expr = false;
Ok(None)
}

Expand Down
39 changes: 33 additions & 6 deletions kclvm/sema/src/core/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ impl SymbolData {
}
}

pub fn get_unresolved_symbol(&self, id: SymbolRef) -> Option<&UnresolvedSymbol> {
if matches!(id.get_kind(), SymbolKind::Unresolved) {
self.unresolved.get(id.get_id())
} else {
None
}
}

pub fn get_symbol(&self, id: SymbolRef) -> Option<&KCLSymbol> {
match id.get_kind() {
SymbolKind::Schema => self
Expand Down Expand Up @@ -1825,6 +1833,7 @@ pub struct UnresolvedSymbol {
pub(crate) owner: Option<SymbolRef>,
pub(crate) sema_info: SymbolSemanticInfo,
pub(crate) hint: Option<SymbolHint>,
pub(crate) is_type: bool,
}

impl Symbol for UnresolvedSymbol {
Expand Down Expand Up @@ -1861,20 +1870,27 @@ impl Symbol for UnresolvedSymbol {
data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
) -> Option<SymbolRef> {
data.get_symbol(self.def?)?
.get_attribute(name, data, module_info)
if self.is_type() {
None
} else {
data.get_symbol(self.def?)?
.get_attribute(name, data, module_info)
}
}

fn get_all_attributes(
&self,
data: &Self::SymbolData,
module_info: Option<&ModuleInfo>,
) -> Vec<SymbolRef> {
if let Some(def) = self.def {
if let Some(def_symbol) = data.get_symbol(def) {
return def_symbol.get_all_attributes(data, module_info);
if !self.is_type() {
if let Some(def) = self.def {
if let Some(def_symbol) = data.get_symbol(def) {
return def_symbol.get_all_attributes(data, module_info);
}
}
}

vec![]
}

Expand Down Expand Up @@ -1928,7 +1944,13 @@ impl Symbol for UnresolvedSymbol {
}

impl UnresolvedSymbol {
pub fn new(name: String, start: Position, end: Position, owner: Option<SymbolRef>) -> Self {
pub fn new(
name: String,
start: Position,
end: Position,
owner: Option<SymbolRef>,
is_type: bool,
) -> Self {
Self {
id: None,
def: None,
Expand All @@ -1938,6 +1960,7 @@ impl UnresolvedSymbol {
sema_info: SymbolSemanticInfo::default(),
owner,
hint: None,
is_type,
}
}

Expand All @@ -1956,6 +1979,10 @@ impl UnresolvedSymbol {

pkg_path + "." + names.last().unwrap()
}

pub fn is_type(&self) -> bool {
self.is_type
}
}

#[derive(Debug, Clone)]
Expand Down
64 changes: 60 additions & 4 deletions kclvm/tools/src/LSP/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::io;
use std::{fs, path::Path};

use crate::goto_def::find_def;
use crate::goto_def::{find_def, find_symbol};
use indexmap::IndexSet;
use kclvm_ast::ast::{self, ImportStmt, Program, Stmt};
use kclvm_ast::MAIN_PKG;
Expand Down Expand Up @@ -248,11 +248,27 @@ fn completion_dot(
}

// look_up_exact_symbol
let mut def = find_def(&pre_pos, gs, true);
if def.is_none() {
def = find_def(pos, gs, false);
let mut symbol = find_symbol(&pre_pos, gs, true);
if symbol.is_none() {
symbol = find_symbol(pos, gs, false);
}

let def = match symbol {
Some(symbol_ref) => {
if let SymbolKind::Unresolved = symbol_ref.get_kind() {
let unresolved_symbol = gs.get_symbols().get_unresolved_symbol(symbol_ref).unwrap();
if unresolved_symbol.is_type() {
return Some(into_completion_items(&items).into());
}
}
match gs.get_symbols().get_symbol(symbol_ref) {
Some(symbol) => symbol.get_definition(),
None => None,
}
}
None => None,
};

match def {
Some(def_ref) => {
if let Some(def) = gs.get_symbols().get_symbol(def_ref) {
Expand Down Expand Up @@ -2096,4 +2112,44 @@ mod tests {
12,
None
);

completion_label_test_snapshot!(
schema_attr_ty_0,
"src/test_data/completion_test/dot/schema_attr_ty/schema_attr_ty.k",
5,
13,
Some('.')
);

completion_label_test_snapshot!(
schema_attr_ty_1,
"src/test_data/completion_test/dot/schema_attr_ty/schema_attr_ty.k",
6,
14,
Some('.')
);

completion_label_test_snapshot!(
schema_attr_ty_2,
"src/test_data/completion_test/dot/schema_attr_ty/schema_attr_ty.k",
7,
18,
Some('.')
);

completion_label_test_snapshot!(
schema_attr_ty_3,
"src/test_data/completion_test/dot/schema_attr_ty/schema_attr_ty.k",
8,
17,
Some('.')
);

completion_label_test_snapshot!(
schema_attr_ty_4,
"src/test_data/completion_test/dot/schema_attr_ty/schema_attr_ty.k",
10,
15,
Some('.')
);
}
8 changes: 8 additions & 0 deletions kclvm/tools/src/LSP/src/goto_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ pub(crate) fn find_def(kcl_pos: &KCLPos, gs: &GlobalState, exact: bool) -> Optio
}
}

pub(crate) fn find_symbol(kcl_pos: &KCLPos, gs: &GlobalState, exact: bool) -> Option<SymbolRef> {
if exact {
gs.look_up_exact_symbol(kcl_pos)
} else {
gs.look_up_closest_symbol(kcl_pos)
}
}

// Convert kcl position to GotoDefinitionResponse. This function will convert to
// None, Scalar or Array according to the number of positions
fn positions_to_goto_def_resp(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", got_labels)"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", got_labels)"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", got_labels)"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", got_labels)"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", got_labels)"
---
[]
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
schema A:
n: str

schema B:
named: A
list: [A]
dict: {str:A}
union: str|A

a: A
Loading