diff --git a/src/binder/aggregate.rs b/src/binder/aggregate.rs index c8c1764d..096799c0 100644 --- a/src/binder/aggregate.rs +++ b/src/binder/aggregate.rs @@ -136,6 +136,10 @@ impl<'a, T: Transaction> Binder<'a, T> { self.visit_column_agg_expr(expr)?; } } + ScalarExpression::Position { expr, in_expr } => { + self.visit_column_agg_expr(expr)?; + self.visit_column_agg_expr(in_expr)?; + } ScalarExpression::Constant(_) | ScalarExpression::ColumnRef { .. } => (), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) @@ -356,6 +360,11 @@ impl<'a, T: Transaction> Binder<'a, T> { } Ok(()) } + ScalarExpression::Position { expr, in_expr } => { + self.validate_having_orderby(expr)?; + self.validate_having_orderby(in_expr)?; + Ok(()) + } ScalarExpression::Constant(_) => Ok(()), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) diff --git a/src/binder/expr.rs b/src/binder/expr.rs index 06a72a3d..b021c904 100644 --- a/src/binder/expr.rs +++ b/src/binder/expr.rs @@ -104,6 +104,10 @@ impl<'a, T: Transaction> Binder<'a, T> { from_expr, }) } + Expr::Position { expr, r#in } => Ok(ScalarExpression::Position { + expr: Box::new(self.bind_expr(expr)?), + in_expr: Box::new(self.bind_expr(r#in)?), + }), Expr::Subquery(subquery) => { let (sub_query, column) = self.bind_subquery(subquery)?; self.context.sub_query(SubQueryType::SubQuery(sub_query)); diff --git a/src/expression/evaluator.rs b/src/expression/evaluator.rs index f00239d2..f5331323 100644 --- a/src/expression/evaluator.rs +++ b/src/expression/evaluator.rs @@ -168,6 +168,19 @@ impl ScalarExpression { Ok(Arc::new(DataValue::Utf8(None))) } } + ScalarExpression::Position { expr, in_expr } => { + let unpack = |expr: &ScalarExpression| -> Result { + Ok(DataValue::clone(expr.eval(tuple, schema)?.as_ref()) + .cast(&LogicalType::Varchar(None))? + .utf8() + .unwrap_or("".to_owned())) + }; + let pattern = unpack(expr)?; + let str = unpack(in_expr)?; + Ok(Arc::new(DataValue::Int32(Some( + str.find(&pattern).map(|pos| pos as i32 + 1).unwrap_or(0), + )))) + } ScalarExpression::Reference { pos, .. } => { return Ok(tuple .values diff --git a/src/expression/mod.rs b/src/expression/mod.rs index 179c2d5e..7c9cb659 100644 --- a/src/expression/mod.rs +++ b/src/expression/mod.rs @@ -79,6 +79,10 @@ pub enum ScalarExpression { for_expr: Option>, from_expr: Option>, }, + Position { + expr: Box, + in_expr: Box, + }, // Temporary expression used for expression substitution Empty, Reference { @@ -201,6 +205,10 @@ impl ScalarExpression { expr.try_reference(output_exprs); } } + ScalarExpression::Position { expr, in_expr } => { + expr.try_reference(output_exprs); + in_expr.try_reference(output_exprs); + } ScalarExpression::Empty => unreachable!(), ScalarExpression::Constant(_) | ScalarExpression::ColumnRef(_) @@ -294,6 +302,9 @@ impl ScalarExpression { Some(true) ) } + ScalarExpression::Position { expr, in_expr } => { + expr.has_count_star() || in_expr.has_count_star() + } ScalarExpression::Empty => unreachable!(), ScalarExpression::Reference { expr, .. } => expr.has_count_star(), ScalarExpression::Tuple(args) => args.iter().any(Self::has_count_star), @@ -372,6 +383,7 @@ impl ScalarExpression { | ScalarExpression::In { .. } | ScalarExpression::Between { .. } => LogicalType::Boolean, ScalarExpression::SubString { .. } => LogicalType::Varchar(None), + ScalarExpression::Position { .. } => LogicalType::Integer, ScalarExpression::Alias { expr, .. } | ScalarExpression::Reference { expr, .. } => { expr.return_type() } @@ -448,6 +460,10 @@ impl ScalarExpression { columns_collect(from_expr, vec, only_column_ref); } } + ScalarExpression::Position { expr, in_expr } => { + columns_collect(expr, vec, only_column_ref); + columns_collect(in_expr, vec, only_column_ref); + } ScalarExpression::Constant(_) => (), ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::If { @@ -537,6 +553,9 @@ impl ScalarExpression { Some(true) ) } + ScalarExpression::Position { expr, in_expr } => { + expr.has_agg_call() || in_expr.has_agg_call() + } ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) | ScalarExpression::Function(ScalarFunction { args, .. }) @@ -671,6 +690,13 @@ impl ScalarExpression { op("for", for_expr), ) } + ScalarExpression::Position { expr, in_expr } => { + format!( + "position({} in {})", + expr.output_name(), + in_expr.output_name() + ) + } ScalarExpression::Reference { expr, .. } => expr.output_name(), ScalarExpression::Empty => unreachable!(), ScalarExpression::Tuple(args) => { diff --git a/src/expression/range_detacher.rs b/src/expression/range_detacher.rs index a57113f4..e1c3acaf 100644 --- a/src/expression/range_detacher.rs +++ b/src/expression/range_detacher.rs @@ -223,6 +223,7 @@ impl<'a> RangeDetacher<'a> { | ScalarExpression::In { expr, .. } | ScalarExpression::Between { expr, .. } | ScalarExpression::SubString { expr, .. } => self.detach(expr), + ScalarExpression::Position { expr, .. } => self.detach(expr), ScalarExpression::IsNull { expr, negated, .. } => match expr.as_ref() { ScalarExpression::ColumnRef(column) => { if let (Some(col_id), Some(col_table)) = (column.id(), column.table_name()) { @@ -248,6 +249,7 @@ impl<'a> RangeDetacher<'a> { | ScalarExpression::In { .. } | ScalarExpression::Between { .. } | ScalarExpression::SubString { .. } + | ScalarExpression::Position { .. } | ScalarExpression::Function(_) | ScalarExpression::If { .. } | ScalarExpression::IfNull { .. } diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index 9f3e9698..b1fa515b 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -84,6 +84,9 @@ impl ScalarExpression { .map(|expr| expr.exist_column(table_name, col_id)) == Some(true) } + ScalarExpression::Position { expr, in_expr } => { + expr.exist_column(table_name, col_id) || in_expr.exist_column(table_name, col_id) + } ScalarExpression::Constant(_) => false, ScalarExpression::Reference { .. } | ScalarExpression::Empty => unreachable!(), ScalarExpression::If { diff --git a/tests/slt/sql_2016/E021_11.slt b/tests/slt/sql_2016/E021_11.slt index 4f0fc2d0..66fbbabe 100644 --- a/tests/slt/sql_2016/E021_11.slt +++ b/tests/slt/sql_2016/E021_11.slt @@ -1,11 +1,9 @@ # E021-11: POSITION function -# TODO: POSITION() - -# query I -# SELECT POSITION ( 'foo' IN 'bar' ) -# ---- -# 0 +query I +SELECT POSITION ( 'foo' IN 'bar' ) +---- +0 # query I # SELECT POSITION ( 'foo' IN 'bar' USING CHARACTERS )