Skip to content

Commit

Permalink
fix(query): fix count all in lambda function (databendlabs#13991)
Browse files Browse the repository at this point in the history
* fix(query): fix count all in lambda function

* refactor resolve lambda function

* add tests

* fix
  • Loading branch information
b41sh authored Dec 12, 2023
1 parent 313288e commit 193ed56
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 120 deletions.
255 changes: 135 additions & 120 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use common_ast::ast::ColumnID;
use common_ast::ast::Expr;
use common_ast::ast::Identifier;
use common_ast::ast::IntervalKind as ASTIntervalKind;
use common_ast::ast::Lambda;
use common_ast::ast::Literal;
use common_ast::ast::MapAccessor;
use common_ast::ast::Query;
Expand Down Expand Up @@ -812,15 +813,6 @@ impl<'a> TypeChecker<'a> {
}

if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"window functions can not be used in lambda function".to_string(),
)
.set_span(*span));
}
// general window function
if window.is_none() {
return Err(ErrorCode::SemanticError(format!(
Expand All @@ -835,16 +827,6 @@ impl<'a> TypeChecker<'a> {
self.resolve_window(*span, display_name, window, func)
.await?
} else if AggregateFunctionFactory::instance().contains(func_name) {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"aggregate functions can not be used in lambda function".to_string(),
)
.set_span(*span));
}

let in_window = self.in_window_function;
self.in_window_function = self.in_window_function || window.is_some();
let in_aggregate_function = self.in_aggregate_function;
Expand All @@ -866,113 +848,14 @@ impl<'a> TypeChecker<'a> {
Box::new((new_agg_func.into(), data_type))
}
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"lambda functions can not be used in lambda function".to_string(),
)
.set_span(*span));
}
if lambda.is_none() {
return Err(ErrorCode::SemanticError(format!(
"function {func_name} must have a lambda expression",
)));
}
let lambda = lambda.as_ref().unwrap();

let params = lambda
.params
.iter()
.map(|param| param.name.to_lowercase())
.collect::<Vec<_>>();

// TODO: support multiple params
if params.len() != 1 {
return Err(ErrorCode::SemanticError(format!(
"incorrect number of parameters in lambda function, {func_name} expects 1 parameter",
)));
}

if args.len() != 1 {
return Err(ErrorCode::SemanticError(format!(
"invalid arguments for lambda function, {func_name} expects 1 argument"
)));
}
let box (arg, arg_type) = self.resolve(args[0]).await?;

let inner_ty = match arg_type.remove_nullable() {
DataType::Array(box inner_ty) => inner_ty.clone(),
DataType::Null | DataType::EmptyArray => DataType::Null,
_ => {
return Err(ErrorCode::SemanticError(
"invalid arguments for lambda function, argument data type must be array".to_string()
));
}
};
let box (lambda_expr, lambda_type) =
parse_lambda_expr(self.ctx.clone(), &params[0], &inner_ty, &lambda.expr)?;

let return_type = if func_name == "array_filter" {
if lambda_type.remove_nullable() == DataType::Boolean {
arg_type.clone()
} else {
return Err(ErrorCode::SemanticError(
"invalid lambda function for `array_filter`, the result data type of lambda function must be boolean".to_string()
));
}
} else if arg_type.is_nullable() {
DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type))))
} else {
DataType::Array(Box::new(lambda_type))
};

