Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: upgrade dict type into schema type #1350

Merged
merged 9 commits into from
May 24, 2024
4 changes: 2 additions & 2 deletions kclvm/sema/src/advanced_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub struct Context<'ctx> {
scopes: Vec<ScopeRef>,
current_pkgpath: Option<String>,
current_filename: Option<String>,
current_schema_symbol: Option<SymbolRef>,
schema_symbol_stack: Vec<Option<SymbolRef>>,
start_pos: Position,
end_pos: Position,
cur_node: AstIndex,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
scopes: vec![],
current_filename: None,
current_pkgpath: None,
current_schema_symbol: None,
schema_symbol_stack: vec![],
start_pos: Position::dummy_pos(),
end_pos: Position::dummy_pos(),
cur_node: AstIndex::default(),
Expand Down
88 changes: 23 additions & 65 deletions kclvm/sema/src/advanced_resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,6 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
}
self.ctx.maybe_def = true;
self.walk_identifier_expr(target)?;

if let Some(target_ty) = self.ctx.node_ty_map.get(&self.ctx.get_node_key(&target.id)) {
match &target_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&target_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.current_schema_symbol = Some(schema_symbol);
}
_ => {}
}
}
self.ctx.maybe_def = false;
}
self.walk_type_expr(assign_stmt.ty.as_ref().map(|ty| ty.as_ref()))?;
Expand Down Expand Up @@ -702,18 +688,9 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for AdvancedResolver<'ctx> {
.clone();
match schema_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&schema_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.current_schema_symbol = Some(schema_symbol);
self.expr(&schema_expr.config)?;
self.do_arguments_symbol_resolve(&schema_expr.args, &schema_expr.kwargs)?;
}
TypeKind::Dict(_) => {
// TODO: for builtin ty symbol, get_type_symbol() just return None
}
_ => {
// Invalid schema type, nothing todo
}
Expand Down Expand Up @@ -854,7 +831,27 @@ impl<'ctx> AdvancedResolver<'ctx> {
self.ctx.end_pos = end;
}
self.ctx.cur_node = expr.id.clone();
match self.walk_expr(&expr.node) {

if let Some(expr_ty) = self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
match &expr_ty.kind {
TypeKind::Schema(_) => {
let schema_symbol = self
.gs
.get_symbols()
.get_type_symbol(&expr_ty, self.get_current_module_info())
.ok_or(anyhow!("schema_symbol not found"))?;
self.ctx.schema_symbol_stack.push(Some(schema_symbol));
}
_ => {
self.ctx.schema_symbol_stack.push(None);
}
}
}

let expr_symbol = self.walk_expr(&expr.node);
self.ctx.schema_symbol_stack.pop();

