Skip to content

Commit

Permalink
feat: upgrade dict type into schema type (#1350)
Browse files Browse the repository at this point in the history
* feat: upgrade dict type into schema type in node_ty_map if expr's expected type is schema

Signed-off-by: he1pa <[email protected]>

* chore: recover accidentally deleted code

* remove resolver.ctx.ty.ctx.expected_ty, use resolver.ctx.config_expr_context instead

Signed-off-by: he1pa <[email protected]>

* use schema symbol stack to replace current schema symbol

Signed-off-by: he1pa <[email protected]>

* handle union type

* add a option emit_error in check_type() to not report an error when trying to upgrade dict into schema

Signed-off-by: he1pa <[email protected]>

* chore: remove unused code

Signed-off-by: he1pa <[email protected]>

* rebase main

* remove unused code

Signed-off-by: he1pa <[email protected]>

---------

Signed-off-by: he1pa <[email protected]>
  • Loading branch information
He1pa authored May 24, 2024
1 parent 6f9c637 commit 1a83c61
Show file tree
Hide file tree
Showing 16 changed files with 359 additions and 91 deletions.
4 changes: 2 additions & 2 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub struct Context<'ctx> {
scopes: Vec<ScopeRef>,
current_pkgpath: Option<String>,
current_filename: Option<String>,
current_schema_symbol: Option<SymbolRef>,
schema_symbol_stack: Vec<Option<SymbolRef>>,
start_pos: Position,
end_pos: Position,
cur_node: AstIndex,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
scopes: vec![],
current_filename: None,
current_pkgpath: None,
current_schema_symbol: None,
schema_symbol_stack: vec![],
start_pos: Position::dummy_pos(),
end_pos: Position::dummy_pos(),
cur_node: AstIndex::default(),
Expand Down
88 changes: 23 additions & 65 deletions kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,6 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
}
self.ctx.maybe_def = true;
self.walk_identifier_expr(target)?;

if let Some(target_ty) = self.ctx.node_ty_map.get(&self.ctx.get_node_key(&target.id)) {
match &target_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&target_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.current_schema_symbol = Some(schema_symbol);
}
_ => {}
}
}
self.ctx.maybe_def = false;
}
self.walk_type_expr(assign_stmt.ty.as_ref().map(|ty| ty.as_ref()))?;
Expand Down Expand Up @@ -702,18 +688,9 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
.clone();
match schema_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&schema_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.current_schema_symbol = Some(schema_symbol);
self.expr(&schema_expr.config)?;
self.do_arguments_symbol_resolve(&schema_expr.args, &schema_expr.kwargs)?;
}
TypeKind::Dict(_) => {
// TODO: for builtin ty symbol, get_type_symbol() just return None
}
_ => {
// Invalid schema type, nothing todo
}
Expand Down Expand Up @@ -854,7 +831,27 @@ impl<'ctx> AdvancedResolver<'ctx> {
self.ctx.end_pos = end;
}
self.ctx.cur_node = expr.id.clone();
match self.walk_expr(&expr.node) {

if let Some(expr_ty) = self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
match &expr_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&expr_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.schema_symbol_stack.push(Some(schema_symbol));
}
_ => {
self.ctx.schema_symbol_stack.push(None);
}
}
}

let expr_symbol = self.walk_expr(&expr.node);
self.ctx.schema_symbol_stack.pop();