match arg_type.remove_nullable() {
// Null and Empty array can convert to ConstantExpr
DataType::Null => Box::new((
ConstantExpr {
span: *span,
value: Scalar::Null,
}
.into(),
DataType::Null,
)),
DataType::EmptyArray => Box::new((
ConstantExpr {
span: *span,
value: Scalar::EmptyArray,
}
.into(),
DataType::EmptyArray,
)),
_ => {
// generate lambda expression
let lambda_field = DataField::new("0", inner_ty.clone());
let lambda_schema = DataSchema::new(vec![lambda_field]);

let expr = lambda_expr.type_check(&lambda_schema)?.project_column_ref(
|index| lambda_schema.index_of(&index.to_string()).unwrap(),
);
let (expr, _) =
ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS);
let remote_lambda_expr = expr.as_remote_expr();
let lambda_display = format!("{} -> {}", params[0], expr.sql_display());

Box::new((
LambdaFunc {
span: *span,
func_name: func_name.to_string(),
args: vec![arg],
lambda_expr: Box::new(remote_lambda_expr),
lambda_display,
return_type: Box::new(return_type.clone()),
}
.into(),
return_type,
))
}
}
self.resolve_lambda_function(*span, func_name, &args, lambda)
.await?
} else {
// Scalar function
let params = params
Expand Down Expand Up @@ -1482,6 +1365,15 @@ impl<'a> TypeChecker<'a> {
func_name: &str,
args: &[&Expr],
) -> Result<WindowFuncType> {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"window functions can not be used in lambda function".to_string(),
)
.set_span(span));
}
// try to resolve window function without arguments first
if let Ok(window_func) = WindowFuncType::from_name(func_name) {
return Ok(window_func);
Expand Down Expand Up @@ -1707,6 +1599,15 @@ impl<'a> TypeChecker<'a> {
params: &[Literal],
args: &[&Expr],
) -> Result<(AggregateFunction, DataType)> {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"aggregate functions can not be used in lambda function".to_string(),
)
.set_span(span));
}
if self.in_aggregate_function {
if self.in_window_function {
// The aggregate function can be in window function call,
Expand Down Expand Up @@ -1798,6 +1699,120 @@ impl<'a> TypeChecker<'a> {
Ok((new_agg_func, data_type))
}

#[async_backtrace::framed]
async fn resolve_lambda_function(
&mut self,
span: Span,
func_name: &str,
args: &[&Expr],
lambda: &Lambda,
) -> Result<Box<(ScalarExpr, DataType)>> {
if matches!(
self.bind_context.expr_context,
ExprContext::InLambdaFunction
) {
return Err(ErrorCode::SemanticError(
"lambda functions can not be used in lambda function".to_string(),
)
.set_span(span));
}
let params = lambda
.params
.iter()
.map(|param| param.name.to_lowercase())
.collect::<Vec<_>>();

// TODO: support multiple params
if params.len() != 1 {
return Err(ErrorCode::SemanticError(format!(
"incorrect number of parameters in lambda function, {func_name} expects 1 parameter",
)));
}

if args.len() != 1 {
return Err(ErrorCode::SemanticError(format!(
"invalid arguments for lambda function, {func_name} expects 1 argument"
)));
}
let box (arg, arg_type) = self.resolve(args[0]).await?;

let inner_ty = match arg_type.remove_nullable() {
DataType::Array(box inner_ty) => inner_ty.clone(),
DataType::Null | DataType::EmptyArray => DataType::Null,
_ => {
return Err(ErrorCode::SemanticError(
"invalid arguments for lambda function, argument data type must be array"
.to_string(),
));
}
};
let box (lambda_expr, lambda_type) =
parse_lambda_expr(self.ctx.clone(), &params[0], &inner_ty, &lambda.expr)?;

let return_type = if func_name == "array_filter" {
if lambda_type.remove_nullable() == DataType::Boolean {
arg_type.clone()
} else {
return Err(ErrorCode::SemanticError(
"invalid lambda function for `array_filter`, the result data type of lambda function must be boolean".to_string()
));
}
} else if arg_type.is_nullable() {
DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type))))
} else {
DataType::Array(Box::new(lambda_type))
};

let (lambda_func, data_type) = match arg_type.remove_nullable() {
// Null and Empty array can convert to ConstantExpr
DataType::Null => (
ConstantExpr {
span,
value: Scalar::Null,
}
.into(),
DataType::Null,
),
DataType::EmptyArray => (
ConstantExpr {
span,
value: Scalar::EmptyArray,
}
.into(),
DataType::EmptyArray,
),
_ => {
// generate lambda expression
let lambda_field = DataField::new("0", inner_ty.clone());
let lambda_schema = DataSchema::new(vec![lambda_field]);

let expr = lambda_expr
.type_check(&lambda_schema)?
.project_column_ref(|index| {
lambda_schema.index_of(&index.to_string()).unwrap()
});
let (expr, _) = ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS);
let remote_lambda_expr = expr.as_remote_expr();
let lambda_display = format!("{} -> {}", params[0], expr.sql_display());

(
LambdaFunc {
span,
func_name: func_name.to_string(),
args: vec![arg],
lambda_expr: Box::new(remote_lambda_expr),
lambda_display,
return_type: Box::new(return_type.clone()),
}
.into(),
return_type,
)
}
};

Ok(Box::new((lambda_func, data_type)))
}

/// Resolve function call.
#[async_backtrace::framed]
pub async fn resolve_function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ select array_transform(col1, A -> a * 2), array_apply(col2, B -> upper(B)) from
statement error 1065
select array_transform([1, 2], x -> y + 1)

statement error 1065
select array_transform([1, 2], x -> count(*))

query T
select array_filter([5, -6, NULL, 7], x -> x > 0)
----
Expand Down

0 comments on commit 193ed56

Please sign in to comment.