match expr_symbol {
Ok(None) => match self.ctx.node_ty_map.get(&self.ctx.get_node_key(&expr.id)) {
Some(ty) => {
if let ast::Expr::Missing(_) = expr.node {
Expand Down Expand Up @@ -1208,7 +1205,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
) -> anyhow::Result<()> {
let (start, end) = (self.ctx.start_pos.clone(), self.ctx.end_pos.clone());

let schema_symbol = self.ctx.current_schema_symbol.take();
let schema_symbol = self.ctx.schema_symbol_stack.last().unwrap_or(&None).clone();
let kind = match &schema_symbol {
Some(_) => LocalSymbolScopeKind::SchemaConfig,
None => LocalSymbolScopeKind::Value,
Expand All @@ -1231,9 +1228,7 @@ impl<'ctx> AdvancedResolver<'ctx> {
for entry in entries.iter() {
if let Some(key) = &entry.node.key {
self.ctx.maybe_def = true;
if let Some(key_symbol_ref) = self.expr(key)? {
self.set_current_schema_symbol(key_symbol_ref)?;
}
self.expr(key)?;
self.ctx.maybe_def = false;
}

Expand All @@ -1253,43 +1248,6 @@ impl<'ctx> AdvancedResolver<'ctx> {
Ok(())
}

pub(crate) fn set_current_schema_symbol(
&mut self,
key_symbol_ref: SymbolRef,
) -> anyhow::Result<()> {
let symbols = self.gs.get_symbols();

if let Some(def_symbol_ref) = symbols
.get_symbol(key_symbol_ref)
.ok_or(anyhow!("def_symbol_ref not found"))?
.get_definition()
{
if let Some(node_key) = symbols.symbols_info.symbol_node_map.get(&def_symbol_ref) {
if let Some(def_ty) = self.ctx.node_ty_map.get(node_key) {
if let Some(ty) = get_possible_schema_ty(def_ty.clone()) {
self.ctx.current_schema_symbol =
self.gs.get_symbols().get_type_symbol(&ty, None);
}
}
}
}
fn get_possible_schema_ty(ty: Arc<Type>) -> Option<Arc<Type>> {
match &ty.kind {
crate::ty::TypeKind::List(ty) => get_possible_schema_ty(ty.clone()),
crate::ty::TypeKind::Dict(dict_ty) => {
get_possible_schema_ty(dict_ty.val_ty.clone())
}
crate::ty::TypeKind::Union(_) => {
// Todo: fix union schema type
None
}
crate::ty::TypeKind::Schema(_) => Some(ty.clone()),
_ => None,
}
}
Ok(())
}

pub(crate) fn resolve_decorator(&mut self, decorators: &'ctx [ast::NodeRef<ast::CallExpr>]) {
for decorator in decorators {
let func_ident = &decorator.node.func;
Expand Down
3 changes: 2 additions & 1 deletion kclvm/sema/src/resolver/arg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl<'ctx> Resolver<'ctx> {
return;
}
};
self.must_assignable_to(ty.clone(), expected_ty, args[i].get_span_pos(), None)
self.must_assignable_to(ty.clone(), expected_ty, args[i].get_span_pos(), None, true)
}
// Do keyword argument type check
for (i, (arg_name, kwarg_ty)) in kwarg_types.iter().enumerate() {
Expand Down Expand Up @@ -132,6 +132,7 @@ impl<'ctx> Resolver<'ctx> {
expected_types[0].clone(),
kwargs[i].get_span_pos(),
None,
true,
);
};
}
Expand Down
2 changes: 2 additions & 0 deletions kclvm/sema/src/resolver/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ impl<'ctx> Resolver<'ctx> {
ty,
key.get_span_pos(),
Some(obj_last.get_span_pos()),
true,
);
}
self.clear_config_expr_context(stack_depth, false);
Expand Down Expand Up @@ -489,6 +490,7 @@ impl<'ctx> Resolver<'ctx> {
Some(key) => match &key.node {
ast::Expr::Identifier(identifier) => {
let mut val_ty = self.expr(value);

for _ in 0..identifier.names.len() - 1 {
val_ty = Type::dict_ref(self.str_ty(), val_ty.clone());
}
Expand Down
61 changes: 57 additions & 4 deletions kclvm/sema/src/resolver/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty.clone(),
unification_stmt.target.get_span_pos(),
None,
true,
);
if !ty.is_any() && expected_ty.is_any() {
self.set_type_to_scope(&names[0].node, ty, &names[0]);
Expand Down Expand Up @@ -184,7 +185,19 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty.clone(),
target.get_span_pos(),
None,
true,
);

let upgrade_schema_type = self.upgrade_dict_to_schema(
value_ty.clone(),
expected_ty.clone(),
&assign_stmt.value.get_span_pos(),
);
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);

if !value_ty.is_any() && expected_ty.is_any() && assign_stmt.ty.is_none() {
self.set_type_to_scope(name, value_ty.clone(), &target.node.names[0]);
if let Some(schema_ty) = &self.ctx.schema {
Expand All @@ -203,7 +216,23 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
value_ty = self.expr(&assign_stmt.value);
// Check type annotation if exists.
self.check_assignment_type_annotation(assign_stmt, value_ty.clone());
self.must_assignable_to(value_ty.clone(), expected_ty, target.get_span_pos(), None)
self.must_assignable_to(
value_ty.clone(),
expected_ty.clone(),
target.get_span_pos(),
None,
true,
);

let upgrade_schema_type = self.upgrade_dict_to_schema(
value_ty.clone(),
expected_ty.clone(),
&assign_stmt.value.get_span_pos(),
);
self.node_ty_map.insert(
self.get_node_key(assign_stmt.value.id.clone()),
upgrade_schema_type.clone(),
);
}
}
value_ty
Expand Down Expand Up @@ -261,6 +290,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
aug_assign_stmt.target.get_span_pos(),
None,
true,
);
self.ctx.l_value = false;
new_target_ty
Expand Down Expand Up @@ -427,6 +457,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
schema_attr.name.get_span_pos(),
None,
true,
);
}
// Assign
Expand All @@ -435,6 +466,7 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
expected_ty,
schema_attr.name.get_span_pos(),
None,
true,
),
},
None => bug!("invalid ast schema attr op kind"),
Expand Down Expand Up @@ -1068,7 +1100,13 @@ impl<'ctx> MutSelfTypedResultWalker<'ctx> for Resolver<'ctx> {
let real_ret_ty = self.stmts(&lambda_expr.body);
self.leave_scope();
self.ctx.in_lambda_expr.pop();
self.must_assignable_to(real_ret_ty.clone(), ret_ty.clone(), (start, end), None);
self.must_assignable_to(
real_ret_ty.clone(),
ret_ty.clone(),
(start, end),
None,
true,
);
if !real_ret_ty.is_any() && ret_ty.is_any() && lambda_expr.return_ty.is_none() {
ret_ty = real_ret_ty;
}
Expand Down Expand Up @@ -1232,9 +1270,24 @@ impl<'ctx> Resolver<'ctx> {
self.ctx.start_pos = start;
self.ctx.end_pos = end;
}

let expected_ty = match self.ctx.config_expr_context.last() {
Some(ty) => ty.clone().map(|o| o.ty),
None => None,
};

let ty = self.walk_expr(&expr.node);
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());

if let Some(expected_ty) = expected_ty {
let upgrade_ty =
self.upgrade_dict_to_schema(ty.clone(), expected_ty, &expr.get_span_pos());
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), upgrade_ty);
} else {
self.node_ty_map
.insert(self.get_node_key(expr.id.clone()), ty.clone());
}

ty
}

Expand Down
1 change: 1 addition & 0 deletions kclvm/sema/src/resolver/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ impl<'ctx> Resolver<'ctx> {
expected_ty,
index_signature_node.get_span_pos(),
None,
true,
);
}
}
Expand Down
Loading
Loading