Skip to content

Commit

Permalink
feat: upgrade dict type into schema type in node_ty_map if expr's exp…
Browse files Browse the repository at this point in the history
…ected type is schema

Signed-off-by: he1pa <[email protected]>
  • Loading branch information
He1pa committed May 22, 2024
1 parent 6bf352c commit 48e7705
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 5 deletions.
24 changes: 23 additions & 1 deletion kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
}

fn walk_assign_stmt(&mut self, assign_stmt: &'ctx ast::AssignStmt) -> Self::Result {
let old_current_schema_symbol = self.ctx.current_schema_symbol.clone();
for target in &assign_stmt.targets {
if target.node.names.is_empty() {
continue;
Expand All @@ -111,6 +112,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
}
self.walk_type_expr(assign_stmt.ty.as_ref().map(|ty| ty.as_ref()))?;
self.expr(&assign_stmt.value)?;
self.ctx.current_schema_symbol = old_current_schema_symbol;
Ok(None)
}

Expand Down Expand Up @@ -854,7 +856,27 @@ impl<'ctx> AdvancedResolver<'ctx> {
self.ctx.end_pos = end;
}
self.ctx.cur_node = expr.id.clone();
match self.walk_expr(&expr.node) {

let old_current_schema_symbol = self.ctx.current_schema_symbol.clone();

if let Some(expr_ty) = self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
match &expr_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&expr_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.current_schema_symbol = Some(schema_symbol);
}
_ => {}
}
}

let expr_symbol = self.walk_expr(&expr.node);
self.ctx.current_schema_symbol = old_current_schema_symbol;

