diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 915ed7416c6c..f9d603a7ed67 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -22,6 +22,7 @@ use common_ast::ast::ColumnID; use common_ast::ast::Expr; use common_ast::ast::Identifier; use common_ast::ast::IntervalKind as ASTIntervalKind; +use common_ast::ast::Lambda; use common_ast::ast::Literal; use common_ast::ast::MapAccessor; use common_ast::ast::Query; @@ -812,15 +813,6 @@ impl<'a> TypeChecker<'a> { } if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) { - if matches!( - self.bind_context.expr_context, - ExprContext::InLambdaFunction - ) { - return Err(ErrorCode::SemanticError( - "window functions can not be used in lambda function".to_string(), - ) - .set_span(*span)); - } // general window function if window.is_none() { return Err(ErrorCode::SemanticError(format!( @@ -835,16 +827,6 @@ impl<'a> TypeChecker<'a> { self.resolve_window(*span, display_name, window, func) .await? } else if AggregateFunctionFactory::instance().contains(func_name) { - if matches!( - self.bind_context.expr_context, - ExprContext::InLambdaFunction - ) { - return Err(ErrorCode::SemanticError( - "aggregate functions can not be used in lambda function".to_string(), - ) - .set_span(*span)); - } - let in_window = self.in_window_function; self.in_window_function = self.in_window_function || window.is_some(); let in_aggregate_function = self.in_aggregate_function; @@ -866,113 +848,14 @@ impl<'a> TypeChecker<'a> { Box::new((new_agg_func.into(), data_type)) } } else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) { - if matches!( - self.bind_context.expr_context, - ExprContext::InLambdaFunction - ) { - return Err(ErrorCode::SemanticError( - "lambda functions can not be used in lambda function".to_string(), - ) - .set_span(*span)); - } if lambda.is_none() { return Err(ErrorCode::SemanticError(format!( "function {func_name} must have a lambda expression", ))); } let lambda = lambda.as_ref().unwrap(); - - let params = lambda - .params - .iter() - .map(|param| param.name.to_lowercase()) - .collect::>(); - - // TODO: support multiple params - if params.len() != 1 { - return Err(ErrorCode::SemanticError(format!( - "incorrect number of parameters in lambda function, {func_name} expects 1 parameter", - ))); - } - - if args.len() != 1 { - return Err(ErrorCode::SemanticError(format!( - "invalid arguments for lambda function, {func_name} expects 1 argument" - ))); - } - let box (arg, arg_type) = self.resolve(args[0]).await?; - - let inner_ty = match arg_type.remove_nullable() { - DataType::Array(box inner_ty) => inner_ty.clone(), - DataType::Null | DataType::EmptyArray => DataType::Null, - _ => { - return Err(ErrorCode::SemanticError( - "invalid arguments for lambda function, argument data type must be array".to_string() - )); - } - }; - let box (lambda_expr, lambda_type) = - parse_lambda_expr(self.ctx.clone(), ¶ms[0], &inner_ty, &lambda.expr)?; - - let return_type = if func_name == "array_filter" { - if lambda_type.remove_nullable() == DataType::Boolean { - arg_type.clone() - } else { - return Err(ErrorCode::SemanticError( - "invalid lambda function for `array_filter`, the result data type of lambda function must be boolean".to_string() - )); - } - } else if arg_type.is_nullable() { - DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type)))) - } else { - DataType::Array(Box::new(lambda_type)) - }; - - match arg_type.remove_nullable() { - // Null and Empty array can convert to ConstantExpr - DataType::Null => Box::new(( - ConstantExpr { - span: *span, - value: Scalar::Null, - } - .into(), - DataType::Null, - )), - DataType::EmptyArray => Box::new(( - ConstantExpr { - span: *span, - value: Scalar::EmptyArray, - } - .into(), - DataType::EmptyArray, - )), - _ => { - // generate lambda expression - let lambda_field = DataField::new("0", inner_ty.clone()); - let lambda_schema = DataSchema::new(vec![lambda_field]); - - let expr = lambda_expr.type_check(&lambda_schema)?.project_column_ref( - |index| lambda_schema.index_of(&index.to_string()).unwrap(), - ); - let (expr, _) = - ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS); - let remote_lambda_expr = expr.as_remote_expr(); - let lambda_display = format!("{} -> {}", params[0], expr.sql_display()); - - Box::new(( - LambdaFunc { - span: *span, - func_name: func_name.to_string(), - args: vec![arg], - lambda_expr: Box::new(remote_lambda_expr), - lambda_display, - return_type: Box::new(return_type.clone()), - } - .into(), - return_type, - )) - } - } + self.resolve_lambda_function(*span, func_name, &args, lambda) + .await? } else { // Scalar function let params = params @@ -1482,6 +1365,15 @@ impl<'a> TypeChecker<'a> { func_name: &str, args: &[&Expr], ) -> Result { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "window functions can not be used in lambda function".to_string(), + ) + .set_span(span)); + } // try to resolve window function without arguments first if let Ok(window_func) = WindowFuncType::from_name(func_name) { return Ok(window_func); @@ -1707,6 +1599,15 @@ impl<'a> TypeChecker<'a> { params: &[Literal], args: &[&Expr], ) -> Result<(AggregateFunction, DataType)> { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "aggregate functions can not be used in lambda function".to_string(), + ) + .set_span(span)); + } if self.in_aggregate_function { if self.in_window_function { // The aggregate function can be in window function call, @@ -1798,6 +1699,120 @@ impl<'a> TypeChecker<'a> { Ok((new_agg_func, data_type)) } + #[async_backtrace::framed] + async fn resolve_lambda_function( + &mut self, + span: Span, + func_name: &str, + args: &[&Expr], + lambda: &Lambda, + ) -> Result> { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "lambda functions can not be used in lambda function".to_string(), + ) + .set_span(span)); + } + let params = lambda + .params + .iter() + .map(|param| param.name.to_lowercase()) + .collect::>(); + + // TODO: support multiple params + if params.len() != 1 { + return Err(ErrorCode::SemanticError(format!( + "incorrect number of parameters in lambda function, {func_name} expects 1 parameter", + ))); + } + + if args.len() != 1 { + return Err(ErrorCode::SemanticError(format!( + "invalid arguments for lambda function, {func_name} expects 1 argument" + ))); + } + let box (arg, arg_type) = self.resolve(args[0]).await?; + + let inner_ty = match arg_type.remove_nullable() { + DataType::Array(box inner_ty) => inner_ty.clone(), + DataType::Null | DataType::EmptyArray => DataType::Null, + _ => { + return Err(ErrorCode::SemanticError( + "invalid arguments for lambda function, argument data type must be array" + .to_string(), + )); + } + }; + let box (lambda_expr, lambda_type) = + parse_lambda_expr(self.ctx.clone(), ¶ms[0], &inner_ty, &lambda.expr)?; + + let return_type = if func_name == "array_filter" { + if lambda_type.remove_nullable() == DataType::Boolean { + arg_type.clone() + } else { + return Err(ErrorCode::SemanticError( + "invalid lambda function for `array_filter`, the result data type of lambda function must be boolean".to_string() + )); + } + } else if arg_type.is_nullable() { + DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type)))) + } else { + DataType::Array(Box::new(lambda_type)) + }; + + let (lambda_func, data_type) = match arg_type.remove_nullable() { + // Null and Empty array can convert to ConstantExpr + DataType::Null => ( + ConstantExpr { + span, + value: Scalar::Null, + } + .into(), + DataType::Null, + ), + DataType::EmptyArray => ( + ConstantExpr { + span, + value: Scalar::EmptyArray, + } + .into(), + DataType::EmptyArray, + ), + _ => { + // generate lambda expression + let lambda_field = DataField::new("0", inner_ty.clone()); + let lambda_schema = DataSchema::new(vec![lambda_field]); + + let expr = lambda_expr + .type_check(&lambda_schema)? + .project_column_ref(|index| { + lambda_schema.index_of(&index.to_string()).unwrap() + }); + let (expr, _) = ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS); + let remote_lambda_expr = expr.as_remote_expr(); + let lambda_display = format!("{} -> {}", params[0], expr.sql_display()); + + ( + LambdaFunc { + span, + func_name: func_name.to_string(), + args: vec![arg], + lambda_expr: Box::new(remote_lambda_expr), + lambda_display, + return_type: Box::new(return_type.clone()), + } + .into(), + return_type, + ) + } + }; + + Ok(Box::new((lambda_func, data_type))) + } + /// Resolve function call. #[async_backtrace::framed] pub async fn resolve_function( diff --git a/tests/sqllogictests/suites/query/02_function/02_0061_function_array.test b/tests/sqllogictests/suites/query/02_function/02_0061_function_array.test index f7994466f783..3dd21f6c2b32 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0061_function_array.test +++ b/tests/sqllogictests/suites/query/02_function/02_0061_function_array.test @@ -228,6 +228,9 @@ select array_transform(col1, A -> a * 2), array_apply(col2, B -> upper(B)) from statement error 1065 select array_transform([1, 2], x -> y + 1) +statement error 1065 +select array_transform([1, 2], x -> count(*)) + query T select array_filter([5, -6, NULL, 7], x -> x > 0) ----