Skip to content

Commit

Permalink
feat: enhance runtime type cast and check for lambda arguments and re…
Browse files Browse the repository at this point in the history
…turn values (#1529)

Signed-off-by: peefy <[email protected]>
  • Loading branch information
Peefy authored Aug 1, 2024
1 parent bbac702 commit fe15ef9
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 21 deletions.
54 changes: 47 additions & 7 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2179,9 +2179,21 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
}
}
self.walk_arguments(&lambda_expr.args, args, kwargs);
let val = self
let mut val = self
.walk_stmts(&lambda_expr.body)
.expect(kcl_error::COMPILE_ERROR_MSG);
if let Some(ty) = &lambda_expr.return_ty {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
val = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
val,
type_annotation,
self.bool_value(false),
],
);
}
self.builder.build_return(Some(&val));
// Exist the function
self.builder.position_at_end(func_before_block);
Expand Down Expand Up @@ -2731,23 +2743,39 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
kwargs: BasicValueEnum<'ctx>,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::COMPILE_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
// Arguments are immutable, so we place them in different scopes.
self.store_argument_in_current_scope(&arg_name.get_name());
self.walk_identifier_with_ctx(arg_name, &ast::ExprContext::Store, Some(arg_value))
Expand All @@ -2756,7 +2784,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
// for loop in 0..argument_len in LLVM begin
let argument_len = self.build_call(&ApiFunc::kclvm_list_len.name(), &[args]);
let end_block = self.append_block("");
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = self.builder.build_int_compare(
IntPredicate::ULT,
Expand All @@ -2768,14 +2796,26 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
self.builder
.build_conditional_branch(is_in_range, next_block, end_block);
self.builder.position_at_end(next_block);
let arg_value = self.build_call(
let mut arg_value = self.build_call(
&ApiFunc::kclvm_list_get_option.name(),
&[
self.current_runtime_ctx_ptr(),
args,
self.native_int_value(i as i32),
],
);
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
self.store_variable(&arg_name.names[0].node, arg_value);
}
// for loop in 0..argument_len in LLVM end
Expand Down
6 changes: 5 additions & 1 deletion kclvm/evaluator/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use kclvm_runtime::ValueRef;
use scopeguard::defer;

use crate::proxy::Proxy;
use crate::ty::type_pack_and_check;
use crate::Evaluator;
use crate::{error as kcl_error, EvalContext};

Expand Down Expand Up @@ -125,8 +126,11 @@ pub fn func_body(
}
// Evaluate arguments and keyword arguments and store values to local variables.
s.walk_arguments(&ctx.node.args, args, kwargs);
let result = s
let mut result = s
.walk_stmts(&ctx.node.body)
.expect(kcl_error::RUNTIME_ERROR_MSG);
if let Some(ty) = &ctx.node.return_ty {
result = type_pack_and_check(s, &result, vec![&ty.node.to_string()], false);
}
result
}
24 changes: 18 additions & 6 deletions kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1449,23 +1449,31 @@ impl<'ctx> Evaluator<'ctx> {
kwargs: &ValueRef,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::RUNTIME_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
// Arguments are immutable, so we place them in different scopes.
let name = arg_name.get_name();
self.store_argument_in_current_scope(&name);
Expand All @@ -1477,14 +1485,18 @@ impl<'ctx> Evaluator<'ctx> {
}
// Positional arguments
let argument_len = args.len();
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = i < argument_len;
if is_in_range {
let arg_value = match args.list_get_option(i as isize) {
let mut arg_value = match args.list_get_option(i as isize) {
Some(v) => v,
None => self.undefined_value(),
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
self.store_variable(&arg_name.names[0].node, arg_value);
} else {
break;
Expand Down
12 changes: 12 additions & 0 deletions kclvm/runtime/src/value/val_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ pub fn check_type(value: &ValueRef, tpe: &str, strict: bool) -> bool {
// if value type is a built-in type e.g. str, int, float, bool
if match_builtin_type(value, tpe) {
return true;
} else if match_function_type(value, tpe) {
return true;
}
if value.is_schema() {
if strict {
Expand Down Expand Up @@ -532,6 +534,16 @@ pub fn match_builtin_type(value: &ValueRef, tpe: &str) -> bool {
value.type_str() == *tpe || (value.type_str() == BUILTIN_TYPE_INT && tpe == BUILTIN_TYPE_FLOAT)
}

/// match_function_type returns the value wether match the given the function type string
#[inline]
pub fn match_function_type(value: &ValueRef, tpe: &str) -> bool {
value.type_str() == *tpe
|| (value.type_str() == KCL_TYPE_FUNCTION
&& tpe.contains("(")
&& tpe.contains(")")
&& tpe.contains("->"))
}

/// is_literal_type returns the type string whether is a literal type
pub fn is_literal_type(tpe: &str) -> bool {
if KCL_NAME_CONSTANTS.contains(&tpe) {
Expand Down
29 changes: 22 additions & 7 deletions kclvm/sema/src/resolver/ty_erasure.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use kclvm_ast::ast;
use kclvm_ast::walker::MutSelfMutWalker;
use kclvm_ast::{ast, walk_if_mut, walk_list_mut};

#[derive(Default)]
struct TypeErasureTransformer;
Expand All @@ -14,14 +14,14 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
schema_index_signature.node.value_ty.node = FUNCTION.to_string().into();
}
}
for item in schema_stmt.body.iter_mut() {
if let kclvm_ast::ast::Stmt::SchemaAttr(attr) = &mut item.node {
self.walk_schema_attr(attr);
}
}
walk_if_mut!(self, walk_arguments, schema_stmt.args);
walk_list_mut!(self, walk_call_expr, schema_stmt.decorators);
walk_list_mut!(self, walk_check_expr, schema_stmt.checks);
walk_list_mut!(self, walk_stmt, schema_stmt.body);
}

fn walk_schema_attr(&mut self, schema_attr: &'ctx mut ast::SchemaAttr) {
walk_list_mut!(self, walk_call_expr, schema_attr.decorators);
walk_if_mut!(self, walk_expr, schema_attr.value);
if let kclvm_ast::ast::Type::Function(_) = schema_attr.ty.as_ref().node {
schema_attr.ty.node = FUNCTION.to_string().into();
}
Expand All @@ -34,6 +34,7 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
}
}
}
self.walk_expr(&mut assign_stmt.value.node);
}
fn walk_type_alias_stmt(&mut self, type_alias_stmt: &'ctx mut ast::TypeAliasStmt) {
if let kclvm_ast::ast::Type::Function(_) = type_alias_stmt.ty.as_ref().node {
Expand All @@ -46,6 +47,20 @@ impl<'ctx> MutSelfMutWalker<'ctx> for TypeErasureTransformer {
ty.node = FUNCTION.to_string().into();
}
}
for default in arguments.defaults.iter_mut() {
if let Some(d) = default.as_deref_mut() {
self.walk_expr(&mut d.node)
}
}
}
fn walk_lambda_expr(&mut self, lambda_expr: &'ctx mut ast::LambdaExpr) {
walk_if_mut!(self, walk_arguments, lambda_expr.args);
walk_list_mut!(self, walk_stmt, lambda_expr.body);
if let Some(ty) = lambda_expr.return_ty.as_mut() {
if let kclvm_ast::ast::Type::Function(_) = ty.as_ref().node {
ty.node = FUNCTION.to_string().into();
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda family: ProviderFamily -> ProviderFamily {
family
}

v = providerFamily({
version: "1.6.0"
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda -> ProviderFamily {
{
version: "1.6.0"
}
}

v = providerFamily()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true

0 comments on commit fe15ef9

Please sign in to comment.