match expr_symbol {
Ok(None) => match self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
Some(ty) => {
if let ast::Expr::Missing(_) = expr.node {
Expand Down Expand Up @@ -1208,7 +1205,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
) -> anyhow::Result<()> {
let (start, end) = (self.ctx.start_pos.clone(), self.ctx.end_pos.clone());

let schema_symbol = self.ctx.current_schema_symbol.take();
let schema_symbol = self.ctx.schema_symbol_stack.last().unwrap_or(&None).clone();
let kind = match &schema_symbol {
Some(_) => LocalSymbolScopeKind::SchemaConfig,
None => LocalSymbolScopeKind::Value,
Expand All @@ -1231,9 +1228,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
for entry in entries.iter() {
if let Some(key) = &entry.node.key {
self.ctx.maybe_def = true;
if let Some(key_symbol_ref) = self.expr(key)? {
self.set_current_schema_symbol(key_symbol_ref)?;
}
self.expr(key)?;
self.ctx.maybe_def = false;
}

Expand All @@ -1253,43 +1248,6 @@ impl<'ctx> AdvancedResolver<'ctx> {
Ok(())
}

pub(crate) fn set_current_schema_symbol(
&mut self,
key_symbol_ref: SymbolRef,
) -> anyhow::Result<()> {
let symbols = self.gs.get_symbols();

if let Some(def_symbol_ref) = symbols
.get_symbol(key_symbol_ref)
.ok_or(anyhow!("def_symbol_ref not found"))?
.get_definition()
{
if let Some(node_key) = symbols.symbols_info.symbol_node_map.get(&def_symbol_ref) {
if let Some(def_ty) = self.ctx.node_ty_map.get(node_key) {
if let Some(ty) = get_possible_schema_ty(def_ty.clone()) {
self.ctx.current_schema_symbol =
self.gs.get_symbols().get_type_symbol(&ty, None);
}
}
}
}
fn get_possible_schema_ty(ty: Arc<Type>) -> Option<Arc<Type>> {
match &ty.kind {
crate::ty::TypeKind::List(ty) => get_possible_schema_ty(ty.clone()),
crate::ty::TypeKind::Dict(dict_ty) => {
get_possible_schema_ty(dict_ty.val_ty.clone())
}
crate::ty::TypeKind::Union(_) => {
// Todo: fix union schema type
None
}
crate::ty::TypeKind::Schema(_) => Some(ty.clone()),
_ => None,
}
}
Ok(())
}

pub(crate) fn resolve_decorator(&mut self, decorators: &'ctx [ast::NodeRef<ast::CallExpr>]) {
for decorator in decorators {
let func_ident = &decorator.node.func;
Expand Down
3 changes: 2 additions & 1 deletion kclvm/sema/src/resolver/arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl<'ctx> Resolver<'ctx> {
return;
}
};
self.must_assignable_to(ty.clone(), expected_ty, args[i].get_span_pos(), None)
self.must_assignable_to(ty.clone(), expected_ty, args[i].get_span_pos(), None, true)
}
// Do keyword argument type check
for (i, (arg_name, kwarg_ty)) in kwarg_types.iter().enumerate() {
Expand Down Expand Up @@ -132,6 +132,7 @@ impl<'ctx> Resolver<'ctx> {
expected_types[0].clone(),
kwargs[i].get_span_pos(),
None,
true,
);
};
}
Expand Down
2 changes: 2 additions & 0 deletions kclvm/sema/src/resolver/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ impl<'ctx> Resolver<'ctx> {
ty,
key.get_span_pos(),
Some(obj_last.get_span_pos()),
true,
);
}
self.clear_config_expr_context(stack_depth, false);
Expand Down Expand Up @@ -489,6 +490,7 @@ impl<'ctx> Resolver<'ctx> {
Some(key) => match &key.node {
ast::Expr::Identifier(identifier) => {
let mut val_ty = self.expr(value);

for _ in 0..identifier.names.len() - 1 {
val_ty = Type::dict_ref(self.str_ty(), val_ty.clone());
}
Expand Down
61 changes: 57 additions & 4 deletions kclvm/sema/src/resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty.clone(),
unification_stmt.target.get_span_pos(),
None,
true,
);
if !ty.is_any() && expected_ty.is_any() {
self.set_type_to_scope(&names[0].node, ty, &names[0]);
Expand Down Expand Up @@ -184,7 +185,19 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty.clone(),
target.get_span_pos(),
None,
true,
);

let upgrade_schema_type = self.upgrade_dict_to_schema(
value_ty.clone(),
expected_ty.clone(),
&assign_stmt.value.get_span_pos(),
);
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);

if !value_ty.is_any() && expected_ty.is_any() && assign_stmt.ty.is_none() {
self.set_type_to_scope(name, value_ty.clone(), &target.node.names[0]);
if let Some(schema_ty) = &self.ctx.schema {
Expand All @@ -203,7 +216,23 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
value_ty = self.expr(&assign_stmt.value);
// Check type annotation if exists.
self.check_assignment_type_annotation(assign_stmt, value_ty.clone());
self.must_assignable_to(value_ty.clone(), expected_ty, target.get_span_pos(), None)
self.must_assignable_to(
value_ty.clone(),
expected_ty.clone(),
target.get_span_pos(),
None,
true,
);

let upgrade_schema_type = self.upgrade_dict_to_schema(
value_ty.clone(),
expected_ty.clone(),
&assign_stmt.value.get_span_pos(),
);
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);
}
}
value_ty
Expand Down Expand Up @@ -261,6 +290,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
aug_assign_stmt.target.get_span_pos(),
None,
true,
);
self.ctx.l_value = false;
new_target_ty
Expand Down Expand Up @@ -427,6 +457,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
schema_attr.name.get_span_pos(),
None,
true,
);
}
// Assign
Expand All @@ -435,6 +466,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
schema_attr.name.get_span_pos(),
None,
true,
),
},
None => bug!("invalid ast schema attr op kind"),
Expand Down Expand Up @@ -1068,7 +1100,13 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
let real_ret_ty = self.stmts(&lambda_expr.body);
self.leave_scope();
self.ctx.in_lambda_expr.pop();
self.must_assignable_to(real_ret_ty.clone(), ret_ty.clone(), (start, end), None);
self.must_assignable_to(
real_ret_ty.clone(),
ret_ty.clone(),
(start, end),
None,
true,
);
if !real_ret_ty.is_any() && ret_ty.is_any() && lambda_expr.return_ty.is_none() {
ret_ty = real_ret_ty;
}
Expand Down Expand Up @@ -1232,9 +1270,24 @@ impl<'ctx> Resolver<'ctx> {
self.ctx.start_pos = start;
self.ctx.end_pos = end;
}

let expected_ty = match self.ctx.config_expr_context.last() {
Some(ty) => ty.clone().map(|o| o.ty),
None => None,
};

let ty = self.walk_expr(&expr.node);
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());

if let Some(expected_ty) = expected_ty {
let upgrade_ty =
self.upgrade_dict_to_schema(ty.clone(), expected_ty, &expr.get_span_pos());
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), upgrade_ty);
} else {
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());
}

ty
}

Expand Down
1 change: 1 addition & 0 deletions kclvm/sema/src/resolver/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl<'ctx> Resolver<'ctx> {
expected_ty,
index_signature_node.get_span_pos(),
None,
true,
);
}
}
Expand Down
Loading

0 comments on commit 1a83c61

Please sign in to comment.