match expr_symbol {
Ok(None) => match self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
Some(ty) => {
if let ast::Expr::Missing(_) = expr.node {
Expand Down
28 changes: 27 additions & 1 deletion kclvm/sema/src/resolver/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ impl<'ctx> Resolver<'ctx> {
&mut self,
entries: &'ctx [ast::NodeRef<ast::ConfigEntry>],
) -> TypeRef {
// println!("{:?}", "walk_config_entries");
let (start, end) = match entries.len() {
0 => (self.ctx.start_pos.clone(), self.ctx.end_pos.clone()),
1 => entries[0].get_span_pos(),
Expand All @@ -478,17 +479,42 @@ impl<'ctx> Resolver<'ctx> {
let mut val_types: Vec<TypeRef> = vec![];
let mut attrs: IndexMap<String, Attr> = IndexMap::new();
for item in entries {
// println!("{:?}", item.get_span_pos());
let key = &item.node.key;
let value = &item.node.value;
let op = &item.node.operation;
let mut stack_depth: usize = 0;
self.check_config_entry(key, value);
stack_depth += self.switch_config_expr_context_by_key(key);
let mut has_insert_index = false;
let val_ty = match key {
Some(key) => match &key.node {
ast::Expr::Identifier(identifier) => {
let old_expected_ty = self.ctx.ty_ctx.expected_ty.clone();
if let Some(expected_ty) = self.ctx.ty_ctx.expected_ty.clone() {
match &expected_ty.kind {
TypeKind::Dict(dict_ty) => {
self.ctx.ty_ctx.expected_ty = Some(dict_ty.val_ty.clone());
}

TypeKind::Schema(schema_ty) => {
match identifier.names.len() {
1 => match schema_ty.attrs.get(&identifier.names[0].node) {
Some(attr) => {
self.ctx.ty_ctx.expected_ty = Some(attr.ty.clone())
}
None => {}
},
_ => {
// unreachable, len() = 0 is error, len() > 1 has be desuger in pre_process()
}
};
}
_ => {}
}
}
let mut val_ty = self.expr(value);
self.ctx.ty_ctx.expected_ty = old_expected_ty;

for _ in 0..identifier.names.len() - 1 {
val_ty = Type::dict_ref(self.str_ty(), val_ty.clone());
}
Expand Down
47 changes: 44 additions & 3 deletions kclvm/sema/src/resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
if target.node.names.len() == 1 {
self.ctx.l_value = true;
let expected_ty = self.walk_identifier_expr(target);
// println!("ssign expecy ty{:?}", expected_ty);
let old_expected_ty = self.ctx.ty_ctx.expected_ty.clone();
self.ctx.ty_ctx.expected_ty = Some(expected_ty.clone());
self.ctx.l_value = false;
match &expected_ty.kind {
TypeKind::Schema(ty) => {
Expand Down Expand Up @@ -185,6 +188,15 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
target.get_span_pos(),
None,
);

self.ctx.ty_ctx.expected_ty = old_expected_ty;
let upgrade_schema_type =
self.upgrade_dict_to_schema(value_ty.clone(), expected_ty.clone());
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);

if !value_ty.is_any() && expected_ty.is_any() && assign_stmt.ty.is_none() {
self.set_type_to_scope(name, value_ty.clone(), &target.node.names[0]);
if let Some(schema_ty) = &self.ctx.schema {
Expand All @@ -199,11 +211,25 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
self.lookup_type_from_scope(name, target.get_span_pos());
self.ctx.l_value = true;
let expected_ty = self.walk_identifier_expr(target);
self.ctx.ty_ctx.expected_ty = Some(expected_ty.clone());
self.ctx.l_value = false;
value_ty = self.expr(&assign_stmt.value);
self.ctx.ty_ctx.expected_ty = None;
// Check type annotation if exists.
self.check_assignment_type_annotation(assign_stmt, value_ty.clone());
self.must_assignable_to(value_ty.clone(), expected_ty, target.get_span_pos(), None)
self.must_assignable_to(
value_ty.clone(),
expected_ty.clone(),
target.get_span_pos(),
None,
);

let upgrade_schema_type =
self.upgrade_dict_to_schema(value_ty.clone(), expected_ty.clone());
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);
}
}
value_ty
Expand Down Expand Up @@ -677,7 +703,15 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {

fn walk_list_expr(&mut self, list_expr: &'ctx ast::ListExpr) -> Self::Result {
let stack_depth = self.switch_list_expr_context();
let old_expected_ty = self.ctx.ty_ctx.expected_ty.clone();
if let Some(ty) = self.ctx.ty_ctx.expected_ty.clone() {
match &ty.kind {
TypeKind::List(item_ty) => self.ctx.ty_ctx.expected_ty = Some(item_ty.clone()),
_ => {}
}
}
let item_type = sup(&self.exprs(&list_expr.elts).to_vec());
self.ctx.ty_ctx.expected_ty = old_expected_ty;
self.clear_config_expr_context(stack_depth, false);
Type::list_ref(item_type)
}
Expand Down Expand Up @@ -1233,8 +1267,15 @@ impl<'ctx> Resolver<'ctx> {
self.ctx.end_pos = end;
}
let ty = self.walk_expr(&expr.node);
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());
if let Some(expected_ty) = self.ctx.ty_ctx.expected_ty.clone() {
let upgrade_ty = self.upgrade_dict_to_schema(ty.clone(), expected_ty);
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), upgrade_ty);
} else {
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());
}

ty
}

Expand Down
24 changes: 24 additions & 0 deletions kclvm/sema/src/resolver/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ impl<'ctx> Resolver<'ctx> {
}
}

// Upgrade the dict type into schema type if it is expected to schema
pub fn upgrade_dict_to_schema(&mut self, ty: TypeRef, expected_ty: TypeRef) -> TypeRef {
match (&ty.kind, &expected_ty.kind) {
(TypeKind::Dict(DictType { .. }), TypeKind::Schema(_)) => expected_ty,
(TypeKind::List(item_ty), TypeKind::List(expected_item_ty)) => {
Type::list(self.upgrade_dict_to_schema(item_ty.clone(), expected_item_ty.clone()))
.into()
}
(
TypeKind::Dict(DictType { key_ty, val_ty, .. }),
TypeKind::Dict(DictType {
key_ty: expected_key_ty,
val_ty: expected_val_ty,
..
}),
) => Type::dict(
self.upgrade_dict_to_schema(key_ty.clone(), expected_key_ty.clone()),
self.upgrade_dict_to_schema(val_ty.clone(), expected_val_ty.clone()),
)
.into(),
_ => expected_ty,
}
}

/// Check the type assignment statement between type annotation and target.
pub fn check_assignment_type_annotation(
&mut self,
Expand Down
2 changes: 2 additions & 0 deletions kclvm/sema/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TypeContext {
pub dep_graph: DiGraph<String, ()>,
pub builtin_types: BuiltinTypes,
node_index_map: HashMap<String, NodeIndex>,
pub expected_ty: Option<TypeRef>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -46,6 +47,7 @@ impl TypeContext {
none: Arc::new(Type::NONE),
},
node_index_map: HashMap::new(),
expected_ty: None,
}
}

Expand Down
68 changes: 68 additions & 0 deletions kclvm/tools/src/LSP/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1648,4 +1648,72 @@ mod tests {
CompletionResponse::List(_) => panic!("test failed"),
}
}

#[test]
#[bench_test]
fn nested_test() {
let (file, program, _, gs) =
compile_test_file("src/test_data/completion_test/dot/nested/nested.k");

let pos = KCLPos {
filename: file.to_owned(),
line: 9,
column: Some(9),
};

let mut got = completion(None, &program, &pos, &gs).unwrap();

match &mut got {
CompletionResponse::Array(arr) => {
let labels: Vec<String> = arr.iter().map(|item| item.label.clone()).collect();
insta::assert_snapshot!(format!("{:?}", labels));
}
CompletionResponse::List(_) => panic!("test failed"),
}

let pos = KCLPos {
filename: file.to_owned(),
line: 18,
column: Some(9),
};

let mut got = completion(None, &program, &pos, &gs).unwrap();
match &mut got {
CompletionResponse::Array(arr) => {
let labels: Vec<String> = arr.iter().map(|item| item.label.clone()).collect();
insta::assert_snapshot!(format!("{:?}", labels));
}
CompletionResponse::List(_) => panic!("test failed"),
}

let pos = KCLPos {
filename: file.to_owned(),
line: 24,
column: Some(13),
};

let mut got = completion(None, &program, &pos, &gs).unwrap();
match &mut got {
CompletionResponse::Array(arr) => {
let labels: Vec<String> = arr.iter().map(|item| item.label.clone()).collect();
insta::assert_snapshot!(format!("{:?}", labels));
}
CompletionResponse::List(_) => panic!("test failed"),
}

let pos = KCLPos {
filename: file.to_owned(),
line: 33,
column: Some(13),
};

let mut got = completion(None, &program, &pos, &gs).unwrap();
match &mut got {
CompletionResponse::Array(arr) => {
let labels: Vec<String> = arr.iter().map(|item| item.label.clone()).collect();
insta::assert_snapshot!(format!("{:?}", labels));
}
CompletionResponse::List(_) => panic!("test failed"),
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", labels)"
---
["ab"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", labels)"
---
["ab"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", labels)"
---
["ab"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: tools/src/LSP/src/completion.rs
expression: "format!(\"{:?}\", labels)"
---
["ab"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
schema N:
ab: int

schema ListN:
a: [N]

list_N: ListN = {
a: [{

}]
}

schema DictN:
a: {str:N}

dictN: DictN = {
a.c = {

}
}
dictN1: DictN = {
a : {
c: {

}
}
}
schema ListListN:
a: [[N]]

listlistN: ListListN = {
a: [[{

}]]
}

0 comments on commit 48e7705

Please sign in to comment.