From 4e2787a0f2f1b4f9f463da90097ff1869d16176d Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Thu, 29 Aug 2024 22:14:26 +0800 Subject: [PATCH 1/4] initial commit Signed-off-by: Lantao Jin --- .../sql/analysis/AnalysisContext.java | 4 + .../org/opensearch/sql/analysis/Analyzer.java | 14 ++ .../ExpressionReferenceOptimizer.java | 7 + .../sql/analysis/symbol/SymbolTable.java | 15 ++ .../sql/ast/AbstractNodeVisitor.java | 5 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 9 + .../org/opensearch/sql/ast/tree/Join.java | 51 ++++ .../function/BuiltinFunctionName.java | 4 + .../sql/planner/DefaultImplementor.java | 84 +++++++ .../sql/planner/logical/LogicalJoin.java | 37 +++ .../sql/planner/logical/LogicalPlanDSL.java | 10 + .../logical/LogicalPlanNodeVisitor.java | 4 + .../physical/PhysicalPlanNodeVisitor.java | 5 + .../physical/join/HashJoinOperator.java | 225 +++++++++++++++++ .../planner/physical/join/JoinOperator.java | 21 ++ .../physical/join/JoinPredicatesHelper.java | 143 +++++++++++ .../physical/join/NestedLoopJoinOperator.java | 231 ++++++++++++++++++ .../opensearch/sql/analysis/AnalyzerTest.java | 206 ++++++++++++++++ .../sql/analysis/AnalyzerTestBase.java | 3 +- .../sql/analysis/SelectAnalyzeTest.java | 42 ++++ .../logical/LogicalPlanNodeVisitorTest.java | 68 +++++- .../planner/physical/JoinOperatorTest.java | 54 ++++ .../physical/PhysicalPlanTestBase.java | 22 +- ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 10 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 51 +++- .../opensearch/sql/ppl/parser/AstBuilder.java | 44 ++++ .../sql/ppl/parser/AstExpressionBuilder.java | 11 + 27 files changed, 1363 insertions(+), 17 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/ast/tree/Join.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java index f1f29e9b38..b4be232b0f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java +++ b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java @@ -41,6 +41,10 @@ public void push() { environment = new TypeEnvironment(environment); } + public void cleanFields() { + environment.clearAllFields(); + } + /** * Return current environment. * diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index d5e8b93b13..8d8b95706f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -48,6 +48,7 @@ import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; @@ -88,6 +89,7 @@ import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalJoin; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; @@ -136,6 +138,18 @@ public LogicalPlan analyze(UnresolvedPlan unresolved, AnalysisContext context) { return unresolved.accept(this, context); } + @Override + public LogicalPlan visitJoin(Join node, AnalysisContext context) { + // TODO tables-join instead of plans-join supported only now + LogicalPlan left = visitRelation((Relation) node.getLeft(), context); + LogicalPlan right = visitRelation((Relation) node.getRight(), context); + Expression condition = expressionAnalyzer.analyze(node.getJoinCondition(), context); + ExpressionReferenceOptimizer optimizer = + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), left, right); + Expression optimized = optimizer.optimize(condition, context); + return new LogicalJoin(left, right, node.getJoinType(), optimized); + } + @Override public LogicalPlan visitRelation(Relation node, AnalysisContext context) { QualifiedName qualifiedName = node.getTableQualifiedName(); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index 398f848f16..c9b618a70f 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -5,6 +5,7 @@ package org.opensearch.sql.analysis; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -53,6 +54,12 @@ public ExpressionReferenceOptimizer( logicalPlan.accept(new ExpressionMapBuilder(), null); } + public ExpressionReferenceOptimizer( + BuiltinFunctionRepository repository, LogicalPlan... logicalPlans) { + this.repository = repository; + Arrays.stream(logicalPlans).forEach(p -> p.accept(new ExpressionMapBuilder(), null)); + } + public Expression optimize(Expression analyzed, AnalysisContext context) { return analyzed.accept(this, context); } diff --git a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java index 8bb6824a63..e5798ba22c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java +++ b/core/src/main/java/org/opensearch/sql/analysis/symbol/SymbolTable.java @@ -72,6 +72,21 @@ public Optional lookup(Symbol symbol) { Map table = tableByNamespace.get(symbol.getNamespace()); ExprType type = null; if (table != null) { + // To handle the field named start with [index.], for example index1.field1, + // this is used by Join query. + if (symbol.getNamespace() == Namespace.FIELD_NAME) { + String[] parts = symbol.getName().split("\\."); + if (parts.length == 2) { + // extract the indexName + if (tableByNamespace.get(Namespace.INDEX_NAME) != null) { + String indexName = tableByNamespace.get(Namespace.INDEX_NAME).firstKey(); + if (indexName != null && indexName.equals(parts[0])) { + type = table.get(parts[1]); + return Optional.ofNullable(type); + } + } + } + } type = table.get(symbol.getName()); } return Optional.ofNullable(type); diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 973b10310b..91cd4d3464 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -47,6 +47,7 @@ import org.opensearch.sql.ast.tree.FetchCursor; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.ML; @@ -109,6 +110,10 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitJoin(Join node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 4f3056b0f7..cbc88c8cae 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -48,6 +48,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Limit; import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; @@ -471,4 +472,12 @@ public static Parse parse( java.util.Map arguments) { return new Parse(parseMethod, sourceField, pattern, arguments, input); } + + public static Join join( + UnresolvedPlan left, + UnresolvedPlan right, + Join.JoinType joinType, + UnresolvedExpression condition) { + return new Join(left, right, joinType, condition); + } } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java new file mode 100644 index 0000000000..6905c31c49 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +@RequiredArgsConstructor +@Getter +@EqualsAndHashCode(callSuper = false) +@ToString +public class Join extends UnresolvedPlan { + private final UnresolvedPlan left; + private final UnresolvedPlan right; + private final JoinType joinType; + private final UnresolvedExpression joinCondition; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitJoin(this, context); + } + + public enum JoinType { + CROSS, + INNER, + SEMI, + ANTI, + LEFT, + RIGHT, + FULL + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index fd5ea14a2e..747fb26ae6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -289,6 +289,10 @@ public static Optional of(String str) { return Optional.ofNullable(ALL_NATIVE_FUNCTIONS.getOrDefault(FunctionName.of(str), null)); } + public static BuiltinFunctionName of(FunctionName name) { + return ALL_NATIVE_FUNCTIONS.get(name); + } + public static Optional ofAggregation(String functionName) { return Optional.ofNullable( AGGREGATION_FUNC_MAPPING.getOrDefault(functionName.toLowerCase(Locale.ROOT), null)); diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index f962c3e4bf..9e4f264151 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -5,13 +5,25 @@ package org.opensearch.sql.planner; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.executor.pagination.PlanSerializer; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalCloseCursor; import org.opensearch.sql.planner.logical.LogicalDedupe; import org.opensearch.sql.planner.logical.LogicalEval; import org.opensearch.sql.planner.logical.LogicalFetchCursor; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalJoin; import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalNested; import org.opensearch.sql.planner.logical.LogicalPaginate; @@ -41,6 +53,9 @@ import org.opensearch.sql.planner.physical.TakeOrderedOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; +import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinPredicatesHelper; +import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; import org.opensearch.sql.storage.read.TableScanBuilder; import org.opensearch.sql.storage.write.TableWriteBuilder; @@ -54,6 +69,7 @@ * @param context type */ public class DefaultImplementor extends LogicalPlanNodeVisitor { + private static final Logger LOG = LogManager.getLogger(); @Override public PhysicalPlan visitRareTopN(LogicalRareTopN node, C context) { @@ -156,6 +172,74 @@ public PhysicalPlan visitRelation(LogicalRelation node, C context) { + "implementing and optimizing logical plan with relation involved"); } + @Override + public PhysicalPlan visitJoin(LogicalJoin join, C ctx) { + LOG.debug("join condition is {}", join.getCondition()); + List predicates = + JoinPredicatesHelper.splitConjunctivePredicates(join.getCondition()); + // Extract all equi-join key pairs + List> equiJoinKeys = new ArrayList<>(); + for (Expression predicate : predicates) { + if (JoinPredicatesHelper.isEqual(predicate)) { + Pair pair = + JoinPredicatesHelper.extractJoinKeys((FunctionExpression) predicate); + if (pair.getLeft() instanceof ReferenceExpression + && pair.getRight() instanceof ReferenceExpression) { + if (canEvaluate((ReferenceExpression) pair.getLeft(), join.getLeft()) + && canEvaluate((ReferenceExpression) pair.getRight(), join.getRight())) { + equiJoinKeys.add(pair); + } else { + throw new SemanticCheckException( + StringUtils.format("Join key must be a field of index.")); + } + } else { + throw new SemanticCheckException( + StringUtils.format( + "Join condition must contain field only. E.g. t1.field1 = t2.field2 AND" + + " t1.field3 = t2.field4. But found {}", + predicate.getClass().getSimpleName())); + } + } else { + equiJoinKeys.clear(); + break; + } + } + + // 1. Determining Join with Hint. TODO + // 2. Pick hash join if it is an equi-join and hash join supported + if (!equiJoinKeys.isEmpty()) { + Pair, List> unzipped = JoinPredicatesHelper.unzip(equiJoinKeys); + List leftKeys = unzipped.getLeft(); + List rightKeys = unzipped.getRight(); + LOG.info("EquiJoin leftKeys are {}, rightKeys are {}", leftKeys, rightKeys); + + return new HashJoinOperator( + leftKeys, + rightKeys, + join.getType(), + visitRelation((LogicalRelation) join.getLeft(), ctx), + visitRelation((LogicalRelation) join.getRight(), ctx), + Optional.empty()); + // 3. Pick sort merge join if the join keys are sortable. TODO + } else { + // 4. Pick Nested loop join if is a non-equi-join. TODO + return new NestedLoopJoinOperator( + visitRelation((LogicalRelation) join.getLeft(), ctx), + visitRelation((LogicalRelation) join.getRight(), ctx), + join.getType(), + join.getCondition()); + } + } + + /** Return true if the reference can be evaluated in relation */ + private boolean canEvaluate(ReferenceExpression expr, LogicalPlan plan) { + if (plan instanceof LogicalRelation relation) { + return relation.getTable().getFieldTypes().containsKey(expr.getAttr()); + } else { + throw new UnsupportedOperationException("Only relation can be used in join"); + } + } + @Override public PhysicalPlan visitFetchCursor(LogicalFetchCursor plan, C context) { return new PlanSerializer(plan.getEngine()).convertToPlan(plan.getCursor()); diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java new file mode 100644 index 0000000000..4ba86fe920 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalJoin.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.expression.Expression; + +@ToString +@EqualsAndHashCode(callSuper = true) +@Getter +public class LogicalJoin extends LogicalPlan { + private final LogicalPlan left; + private final LogicalPlan right; + private final Join.JoinType type; + private final Expression condition; + + public LogicalJoin( + LogicalPlan left, LogicalPlan right, Join.JoinType type, Expression condition) { + super(ImmutableList.of(left, right)); + this.left = left; + this.right = right; + this.type = type; + this.condition = condition; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitJoin(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 2a886ba0ca..40ebdc0ce3 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -13,6 +13,7 @@ import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.Expression; @@ -138,4 +139,13 @@ public LogicalPlan values(List... values) { public static LogicalPlan limit(LogicalPlan input, Integer limit, Integer offset) { return new LogicalLimit(input, limit, offset); } + + public LogicalPlan innerJoin(LogicalPlan left, LogicalPlan right, Expression condition) { + return join(left, right, Join.JoinType.INNER, condition); + } + + public LogicalPlan join( + LogicalPlan left, LogicalPlan right, Join.JoinType joinType, Expression condition) { + return new LogicalJoin(left, right, joinType, condition); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 156db35306..532dcfb734 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -115,4 +115,8 @@ public R visitFetchCursor(LogicalFetchCursor plan, C context) { public R visitCloseCursor(LogicalCloseCursor plan, C context) { return visitNode(plan, context); } + + public R visitJoin(LogicalJoin plan, C context) { + return visitNode(plan, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 67d7a05135..55771b0b15 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.planner.physical; +import org.opensearch.sql.planner.physical.join.JoinOperator; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.write.TableWriteOperator; @@ -99,4 +100,8 @@ public R visitML(PhysicalPlan node, C context) { public R visitCursorClose(CursorCloseOperator node, C context) { return visitNode(node, context); } + + public R visitJoin(JoinOperator node, C context) { + return visitNode(node, context); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java new file mode 100644 index 0000000000..ab860c789e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java @@ -0,0 +1,225 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +@RequiredArgsConstructor +public class HashJoinOperator extends JoinOperator { + private final List leftKeys; + private final List rightKeys; + private final Join.JoinType joinType; + private final PhysicalPlan left; + private final PhysicalPlan right; + private final Optional nonEquiCond; + + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); + private Iterator joinedIterator; + + @Override + public void open() { + // Build hash table from left + left.open(); + Map hashed = buildHashed(); + // Set streamed side to right + right.open(); + Iterator streamed = right; + + if (joinType == Join.JoinType.INNER) { + innerJoin(streamed, hashed); + } else if (joinType == Join.JoinType.LEFT) { + leftOuterJoin(streamed, hashed); + } else if (joinType == Join.JoinType.SEMI) { + semiJoin(streamed, hashed); + } else if (joinType == Join.JoinType.ANTI) { + antiJoin(streamed, hashed); + } else { + throw new IllegalArgumentException("Unsupported join type: " + joinType); + } + } + + @Override + public void close() { + joinedIterator = null; + left.close(); + right.close(); + } + + private void innerJoin(Iterator streamed, Map hashed) { + while (streamed.hasNext()) { + ExprValue rightRow = streamed.next(); + for (Expression rightKey : rightKeys) { + ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); + if (rightRowKey != null && hashed.containsKey(rightRowKey)) { + ExprValue leftRow = hashed.get(rightRowKey); + ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + } + } else { + joinedBuilder.add(joinedRow); + } + } + } + } + joinedIterator = joinedBuilder.build().iterator(); + } + + private void leftOuterJoin(Iterator streamed, Map hashed) { + // Track matched keys to identify unmatched left rows later + Set matchedKeys = new HashSet<>(); + + while (streamed.hasNext()) { + ExprValue rightRow = streamed.next(); + for (Expression rightKey : rightKeys) { + ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); + if (rightRowKey != null && hashed.containsKey(rightRowKey)) { + ExprValue leftRow = hashed.get(rightRowKey); + ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + matchedKeys.add(rightRowKey); + } + } else { + joinedBuilder.add(joinedRow); + matchedKeys.add(rightRowKey); + } + } + } + } + + // Add unmatched left rows with nulls for the right side + for (Map.Entry entry : hashed.entrySet()) { + if (!matchedKeys.contains(entry.getKey())) { + ExprValue leftRow = entry.getValue(); + ExprValue joinedRow = combineExprTupleValue(leftRow, ExprValueUtils.nullValue()); + joinedBuilder.add(joinedRow); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private void semiJoin(Iterator streamed, Map hashed) { + Set matchedKeys = new HashSet<>(); + + while (streamed.hasNext()) { + ExprValue rightRow = streamed.next(); + for (Expression rightKey : rightKeys) { + ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); + if (rightRowKey != null && hashed.containsKey(rightRowKey)) { + ExprValue leftRow = hashed.get(rightRowKey); + if (nonEquiCond.isPresent()) { + ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedKeys.add(rightRowKey); + } + } else { + matchedKeys.add(rightRowKey); + } + } + } + } + + // Add matched left rows to the result + for (ExprValue key : matchedKeys) { + joinedBuilder.add(hashed.get(key)); + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private void antiJoin(Iterator streamed, Map hashed) { + Set matchedKeys = new HashSet<>(); + + while (streamed.hasNext()) { + ExprValue rightRow = streamed.next(); + for (Expression rightKey : rightKeys) { + ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); + if (rightRowKey != null && hashed.containsKey(rightRowKey)) { + ExprValue leftRow = hashed.get(rightRowKey); + if (nonEquiCond.isPresent()) { + ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedKeys.add(rightRowKey); + } + } else { + matchedKeys.add(rightRowKey); + } + } + } + } + + // Add unmatched left rows to the result + for (Map.Entry entry : hashed.entrySet()) { + if (!matchedKeys.contains(entry.getKey())) { + joinedBuilder.add(entry.getValue()); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private Map buildHashed() { + ImmutableMap.Builder leftTableBuilder = ImmutableMap.builder(); + while (left.hasNext()) { + ExprValue row = left.next(); + for (Expression leftKey : leftKeys) { + ExprValue rowKey = leftKey.valueOf(row.bindingTuples()); + if (rowKey != null) { + leftTableBuilder.put(rowKey, row); + break; + } + } + } + return leftTableBuilder.build(); + } + + @Override + public boolean hasNext() { + return joinedIterator != null && joinedIterator.hasNext(); + } + + @Override + public ExprValue next() { + return joinedIterator.next(); + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } + + private ExprTupleValue combineExprTupleValue(ExprValue left, ExprValue right) { + Map combinedMap = left.tupleValue(); + combinedMap.putAll(right.tupleValue()); + return ExprTupleValue.fromExprValueMap(combinedMap); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java new file mode 100644 index 0000000000..ee2e3ace44 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import java.util.List; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; + +public abstract class JoinOperator extends PhysicalPlan { + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitJoin(this, context); + } + + @Override + public abstract List getChild(); +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java new file mode 100644 index 0000000000..4c789750a5 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinPredicatesHelper.java @@ -0,0 +1,143 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import com.google.common.collect.ImmutableList; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +@UtilityClass +public class JoinPredicatesHelper { + + private static boolean instanceOf(Expression function, BuiltinFunctionName functionName) { + return function instanceof FunctionExpression + && ((FunctionExpression) function).getFunctionName().equals(functionName.getName()); + } + + private static boolean isValidJoinPredicate(FunctionExpression predicate) { + BuiltinFunctionName builtinFunctionName = BuiltinFunctionName.of(predicate.getFunctionName()); + switch (builtinFunctionName) { + case AND: + case OR: + case EQUAL: + case NOTEQUAL: + case LESS: + case LTE: + case GREATER: + case GTE: + return true; + default: + return false; + } + } + + public static ImmutablePair extractJoinKeys( + FunctionExpression predicate) { + if (isValidJoinPredicate(predicate)) { + throw new SemanticCheckException( + StringUtils.format( + "Join condition {} is an invalid function", + predicate.getFunctionName().getFunctionName())); + } else { + return ImmutablePair.of( + predicate.getArguments().getFirst(), predicate.getArguments().getLast()); + } + } + + public static List splitConjunctivePredicates(Expression condition) { + if (JoinPredicatesHelper.isAnd(condition)) { + return Stream.concat( + splitConjunctivePredicates(((FunctionExpression) condition).getArguments().getFirst()) + .stream(), + splitConjunctivePredicates(((FunctionExpression) condition).getArguments().getLast()) + .stream()) + .collect(Collectors.toList()); + } else { + return ImmutableList.of(condition); + } + } + + public static List splitDisjunctivePredicates(Expression condition) { + if (JoinPredicatesHelper.isOr(condition)) { + return Stream.concat( + splitDisjunctivePredicates(((FunctionExpression) condition).getArguments().getFirst()) + .stream(), + splitDisjunctivePredicates(((FunctionExpression) condition).getArguments().getLast()) + .stream()) + .collect(Collectors.toList()); + } else { + return ImmutableList.of(condition); + } + } + + public static Pair, List> unzip(List> pairs) { + List leftList = new ArrayList<>(); + List rightList = new ArrayList<>(); + for (Pair pair : pairs) { + leftList.add(pair.getLeft()); + rightList.add(pair.getRight()); + } + return Pair.of(leftList, rightList); + } + + public static boolean isAnd(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.AND); + } + + public static boolean isOr(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.OR); + } + + public static boolean isEqual(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.EQUAL); + } + + public static boolean isNot(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOT); + } + + public static boolean isXor(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.XOR); + } + + public static boolean isNotEqual(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOTEQUAL); + } + + public static boolean isLess(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LESS); + } + + public static boolean isLte(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LTE); + } + + public static boolean isGreater(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.GREATER); + } + + public static boolean isGte(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.GTE); + } + + public static boolean isLike(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.LIKE); + } + + public static boolean isNotLike(Expression expression) { + return instanceOf(expression, BuiltinFunctionName.NOT_LIKE); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java new file mode 100644 index 0000000000..d933c42be6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java @@ -0,0 +1,231 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical.join; + +import com.google.common.collect.ImmutableList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.PhysicalPlan; + +@RequiredArgsConstructor +public class NestedLoopJoinOperator extends JoinOperator { + private final PhysicalPlan left; + private final PhysicalPlan right; + private final Join.JoinType joinType; + private final Expression condition; + + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); + private Iterator joinedIterator; + // Build side is left by default, set the smaller side as the build side in future TODO + private Iterator buildSide; + + @Override + public void open() { + left.open(); + right.open(); + + // buildSide is left plan by default + buildSide = left; + + if (joinType == Join.JoinType.INNER) { + List cached = cacheStreamedSide(right); + innerJoin(cached); + } else if (joinType == Join.JoinType.LEFT) { + // build side is right plan and streamed side is left plan in left outer join. + buildSide = right; + List cached = cacheStreamedSide(left); + outerJoin(cached); + } else if (joinType == Join.JoinType.RIGHT) { + List cached = cacheStreamedSide(right); + outerJoin(cached); + } else if (joinType == Join.JoinType.SEMI) { + List cached = cacheStreamedSide(right); + semiJoin(cached); + } else if (joinType == Join.JoinType.ANTI) { + List cached = cacheStreamedSide(right); + antiJoin(cached); + } else { + // LeftOuter with BuildLeft + // RightOuter with BuildRight + // FullOuter + List cached = cacheStreamedSide(right); + defaultJoin(cached); + } + } + + /** Convert iterator to a list to allow multiple iterations */ + private List cacheStreamedSide(PhysicalPlan plan) { + ImmutableList.Builder streamedBuilder = ImmutableList.builder(); + plan.forEachRemaining(streamedBuilder::add); + return streamedBuilder.build(); + } + + @Override + public void close() { + joinedIterator = null; + left.close(); + right.close(); + } + + private void innerJoin(List cacheStreamedSide) { + Iterator streamed = cacheStreamedSide.iterator(); + while (streamed.hasNext()) { + ExprValue leftRow = buildSide.next(); + + for (ExprValue rightRow : cacheStreamedSide) { + ExprTupleValue joinedRow = combineExprTupleValue(leftRow, rightRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && (conditionValue.booleanValue())) { + joinedBuilder.add(joinedRow); + } + } + } + joinedIterator = joinedBuilder.build().iterator(); + } + + /** The implementation for LeftOuter with BuildRight, RightOuter with BuildLeft */ + private void outerJoin(List cacheStreamedSide) { + Set matchedRows = new HashSet<>(); + + // Probe phase + for (ExprValue streamedRow : cacheStreamedSide) { + boolean matched = false; + while (buildSide.hasNext()) { + ExprValue buildRow = buildSide.next(); + ExprTupleValue joinedRow = + combineExprTupleValue( + joinType == Join.JoinType.LEFT ? streamedRow : buildRow, + joinType == Join.JoinType.LEFT ? buildRow : streamedRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + matchedRows.add(streamedRow); + matched = true; + break; + } + } + if (!matched) { + ExprTupleValue joinedRow = + combineExprTupleValue( + joinType == Join.JoinType.LEFT ? streamedRow : ExprValueUtils.nullValue(), + joinType == Join.JoinType.LEFT ? ExprValueUtils.nullValue() : streamedRow); + joinedBuilder.add(joinedRow); + } + } + + // Add unmatched rows + if (joinType == Join.JoinType.LEFT) { + while (buildSide.hasNext()) { + ExprValue buildRow = buildSide.next(); + if (!matchedRows.contains(buildRow)) { + ExprTupleValue joinedRow = combineExprTupleValue(ExprValueUtils.nullValue(), buildRow); + joinedBuilder.add(joinedRow); + } + } + } else if (joinType == Join.JoinType.RIGHT) { + for (ExprValue streamedRow : cacheStreamedSide) { + if (!matchedRows.contains(streamedRow)) { + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, ExprValueUtils.nullValue()); + joinedBuilder.add(joinedRow); + } + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private void semiJoin(List cacheStreamedSide) { + Set matchedRows = new HashSet<>(); + + // Probe phase + for (ExprValue streamedRow : cacheStreamedSide) { + while (buildSide.hasNext()) { + ExprValue buildRow = buildSide.next(); + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedRows.add(streamedRow); + break; + } + } + } + + // Add matched rows + for (ExprValue row : matchedRows) { + joinedBuilder.add(row); + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + // Java + private void antiJoin(List cacheStreamedSide) { + Set matchedRows = new HashSet<>(); + + // Probe phase + for (ExprValue streamedRow : cacheStreamedSide) { + boolean matched = false; + while (buildSide.hasNext()) { + ExprValue buildRow = buildSide.next(); + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedRows.add(streamedRow); + matched = true; + break; + } + } + if (!matched) { + matchedRows.add(streamedRow); + } + } + + // Add unmatched rows + for (ExprValue row : cacheStreamedSide) { + if (!matchedRows.contains(row)) { + joinedBuilder.add(row); + } + } + + joinedIterator = joinedBuilder.build().iterator(); + } + + private void defaultJoin(List cacheStreamedSide) {} + + @Override + public boolean hasNext() { + return joinedIterator != null && joinedIterator.hasNext(); + } + + @Override + public ExprValue next() { + return joinedIterator.next(); + } + + private ExprTupleValue combineExprTupleValue(ExprValue left, ExprValue right) { + Map combinedMap = left.tupleValue(); + combinedMap.putAll(right.tupleValue()); + return ExprTupleValue.fromExprValueMap(combinedMap); + } + + @Override + public List getChild() { + return ImmutableList.of(left, right); + } +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 8d935b11d2..3668b13485 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -81,6 +81,7 @@ import org.opensearch.sql.ast.tree.AD; import org.opensearch.sql.ast.tree.CloseCursor; import org.opensearch.sql.ast.tree.FetchCursor; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Paginate; @@ -1767,4 +1768,209 @@ public void visit_close_cursor() { () -> assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor())); } + + @Test + public void inner_join() { + assertAnalyzeEqual( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void left_outer_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.LEFT, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.LEFT, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void right_outer_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.RIGHT, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.RIGHT, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void anti_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.ANTI, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.ANTI, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void semi_join() { + assertAnalyzeEqual( + LogicalPlanDSL.join( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + Join.JoinType.SEMI, + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.SEMI, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), AstDSL.field("schema2.double_value"))))); + } + + @Test + public void basic_SPJG() { + // Select(Filter)-Project-Join-GroupBy + // SELECT + // schema1.string_value, + // schema2.string_value, + // AVG(schema1.integer_value), + // MIN(schema2.long_value), + // FROM + // schema1 + // INNER JOIN + // schema2 + // ON + // schema1.integer_value = schema2.integer_value + // AND + // schema1.double_value = schema2.double_value + // WHERE + // schema1.integer_value > 10 + // GROUP BY + // schema1.string_value, schema2.string_value + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.filter( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + ImmutableList.of( + DSL.named( + "AVG(schema1.integer_value)", + DSL.avg(DSL.ref("schema1.integer_value", INTEGER))), + DSL.named( + "MIN(schema2.long_value)", + DSL.min(DSL.ref("schema2.long_value", LONG)))), + ImmutableList.of( + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named( + "schema2.string_value", DSL.ref("schema2.string_value", STRING)))), + DSL.greater( + DSL.ref("schema1.integer_value", INTEGER), DSL.literal(integerValue(10)))), + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named("schema2.string_value", DSL.ref("schema2.string_value", STRING))), + AstDSL.projectWithArg( + AstDSL.filter( + AstDSL.agg( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), + AstDSL.field("schema2.double_value")))), + ImmutableList.of( + alias( + "AVG(schema1.integer_value)", + aggregate("AVG", qualifiedName("schema1.integer_value"))), + alias( + "MIN(schema2.long_value)", + aggregate("MIN", qualifiedName("schema2.long_value")))), + emptyList(), + ImmutableList.of( + alias("schema1.string_value", qualifiedName("schema1.string_value")), + alias("schema2.string_value", qualifiedName("schema2.string_value"))), + emptyList()), + compare(">", AstDSL.field("schema1.integer_value"), intLiteral(10))), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("schema1.string_value", AstDSL.field("schema1.string_value")), + AstDSL.alias("schema2.string_value", AstDSL.field("schema2.string_value")))); + } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 0bf959a1b7..16da1e539d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -177,7 +177,8 @@ protected ExpressionAnalyzer expressionAnalyzer() { } protected void assertAnalyzeEqual(LogicalPlan expected, UnresolvedPlan unresolvedPlan) { - assertEquals(expected, analyze(unresolvedPlan)); + LogicalPlan actual = analyze(unresolvedPlan); + assertEquals(expected, actual); } protected LogicalPlan analyze(UnresolvedPlan unresolvedPlan) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java index 27edc588fa..df4cdcb99d 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; @@ -132,4 +133,45 @@ public void rename_and_project_all() { AstDSL.defaultFieldsArgs(), AllFields.of())); } + + @Test + public void project_all_from_join() { + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.innerJoin( + LogicalPlanDSL.relation("schema1", table), + LogicalPlanDSL.relation("schema2", table), + DSL.and( + DSL.equal( + DSL.ref("schema1.integer_value", INTEGER), + DSL.ref("schema2.integer_value", INTEGER)), + DSL.equal( + DSL.ref("schema1.double_value", DOUBLE), + DSL.ref("schema2.double_value", DOUBLE)))), + DSL.named("schema1.integer_value", DSL.ref("schema1.integer_value", INTEGER)), + DSL.named("schema1.double_value", DSL.ref("schema1.double_value", DOUBLE)), + DSL.named("schema1.string_value", DSL.ref("schema1.string_value", STRING)), + DSL.named("schema2.integer_value", DSL.ref("schema2.integer_value", INTEGER)), + DSL.named("schema2.double_value", DSL.ref("schema2.double_value", DOUBLE)), + DSL.named("schema2.string_value", DSL.ref("schema2.string_value", STRING))), + AstDSL.projectWithArg( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("schema1.double_value"), + AstDSL.field("schema2.double_value")))), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("schema1.integer_value", AstDSL.field("schema1.integer_value")), + AstDSL.alias("schema1.double_value", AstDSL.field("schema1.double_value")), + AstDSL.alias("schema1.string_value", AstDSL.field("schema1.string_value")), + AstDSL.alias("schema2.integer_value", AstDSL.field("schema2.integer_value")), + AstDSL.alias("schema2.double_value", AstDSL.field("schema2.double_value")), + AstDSL.alias("schema2.string_value", AstDSL.field("schema2.string_value")))); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f212749f48..f4be576d79 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -78,6 +78,56 @@ public void logical_plan_should_be_traversable() { assertEquals(5, result); } + @Test + public void table_join_plan_should_be_traversable() { + LogicalPlan leftRelation = LogicalPlanDSL.relation("schema1", table); + LogicalPlan rightRelation = LogicalPlanDSL.relation("schema2", table); + LogicalPlan join = LogicalPlanDSL.innerJoin(leftRelation, rightRelation, expression); + LogicalPlan logicalPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(join, expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + Integer result = logicalPlan.accept(new NodesCount(), null); + assertEquals(7, result); + } + + @Test + public void complex_join_plan_should_be_traversable() { + LogicalPlan leftPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + + LogicalPlan rightPlan = + LogicalPlanDSL.rename( + LogicalPlanDSL.aggregation( + LogicalPlanDSL.rareTopN( + LogicalPlanDSL.filter(LogicalPlanDSL.relation("schema", table), expression), + CommandType.TOP, + ImmutableList.of(expression), + expression), + ImmutableList.of(DSL.named("avg", aggregator)), + ImmutableList.of(DSL.named("group", expression))), + ImmutableMap.of(ref, ref)); + LogicalPlan join = LogicalPlanDSL.innerJoin(leftPlan, rightPlan, expression); + Integer result = join.accept(new NodesCount(), null); + assertEquals(11, result); + } + @SuppressWarnings("unchecked") private static Stream getLogicalPlansForVisitorTest() { LogicalPlan relation = LogicalPlanDSL.relation("schema", table); @@ -141,6 +191,12 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + LogicalPlan relation2 = LogicalPlanDSL.relation("schema2", table); + + LogicalPlan join = + LogicalPlanDSL.innerJoin( + (LogicalRelation) relation, (LogicalRelation) relation2, expression); + return Stream.of( relation, tableScanBuilder, @@ -163,7 +219,8 @@ public TableWriteOperator build(PhysicalPlan child) { paginate, nested, cursor, - closeCursor) + closeCursor, + join) .map(Arguments::of); } @@ -214,5 +271,14 @@ public Integer visitRareTopN(LogicalRareTopN plan, Object context) { .mapToInt(Integer::intValue) .sum(); } + + @Override + public Integer visitJoin(LogicalJoin plan, Object context) { + return 1 + + plan.getChild().stream() + .map(child -> child.accept(this, context)) + .mapToInt(Integer::intValue) + .sum(); + } } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java new file mode 100644 index 0000000000..2f114eced8 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; + +import com.google.common.collect.ImmutableMap; +import java.util.List; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class JoinOperatorTest extends PhysicalPlanTestBase { + + public void nested_loop_join_test() { + PhysicalPlan left = testScan(compoundInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.INNER, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "ip", + "209.160.24.63", + "action", + "GET", + "response", + 404, + "referer", + "www.amazon.com")))); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java index 6399f945ed..f1b76611a5 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java @@ -29,17 +29,17 @@ public class PhysicalPlanTestBase { protected static final List countTestInputs = new ImmutableList.Builder() - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "testString", "asdf"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "testString", "asdf"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "name", "c"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8, "name", "h"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k"))) .build(); protected static final List inputs = diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 9f707c13cd..4c4b8a9c32 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -36,6 +36,16 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; +//Native JOIN KEYWORDS +JOIN: 'JOIN'; +ON: 'ON'; +INNER: 'INNER'; +OUTER: 'OUTER'; +FULL: 'FULL'; +SEMI: 'SEMI'; +ANTI: 'ANTI'; +CROSS: 'CROSS'; + // COMMAND ASSIST KEYWORDS AS: 'AS'; BY: 'BY'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 4dc223b028..8d1d3ca9c1 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -166,16 +166,49 @@ mlArg // clauses fromClause - : SOURCE EQUAL tableSourceClause - | INDEX EQUAL tableSourceClause - | SOURCE EQUAL tableFunction - | INDEX EQUAL tableFunction + : (SOURCE | INDEX) EQUAL tableSourceClause + | (SOURCE | INDEX) EQUAL tableFunction + | (SOURCE | INDEX) EQUAL relation ; tableSourceClause : tableSource (COMMA tableSource)* ; +// TODO two-tables join only. Multi-tables join `relationExtension*` is unsupported in current implementation. +relation + : tablePrimary relationExtension + ; + +tablePrimary + : tableSource (AS alias = qualifiedName)? + ; + +relationExtension + : joinSource + ; + + // TODO joinCriteria could be none `(joinCriteria?)` for complex cases. It's unsupported in current implementation. + // TODO join hints `(hintStatement)?` is unsupported in current implementation. + // TODO directly tables jon only, join two plans is unsupported in current implementation. +joinSource + : (joinType) JOIN right = tablePrimary joinCriteria + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | RIGHT OUTER? + | FULL OUTER? + | LEFT? SEMI + | LEFT? ANTI + ; + +joinCriteria + : ON logicalExpression + ; + renameClasue : orignalField = wcFieldExpression AS renamedField = wcFieldExpression ; @@ -925,4 +958,14 @@ keywordsCanBeId | SPARKLINE | C | DC + // JOIN + | ON + | INNER + | CROSS + | OUTER + | SEMI + | LEFT + | RIGHT + | FULL + | ANTI ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 78fe28b49e..b22b557832 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -53,6 +53,7 @@ import org.opensearch.sql.ast.tree.Eval; import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.Parse; @@ -312,6 +313,8 @@ public UnresolvedPlan visitTopCommand(TopCommandContext ctx) { public UnresolvedPlan visitFromClause(FromClauseContext ctx) { if (ctx.tableFunction() != null) { return visitTableFunction(ctx.tableFunction()); + } else if (ctx.relation() != null) { + return visitRelation(ctx.relation()); } else { return visitTableSourceClause(ctx.tableSourceClause()); } @@ -337,6 +340,47 @@ public UnresolvedPlan visitTableFunction(TableFunctionContext ctx) { return new TableFunction(this.internalVisitExpression(ctx.qualifiedName()), builder.build()); } + @Override + public UnresolvedPlan visitTablePrimary(OpenSearchPPLParser.TablePrimaryContext ctx) { + if (ctx.alias != null) { + return new Relation(this.internalVisitExpression(ctx.tableSource()), ctx.alias.getText()); + } else { + return new Relation(this.internalVisitExpression(ctx.tableSource())); + } + } + + @Override + public UnresolvedPlan visitRelation(OpenSearchPPLParser.RelationContext ctx) { + return withRelationExtensions(ctx, visitTablePrimary(ctx.tablePrimary())); + } + + private UnresolvedPlan withRelationExtensions( + OpenSearchPPLParser.RelationContext ctx, UnresolvedPlan tablePrimary) { + OpenSearchPPLParser.JoinSourceContext joinCtx = ctx.relationExtension().joinSource(); + Join.JoinType joinType; + if (joinCtx.joinType() == null) { + joinType = Join.JoinType.INNER; + } else if (joinCtx.joinType().INNER() != null) { + joinType = Join.JoinType.INNER; + } else if (joinCtx.joinType().CROSS() != null) { + joinType = Join.JoinType.CROSS; + } else if (joinCtx.joinType().FULL() != null) { + joinType = Join.JoinType.FULL; + } else if (joinCtx.joinType().SEMI() != null) { + joinType = Join.JoinType.SEMI; + } else if (joinCtx.joinType().ANTI() != null) { + joinType = Join.JoinType.ANTI; + } else if (joinCtx.joinType().LEFT() != null) { + joinType = Join.JoinType.LEFT; + } else if (joinCtx.joinType().RIGHT() != null) { + joinType = Join.JoinType.RIGHT; + } else { + joinType = Join.JoinType.INNER; + } + UnresolvedExpression joinCondition = this.internalVisitExpression(joinCtx.joinCriteria()); + return new Join(tablePrimary, visitTablePrimary(joinCtx.right), joinType, joinCondition); + } + /** Navigate to & build AST expression. */ private UnresolvedExpression internalVisitExpression(ParseTree tree) { return expressionBuilder.visit(tree); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index aec22ac231..bf1aabd813 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -69,6 +69,17 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor Date: Mon, 2 Sep 2024 13:37:36 +0800 Subject: [PATCH 2/4] Support NestedLoopJoin for join type INNER, LEFT, RIGHT, SEMI and ANTI Signed-off-by: Lantao Jin --- .../org/opensearch/sql/ast/tree/Join.java | 6 +- .../physical/join/NestedLoopJoinOperator.java | 133 ++-- .../planner/physical/JoinOperatorTest.java | 631 +++++++++++++++++- .../physical/PhysicalPlanTestBase.java | 44 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 9 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 16 +- .../sql/ppl/parser/AstExpressionBuilder.java | 11 - 7 files changed, 728 insertions(+), 122 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java index 6905c31c49..f70d46de84 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Join.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Join.java @@ -40,12 +40,12 @@ public T accept(AbstractNodeVisitor nodeVisitor, C context) { } public enum JoinType { - CROSS, INNER, - SEMI, - ANTI, LEFT, RIGHT, + SEMI, + ANTI, + CROSS, FULL } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java index d933c42be6..c79f820221 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java @@ -5,9 +5,12 @@ package org.opensearch.sql.planner.physical.join; +import static org.opensearch.sql.data.type.ExprCoreType.UNDEFINED; + import com.google.common.collect.ImmutableList; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -40,33 +43,34 @@ public void open() { buildSide = left; if (joinType == Join.JoinType.INNER) { - List cached = cacheStreamedSide(right); - innerJoin(cached); + List cachedBuildSide = cacheIterator(right); + innerJoin(cachedBuildSide, left); } else if (joinType == Join.JoinType.LEFT) { // build side is right plan and streamed side is left plan in left outer join. - buildSide = right; - List cached = cacheStreamedSide(left); - outerJoin(cached); + List cachedBuildSide = cacheIterator(right); + outerJoin(cachedBuildSide, left); } else if (joinType == Join.JoinType.RIGHT) { - List cached = cacheStreamedSide(right); - outerJoin(cached); + // build side is left plan and streamed side is right plan in right outer join. + List cachedBuildSide = cacheIterator(left); + outerJoin(cachedBuildSide, right); } else if (joinType == Join.JoinType.SEMI) { - List cached = cacheStreamedSide(right); - semiJoin(cached); + // build right plan in left semi + // TODO support buildLeft in LEFT SEMI + List cachedBuildSide = cacheIterator(right); + semiJoin(cachedBuildSide, left); } else if (joinType == Join.JoinType.ANTI) { - List cached = cacheStreamedSide(right); - antiJoin(cached); + // build right plan in left semi + // TODO support buildLeft in LEFT ANTI + List cachedBuildSide = cacheIterator(right); + antiJoin(cachedBuildSide, left); } else { - // LeftOuter with BuildLeft - // RightOuter with BuildRight // FullOuter - List cached = cacheStreamedSide(right); - defaultJoin(cached); + throw new UnsupportedOperationException("Unsupported Join Type " + joinType); } } /** Convert iterator to a list to allow multiple iterations */ - private List cacheStreamedSide(PhysicalPlan plan) { + private List cacheIterator(PhysicalPlan plan) { ImmutableList.Builder streamedBuilder = ImmutableList.builder(); plan.forEachRemaining(streamedBuilder::add); return streamedBuilder.build(); @@ -79,13 +83,12 @@ public void close() { right.close(); } - private void innerJoin(List cacheStreamedSide) { - Iterator streamed = cacheStreamedSide.iterator(); - while (streamed.hasNext()) { - ExprValue leftRow = buildSide.next(); + private void innerJoin(List cachedBuildSide, Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); - for (ExprValue rightRow : cacheStreamedSide) { - ExprTupleValue joinedRow = combineExprTupleValue(leftRow, rightRow); + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && (conditionValue.booleanValue())) { @@ -97,64 +100,34 @@ private void innerJoin(List cacheStreamedSide) { } /** The implementation for LeftOuter with BuildRight, RightOuter with BuildLeft */ - private void outerJoin(List cacheStreamedSide) { - Set matchedRows = new HashSet<>(); - - // Probe phase - for (ExprValue streamedRow : cacheStreamedSide) { + private void outerJoin(List cachedBuildSide, Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); boolean matched = false; - while (buildSide.hasNext()) { - ExprValue buildRow = buildSide.next(); - ExprTupleValue joinedRow = - combineExprTupleValue( - joinType == Join.JoinType.LEFT ? streamedRow : buildRow, - joinType == Join.JoinType.LEFT ? buildRow : streamedRow); + for (ExprValue buildRow : cachedBuildSide) { + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { joinedBuilder.add(joinedRow); - matchedRows.add(streamedRow); matched = true; - break; } } if (!matched) { - ExprTupleValue joinedRow = - combineExprTupleValue( - joinType == Join.JoinType.LEFT ? streamedRow : ExprValueUtils.nullValue(), - joinType == Join.JoinType.LEFT ? ExprValueUtils.nullValue() : streamedRow); + ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, ExprValueUtils.nullValue()); joinedBuilder.add(joinedRow); } } - // Add unmatched rows - if (joinType == Join.JoinType.LEFT) { - while (buildSide.hasNext()) { - ExprValue buildRow = buildSide.next(); - if (!matchedRows.contains(buildRow)) { - ExprTupleValue joinedRow = combineExprTupleValue(ExprValueUtils.nullValue(), buildRow); - joinedBuilder.add(joinedRow); - } - } - } else if (joinType == Join.JoinType.RIGHT) { - for (ExprValue streamedRow : cacheStreamedSide) { - if (!matchedRows.contains(streamedRow)) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, ExprValueUtils.nullValue()); - joinedBuilder.add(joinedRow); - } - } - } - joinedIterator = joinedBuilder.build().iterator(); } - private void semiJoin(List cacheStreamedSide) { + private void semiJoin(List cachedBuildSide, Iterator streamedSide) { Set matchedRows = new HashSet<>(); - // Probe phase - for (ExprValue streamedRow : cacheStreamedSide) { - while (buildSide.hasNext()) { - ExprValue buildRow = buildSide.next(); + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); + for (ExprValue buildRow : cachedBuildSide) { ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) @@ -165,7 +138,6 @@ private void semiJoin(List cacheStreamedSide) { } } - // Add matched rows for (ExprValue row : matchedRows) { joinedBuilder.add(row); } @@ -173,41 +145,27 @@ private void semiJoin(List cacheStreamedSide) { joinedIterator = joinedBuilder.build().iterator(); } - // Java - private void antiJoin(List cacheStreamedSide) { - Set matchedRows = new HashSet<>(); - - // Probe phase - for (ExprValue streamedRow : cacheStreamedSide) { + private void antiJoin(List cachedBuildSide, Iterator streamedSide) { + while (streamedSide.hasNext()) { + ExprValue streamedRow = streamedSide.next(); boolean matched = false; - while (buildSide.hasNext()) { - ExprValue buildRow = buildSide.next(); + for (ExprValue buildRow : cachedBuildSide) { ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { - matchedRows.add(streamedRow); matched = true; break; } } if (!matched) { - matchedRows.add(streamedRow); - } - } - - // Add unmatched rows - for (ExprValue row : cacheStreamedSide) { - if (!matchedRows.contains(row)) { - joinedBuilder.add(row); + joinedBuilder.add(streamedRow); } } joinedIterator = joinedBuilder.build().iterator(); } - private void defaultJoin(List cacheStreamedSide) {} - @Override public boolean hasNext() { return joinedIterator != null && joinedIterator.hasNext(); @@ -218,9 +176,14 @@ public ExprValue next() { return joinedIterator.next(); } - private ExprTupleValue combineExprTupleValue(ExprValue left, ExprValue right) { - Map combinedMap = left.tupleValue(); - combinedMap.putAll(right.tupleValue()); + private ExprTupleValue combineExprTupleValue(ExprValue streamedRow, ExprValue buildRow) { + ExprValue left = joinType == Join.JoinType.RIGHT ? buildRow : streamedRow; + ExprValue right = joinType == Join.JoinType.RIGHT ? streamedRow : buildRow; + Map leftTuple = left.type().equals(UNDEFINED) ? Map.of() : left.tupleValue(); + Map rightTuple = + right.type().equals(UNDEFINED) ? Map.of() : right.tupleValue(); + Map combinedMap = new LinkedHashMap<>(leftTuple); + combinedMap.putAll(rightTuple); return ExprTupleValue.fromExprValueMap(combinedMap); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java index 2f114eced8..01e8f59ae8 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java @@ -14,9 +14,11 @@ import java.util.List; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; @@ -26,8 +28,9 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) public class JoinOperatorTest extends PhysicalPlanTestBase { - public void nested_loop_join_test() { - PhysicalPlan left = testScan(compoundInputs); + @Test + public void nested_loop_inner_join_test() { + PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = new NestedLoopJoinOperator( @@ -36,19 +39,627 @@ public void nested_loop_join_test() { Join.JoinType.INNER, DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); List result = execute(joinPlan); + result.forEach(System.out::println); assertEquals(7, result.size()); assertThat( result, containsInAnyOrder( ExprValueUtils.tupleValue( ImmutableMap.of( - "ip", - "209.160.24.63", - "action", - "GET", - "response", - 404, - "referer", - "www.amazon.com")))); + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)))); + } + + @Test + public void nested_loop_inner_join_test_2() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.INNER, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)))); + } + + @Test + public void nested_loop_left_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.LEFT, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void nested_loop_left_join_test_2() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.LEFT, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void nested_loop_right_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.RIGHT, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void nested_loop_right_join_test_2() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.RIGHT, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void nested_loop_semi_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.SEMI, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)))); + } + + @Test + public void nested_loop_semi_join_test_2() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.SEMI, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)))); + } + + @Test + public void nested_loop_anti_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.ANTI, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(2, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void nested_loop_anti_join_test_2() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new NestedLoopJoinOperator( + left, + right, + Join.JoinType.ANTI, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java index f1b76611a5..35de25375d 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java @@ -31,17 +31,57 @@ public class PhysicalPlanTestBase { new ImmutableList.Builder() .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "name", "c"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8, "name", "h"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k"))) .build(); + protected static final List joinTestInputs = + new ImmutableList.Builder() + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13))) + .build(); + protected static final List inputs = new ImmutableList.Builder() .add( diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 8d1d3ca9c1..29fc53f1c0 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -166,9 +166,12 @@ mlArg // clauses fromClause - : (SOURCE | INDEX) EQUAL tableSourceClause - | (SOURCE | INDEX) EQUAL tableFunction - | (SOURCE | INDEX) EQUAL relation + : SOURCE EQUAL tableSourceClause + | INDEX EQUAL tableSourceClause + | SOURCE EQUAL tableFunction + | INDEX EQUAL tableFunction + | SOURCE EQUAL relation + | INDEX EQUAL relation ; tableSourceClause diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index b22b557832..3931ee8941 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -362,18 +362,18 @@ private UnresolvedPlan withRelationExtensions( joinType = Join.JoinType.INNER; } else if (joinCtx.joinType().INNER() != null) { joinType = Join.JoinType.INNER; - } else if (joinCtx.joinType().CROSS() != null) { - joinType = Join.JoinType.CROSS; - } else if (joinCtx.joinType().FULL() != null) { - joinType = Join.JoinType.FULL; - } else if (joinCtx.joinType().SEMI() != null) { - joinType = Join.JoinType.SEMI; - } else if (joinCtx.joinType().ANTI() != null) { - joinType = Join.JoinType.ANTI; } else if (joinCtx.joinType().LEFT() != null) { joinType = Join.JoinType.LEFT; } else if (joinCtx.joinType().RIGHT() != null) { joinType = Join.JoinType.RIGHT; + } else if (joinCtx.joinType().SEMI() != null) { + joinType = Join.JoinType.SEMI; + } else if (joinCtx.joinType().ANTI() != null) { + joinType = Join.JoinType.ANTI; + } else if (joinCtx.joinType().CROSS() != null) { + joinType = Join.JoinType.CROSS; + } else if (joinCtx.joinType().FULL() != null) { + joinType = Join.JoinType.FULL; } else { joinType = Join.JoinType.INNER; } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index bf1aabd813..aec22ac231 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -69,17 +69,6 @@ public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor Date: Tue, 3 Sep 2024 20:28:00 +0800 Subject: [PATCH 3/4] Support HashJoin with default implementation of memory hash table Signed-off-by: Lantao Jin --- .../sql/analysis/AnalysisContext.java | 4 - .../sql/planner/DefaultImplementor.java | 21 +- .../physical/join/DefaultHashedRelation.java | 54 + .../physical/join/HashJoinOperator.java | 253 +++-- .../planner/physical/join/HashedRelation.java | 32 + .../planner/physical/join/JoinOperator.java | 40 + .../physical/join/NestedLoopJoinOperator.java | 116 ++- .../physical/HashJoinOperatorTest.java | 950 ++++++++++++++++++ ...t.java => NestedLoopJoinOperatorTest.java} | 110 +- 9 files changed, 1340 insertions(+), 240 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java rename core/src/test/java/org/opensearch/sql/planner/physical/{JoinOperatorTest.java => NestedLoopJoinOperatorTest.java} (89%) diff --git a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java index b4be232b0f..f1f29e9b38 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java +++ b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java @@ -41,10 +41,6 @@ public void push() { environment = new TypeEnvironment(environment); } - public void cleanFields() { - environment.clearAllFields(); - } - /** * Return current environment. * diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index 9e4f264151..4dc0f5a27c 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -5,12 +5,16 @@ package org.opensearch.sql.planner; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildLeft; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; + import java.util.ArrayList; import java.util.List; import java.util.Optional; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.executor.pagination.PlanSerializer; @@ -54,6 +58,7 @@ import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinOperator; import org.opensearch.sql.planner.physical.join.JoinPredicatesHelper; import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -205,7 +210,8 @@ && canEvaluate((ReferenceExpression) pair.getRight(), join.getRight())) { } } - // 1. Determining Join with Hint. TODO + // 1. Determining Join with Hint and build side. + JoinOperator.BuildSide buildSide = determineBuildSide(join.getType()); // 2. Pick hash join if it is an equi-join and hash join supported if (!equiJoinKeys.isEmpty()) { Pair, List> unzipped = JoinPredicatesHelper.unzip(equiJoinKeys); @@ -217,6 +223,7 @@ && canEvaluate((ReferenceExpression) pair.getRight(), join.getRight())) { leftKeys, rightKeys, join.getType(), + buildSide, visitRelation((LogicalRelation) join.getLeft(), ctx), visitRelation((LogicalRelation) join.getRight(), ctx), Optional.empty()); @@ -227,10 +234,22 @@ && canEvaluate((ReferenceExpression) pair.getRight(), join.getRight())) { visitRelation((LogicalRelation) join.getLeft(), ctx), visitRelation((LogicalRelation) join.getRight(), ctx), join.getType(), + buildSide, join.getCondition()); } } + /** + * Build side is right by default (except RightOuter). TODO set the smaller side as the build side + * TODO set build side from hint if provided + * + * @param joinType Join type + * @return Build side + */ + private JoinOperator.BuildSide determineBuildSide(Join.JoinType joinType) { + return joinType == Join.JoinType.RIGHT ? BuildLeft : BuildRight; + } + /** Return true if the reference can be evaluated in relation */ private boolean canEvaluate(ReferenceExpression expr, LogicalPlan plan) { if (plan instanceof LogicalRelation relation) { diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java new file mode 100644 index 0000000000..6a032db5c1 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/DefaultHashedRelation.java @@ -0,0 +1,54 @@ +package org.opensearch.sql.planner.physical.join; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.opensearch.sql.data.model.ExprValue; + +public class DefaultHashedRelation implements HashedRelation, Serializable { + + private final Map> map = new HashMap<>(); + private int numKeys; + private int numValues; + + @Override + public List get(ExprValue key) { + return map.get(key); + } + + @Override + public ExprValue getValue(ExprValue key) { + List values = map.get(key); + return values != null && !values.isEmpty() ? values.getFirst() : null; + } + + @Override + public boolean containsKey(ExprValue key) { + return map.containsKey(key); + } + + @Override + public Iterator keyIterator() { + return map.keySet().iterator(); + } + + @Override + public boolean isUniqueKey() { + return numKeys == numValues; + } + + @Override + public void close() { + map.clear(); + } + + @Override + public void put(ExprValue key, ExprValue value) { + map.computeIfAbsent(key, k -> new ArrayList<>()).add(value); + numKeys++; + numValues++; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java index ab860c789e..0be6612e96 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java @@ -5,14 +5,15 @@ package org.opensearch.sql.planner.physical.join; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; + import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.IntStream; import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.data.model.ExprTupleValue; @@ -21,11 +22,22 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.planner.physical.PhysicalPlan; +/** + * Hash Join Operator. For best performance, the build side should be set a smaller table, without + * hint and CBO, we treat right side as a smaller table by default and the build side set to right. + * TODO add join hint support. Best practice in PPL: source=bigger | INNER JOIN smaller ON + * bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is right + * (smaller), and the streamed side is left (bigger). For RIGHT OUTER join, the build side is always + * left. If the smaller table is left, it will get the best performance: source=smaller | RIGHT JOIN + * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is + * left (smaller), and the streamed side is right (bigger). + */ @RequiredArgsConstructor public class HashJoinOperator extends JoinOperator { private final List leftKeys; private final List rightKeys; private final Join.JoinType joinType; + private final BuildSide buildSide; private final PhysicalPlan left; private final PhysicalPlan right; private final Optional nonEquiCond; @@ -33,51 +45,74 @@ public class HashJoinOperator extends JoinOperator { private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); private Iterator joinedIterator; + private HashedRelation hashed; + private List buildKeys; + private List streamedKeys; + @Override public void open() { - // Build hash table from left + if (!(leftKeys.size() == rightKeys.size() + && IntStream.range(0, leftKeys.size()) + .allMatch(i -> sameType(leftKeys.get(i), rightKeys.get(i))))) { + throw new IllegalArgumentException( + "Join keys from two sides should have same length and types"); + } + left.open(); - Map hashed = buildHashed(); - // Set streamed side to right right.open(); - Iterator streamed = right; - - if (joinType == Join.JoinType.INNER) { - innerJoin(streamed, hashed); - } else if (joinType == Join.JoinType.LEFT) { - leftOuterJoin(streamed, hashed); - } else if (joinType == Join.JoinType.SEMI) { - semiJoin(streamed, hashed); - } else if (joinType == Join.JoinType.ANTI) { - antiJoin(streamed, hashed); + Iterator streamed; + if (buildSide == BuildRight) { + hashed = buildHashed(right, rightKeys); + streamed = left; + buildKeys = rightKeys; + streamedKeys = leftKeys; } else { - throw new IllegalArgumentException("Unsupported join type: " + joinType); + hashed = buildHashed(left, leftKeys); + streamed = right; + buildKeys = leftKeys; + streamedKeys = rightKeys; + } + + switch (joinType) { + case INNER -> innerJoin(streamed); + case LEFT, RIGHT -> outerJoin(streamed); + case SEMI -> semiJoin(streamed); + case ANTI -> antiJoin(streamed); + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); } } @Override public void close() { joinedIterator = null; + if (hashed != null) { + hashed.close(); + hashed = null; + } left.close(); right.close(); } - private void innerJoin(Iterator streamed, Map hashed) { + @Override + public void innerJoin(Iterator streamed) { while (streamed.hasNext()) { - ExprValue rightRow = streamed.next(); - for (Expression rightKey : rightKeys) { - ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); - if (rightRowKey != null && hashed.containsKey(rightRowKey)) { - ExprValue leftRow = hashed.get(rightRowKey); - ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); - if (nonEquiCond.isPresent()) { - ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); - if (!(conditionValue.isNull() || conditionValue.isMissing()) - && conditionValue.booleanValue()) { + ExprValue streamedRow = streamed.next(); + + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + } + } else { joinedBuilder.add(joinedRow); } - } else { - joinedBuilder.add(joinedRow); } } } @@ -85,37 +120,40 @@ private void innerJoin(Iterator streamed, Map h joinedIterator = joinedBuilder.build().iterator(); } - private void leftOuterJoin(Iterator streamed, Map hashed) { - // Track matched keys to identify unmatched left rows later - Set matchedKeys = new HashSet<>(); - + /** The implementation for outer join: LeftOuter with BuildRight RightOuter with BuildLeft */ + @Override + public void outerJoin(Iterator streamed) { while (streamed.hasNext()) { - ExprValue rightRow = streamed.next(); - for (Expression rightKey : rightKeys) { - ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); - if (rightRowKey != null && hashed.containsKey(rightRowKey)) { - ExprValue leftRow = hashed.get(rightRowKey); - ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); - if (nonEquiCond.isPresent()) { - ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); - if (!(conditionValue.isNull() || conditionValue.isMissing()) - && conditionValue.booleanValue()) { + ExprValue streamedRow = streamed.next(); + boolean matched = false; + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + joinedBuilder.add(joinedRow); + matched = true; + } + } else { joinedBuilder.add(joinedRow); - matchedKeys.add(rightRowKey); + matched = true; } - } else { - joinedBuilder.add(joinedRow); - matchedKeys.add(rightRowKey); } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + matched = false; + break; } } - } - // Add unmatched left rows with nulls for the right side - for (Map.Entry entry : hashed.entrySet()) { - if (!matchedKeys.contains(entry.getKey())) { - ExprValue leftRow = entry.getValue(); - ExprValue joinedRow = combineExprTupleValue(leftRow, ExprValueUtils.nullValue()); + if (!matched) { + ExprTupleValue joinedRow = + combineExprTupleValue(buildSide, streamedRow, ExprValueUtils.nullValue()); joinedBuilder.add(joinedRow); } } @@ -123,83 +161,90 @@ private void leftOuterJoin(Iterator streamed, Map streamed, Map hashed) { - Set matchedKeys = new HashSet<>(); + @Override + public void semiJoin(Iterator streamed) { + Set matchedRows = new HashSet<>(); while (streamed.hasNext()) { - ExprValue rightRow = streamed.next(); - for (Expression rightKey : rightKeys) { - ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); - if (rightRowKey != null && hashed.containsKey(rightRowKey)) { - ExprValue leftRow = hashed.get(rightRowKey); - if (nonEquiCond.isPresent()) { - ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); - ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); - if (!(conditionValue.isNull() || conditionValue.isMissing()) - && conditionValue.booleanValue()) { - matchedKeys.add(rightRowKey); + ExprValue streamedRow = streamed.next(); + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); + if (nonEquiCond.isPresent()) { + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matchedRows.add(streamedRow); + } + } else { + matchedRows.add(streamedRow); } - } else { - matchedKeys.add(rightRowKey); } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + break; } } } - // Add matched left rows to the result - for (ExprValue key : matchedKeys) { - joinedBuilder.add(hashed.get(key)); + for (ExprValue row : matchedRows) { + joinedBuilder.add(row); } joinedIterator = joinedBuilder.build().iterator(); } - private void antiJoin(Iterator streamed, Map hashed) { - Set matchedKeys = new HashSet<>(); - + @Override + public void antiJoin(Iterator streamed) { while (streamed.hasNext()) { - ExprValue rightRow = streamed.next(); - for (Expression rightKey : rightKeys) { - ExprValue rightRowKey = rightKey.valueOf(rightRow.bindingTuples()); - if (rightRowKey != null && hashed.containsKey(rightRowKey)) { - ExprValue leftRow = hashed.get(rightRowKey); - if (nonEquiCond.isPresent()) { - ExprValue joinedRow = combineExprTupleValue(leftRow, rightRow); - ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); - if (!(conditionValue.isNull() || conditionValue.isMissing()) - && conditionValue.booleanValue()) { - matchedKeys.add(rightRowKey); + ExprValue streamedRow = streamed.next(); + boolean matched = false; + for (Expression streamedKey : streamedKeys) { + ExprValue streamedRowKey = streamedKey.valueOf(streamedRow.bindingTuples()); + if (streamedRowKey != null && hashed.containsKey(streamedRowKey)) { + List matchedBuildRows = hashed.get(streamedRowKey); + for (ExprValue matchedBuildRow : matchedBuildRows) { + if (nonEquiCond.isPresent()) { + ExprValue joinedRow = combineExprTupleValue(buildSide, matchedBuildRow, streamedRow); + ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); + if (!(conditionValue.isNull() || conditionValue.isMissing()) + && conditionValue.booleanValue()) { + matched = true; + } + } else { + matched = true; } - } else { - matchedKeys.add(rightRowKey); } + } else { + // if any streamedRowKey does not match, the remaining keys are not checked. + matched = false; + break; } } - } - - // Add unmatched left rows to the result - for (Map.Entry entry : hashed.entrySet()) { - if (!matchedKeys.contains(entry.getKey())) { - joinedBuilder.add(entry.getValue()); + if (!matched) { + joinedBuilder.add(streamedRow); } } joinedIterator = joinedBuilder.build().iterator(); } - private Map buildHashed() { - ImmutableMap.Builder leftTableBuilder = ImmutableMap.builder(); - while (left.hasNext()) { - ExprValue row = left.next(); - for (Expression leftKey : leftKeys) { - ExprValue rowKey = leftKey.valueOf(row.bindingTuples()); + private HashedRelation buildHashed(PhysicalPlan buildSide, List buildKeys) { + HashedRelation hashedRelation = new DefaultHashedRelation(); + while (buildSide.hasNext()) { + ExprValue row = buildSide.next(); + for (Expression buildKey : buildKeys) { + ExprValue rowKey = buildKey.valueOf(row.bindingTuples()); if (rowKey != null) { - leftTableBuilder.put(rowKey, row); + hashedRelation.put(rowKey, row); break; } } } - return leftTableBuilder.build(); + return hashedRelation; } @Override @@ -216,10 +261,4 @@ public ExprValue next() { public List getChild() { return ImmutableList.of(left, right); } - - private ExprTupleValue combineExprTupleValue(ExprValue left, ExprValue right) { - Map combinedMap = left.tupleValue(); - combinedMap.putAll(right.tupleValue()); - return ExprTupleValue.fromExprValueMap(combinedMap); - } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java new file mode 100644 index 0000000000..9dbb226175 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashedRelation.java @@ -0,0 +1,32 @@ +package org.opensearch.sql.planner.physical.join; + +import java.util.Iterator; +import java.util.List; +import org.opensearch.sql.data.model.ExprValue; + +public interface HashedRelation { + + /** Return matched rows. */ + List get(ExprValue key); + + /** + * Return the single matched row. Only used in {@link DefaultHashedRelation#isUniqueKey()} is + * true. + */ + ExprValue getValue(ExprValue key); + + /** Whether the key exists. */ + boolean containsKey(ExprValue key); + + /** Return the key iterator. */ + Iterator keyIterator(); + + /** Whether the key is unique. */ + boolean isUniqueKey(); + + /** Put the key-value pair into the relation. */ + void put(ExprValue key, ExprValue value); + + /** Release the resources */ + void close(); +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java index ee2e3ace44..8e60c81415 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java @@ -5,7 +5,16 @@ package org.opensearch.sql.planner.physical.join; +import static org.opensearch.sql.data.type.ExprCoreType.UNDEFINED; + +import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.Expression; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; @@ -18,4 +27,35 @@ public R accept(PhysicalPlanNodeVisitor visitor, C context) { @Override public abstract List getChild(); + + public abstract void innerJoin(Iterator streamedSide); + + public abstract void outerJoin(Iterator streamedSide); + + public abstract void semiJoin(Iterator streamedSide); + + public abstract void antiJoin(Iterator streamedSide); + + protected ExprTupleValue combineExprTupleValue( + BuildSide buildSide, ExprValue streamedRow, ExprValue buildRow) { + ExprValue left = buildSide == BuildSide.BuildLeft ? buildRow : streamedRow; + ExprValue right = buildSide == BuildSide.BuildLeft ? streamedRow : buildRow; + Map leftTuple = left.type().equals(UNDEFINED) ? Map.of() : left.tupleValue(); + Map rightTuple = + right.type().equals(UNDEFINED) ? Map.of() : right.tupleValue(); + Map combinedMap = new LinkedHashMap<>(leftTuple); + combinedMap.putAll(rightTuple); + return ExprTupleValue.fromExprValueMap(combinedMap); + } + + protected boolean sameType(Expression expr1, Expression expr2) { + ExprType type1 = expr1.type(); + ExprType type2 = expr2.type(); + return type1.isCompatible(type2); + } + + public enum BuildSide { + BuildLeft, + BuildRight + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java index c79f820221..f9817b70c6 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java @@ -5,14 +5,12 @@ package org.opensearch.sql.planner.physical.join; -import static org.opensearch.sql.data.type.ExprCoreType.UNDEFINED; +import static org.opensearch.sql.planner.physical.join.JoinOperator.BuildSide.BuildRight; import com.google.common.collect.ImmutableList; import java.util.HashSet; import java.util.Iterator; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.Set; import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.tree.Join; @@ -22,73 +20,67 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.planner.physical.PhysicalPlan; +/** + * Nested Loop Join Operator. For best performance, the build side should be set a smaller table, + * without hint and CBO, we treat right side as a smaller table by default and the build side set to + * right. TODO add join hint support. Best practice in PPL: source=bigger | INNER JOIN smaller ON + * bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is right + * (smaller), and the streamed side is left (bigger). For RIGHT OUTER join, the build side is always + * left. If the smaller table is left, it will get the best performance: source=smaller | RIGHT JOIN + * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is + * left (smaller), and the streamed side is right (bigger). + */ @RequiredArgsConstructor public class NestedLoopJoinOperator extends JoinOperator { private final PhysicalPlan left; private final PhysicalPlan right; private final Join.JoinType joinType; + private final BuildSide buildSide; private final Expression condition; private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); private Iterator joinedIterator; - // Build side is left by default, set the smaller side as the build side in future TODO - private Iterator buildSide; + + private List cachedBuildSide; @Override public void open() { left.open(); right.open(); - - // buildSide is left plan by default - buildSide = left; - - if (joinType == Join.JoinType.INNER) { - List cachedBuildSide = cacheIterator(right); - innerJoin(cachedBuildSide, left); - } else if (joinType == Join.JoinType.LEFT) { - // build side is right plan and streamed side is left plan in left outer join. - List cachedBuildSide = cacheIterator(right); - outerJoin(cachedBuildSide, left); - } else if (joinType == Join.JoinType.RIGHT) { - // build side is left plan and streamed side is right plan in right outer join. - List cachedBuildSide = cacheIterator(left); - outerJoin(cachedBuildSide, right); - } else if (joinType == Join.JoinType.SEMI) { - // build right plan in left semi - // TODO support buildLeft in LEFT SEMI - List cachedBuildSide = cacheIterator(right); - semiJoin(cachedBuildSide, left); - } else if (joinType == Join.JoinType.ANTI) { - // build right plan in left semi - // TODO support buildLeft in LEFT ANTI - List cachedBuildSide = cacheIterator(right); - antiJoin(cachedBuildSide, left); + Iterator streamed; + if (buildSide == BuildRight) { + cachedBuildSide = cacheIterator(right); + streamed = left; } else { - // FullOuter - throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + cachedBuildSide = cacheIterator(left); + streamed = right; } - } - /** Convert iterator to a list to allow multiple iterations */ - private List cacheIterator(PhysicalPlan plan) { - ImmutableList.Builder streamedBuilder = ImmutableList.builder(); - plan.forEachRemaining(streamedBuilder::add); - return streamedBuilder.build(); + switch (joinType) { + case INNER -> innerJoin(streamed); + case LEFT, RIGHT -> outerJoin(streamed); + case SEMI -> semiJoin(streamed); + case ANTI -> antiJoin(streamed); + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + } } @Override public void close() { joinedIterator = null; + cachedBuildSide = null; left.close(); right.close(); } - private void innerJoin(List cachedBuildSide, Iterator streamedSide) { + /** The implementation for inner join: Inner with BuildRight */ + @Override + public void innerJoin(Iterator streamedSide) { while (streamedSide.hasNext()) { ExprValue streamedRow = streamedSide.next(); for (ExprValue buildRow : cachedBuildSide) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && (conditionValue.booleanValue())) { @@ -99,13 +91,14 @@ private void innerJoin(List cachedBuildSide, Iterator stre joinedIterator = joinedBuilder.build().iterator(); } - /** The implementation for LeftOuter with BuildRight, RightOuter with BuildLeft */ - private void outerJoin(List cachedBuildSide, Iterator streamedSide) { + /** The implementation for outer join: LeftOuter with BuildRight RightOuter with BuildLeft */ + @Override + public void outerJoin(Iterator streamedSide) { while (streamedSide.hasNext()) { ExprValue streamedRow = streamedSide.next(); boolean matched = false; for (ExprValue buildRow : cachedBuildSide) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { @@ -114,7 +107,8 @@ private void outerJoin(List cachedBuildSide, Iterator stre } } if (!matched) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, ExprValueUtils.nullValue()); + ExprTupleValue joinedRow = + combineExprTupleValue(buildSide, streamedRow, ExprValueUtils.nullValue()); joinedBuilder.add(joinedRow); } } @@ -122,13 +116,17 @@ private void outerJoin(List cachedBuildSide, Iterator stre joinedIterator = joinedBuilder.build().iterator(); } - private void semiJoin(List cachedBuildSide, Iterator streamedSide) { + /** + * The implementation for left semi join: LeftSemi with BuildRight TODO LeftSemi with buildLeft + */ + @Override + public void semiJoin(Iterator streamedSide) { Set matchedRows = new HashSet<>(); while (streamedSide.hasNext()) { ExprValue streamedRow = streamedSide.next(); for (ExprValue buildRow : cachedBuildSide) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { @@ -145,12 +143,16 @@ private void semiJoin(List cachedBuildSide, Iterator strea joinedIterator = joinedBuilder.build().iterator(); } - private void antiJoin(List cachedBuildSide, Iterator streamedSide) { + /** + * The implementation for left anti join: LeftAnti with BuildRight TODO LeftAnti with buildLeft + */ + @Override + public void antiJoin(Iterator streamedSide) { while (streamedSide.hasNext()) { ExprValue streamedRow = streamedSide.next(); boolean matched = false; for (ExprValue buildRow : cachedBuildSide) { - ExprTupleValue joinedRow = combineExprTupleValue(streamedRow, buildRow); + ExprTupleValue joinedRow = combineExprTupleValue(buildSide, streamedRow, buildRow); ExprValue conditionValue = condition.valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { @@ -166,6 +168,13 @@ private void antiJoin(List cachedBuildSide, Iterator strea joinedIterator = joinedBuilder.build().iterator(); } + /** Convert iterator to a list to allow multiple iterations */ + private List cacheIterator(PhysicalPlan plan) { + ImmutableList.Builder streamedBuilder = ImmutableList.builder(); + plan.forEachRemaining(streamedBuilder::add); + return streamedBuilder.build(); + } + @Override public boolean hasNext() { return joinedIterator != null && joinedIterator.hasNext(); @@ -176,17 +185,6 @@ public ExprValue next() { return joinedIterator.next(); } - private ExprTupleValue combineExprTupleValue(ExprValue streamedRow, ExprValue buildRow) { - ExprValue left = joinType == Join.JoinType.RIGHT ? buildRow : streamedRow; - ExprValue right = joinType == Join.JoinType.RIGHT ? streamedRow : buildRow; - Map leftTuple = left.type().equals(UNDEFINED) ? Map.of() : left.tupleValue(); - Map rightTuple = - right.type().equals(UNDEFINED) ? Map.of() : right.tupleValue(); - Map combinedMap = new LinkedHashMap<>(leftTuple); - combinedMap.putAll(rightTuple); - return ExprTupleValue.fromExprValueMap(combinedMap); - } - @Override public List getChild() { return ImmutableList.of(left, right); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java new file mode 100644 index 0000000000..b0e6b36fc7 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java @@ -0,0 +1,950 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.List; +import java.util.Optional; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinOperator; + +@ExtendWith(MockitoExtension.class) +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +public class HashJoinOperatorTest extends PhysicalPlanTestBase { + private final JoinOperator.BuildSide defaultBuildSide = JoinOperator.BuildSide.BuildRight; + private final Optional defaultNonEquiCond = + Optional.of( + DSL.and( + DSL.equal(DSL.ref("host", STRING), DSL.literal("h1")), + DSL.lte(DSL.ref("id", INTEGER), DSL.literal(5)))); + + @Test + public void inner_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.INNER, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)))); + } + + @Test + public void inner_join_side_exchange_test() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("id", INTEGER)), + ImmutableList.of(DSL.ref("errors", INTEGER)), + Join.JoinType.INNER, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)))); + } + + @Test + public void inner_join_with_non_equi_cond_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.INNER, + defaultBuildSide, + left, + right, + defaultNonEquiCond); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")))); + } + + @Test + public void left_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.LEFT, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void left_join_side_exchange_test() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("id", INTEGER)), + ImmutableList.of(DSL.ref("errors", INTEGER)), + Join.JoinType.LEFT, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void left_join_with_non_equi_cond_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.INNER, + defaultBuildSide, + left, + right, + defaultNonEquiCond); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")))); + } + + @Test + public void right_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.RIGHT, + JoinOperator.BuildSide.BuildLeft, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6, + "id", + 6, + "name", + "f")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10, + "id", + 10, + "name", + "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void right_join_side_exchange_test() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("id", INTEGER)), + ImmutableList.of(DSL.ref("errors", INTEGER)), + Join.JoinType.RIGHT, + JoinOperator.BuildSide.BuildLeft, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(9, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 2, + "name", + "b", + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 10, + "name", + "j", + "day", + new ExprDateValue("2021-01-04"), + "host", + "h2", + "errors", + 10)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 1, + "name", + "a", + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", + 6, + "name", + "f", + "day", + new ExprDateValue("2021-01-07"), + "host", + "h1", + "errors", + 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void right_join_with_non_equi_cond_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.RIGHT, + JoinOperator.BuildSide.BuildLeft, + left, + right, + defaultNonEquiCond); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(12, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-04"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-06"), + "host", + "h1", + "errors", + 1, + "id", + 1, + "name", + "a")), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", + new ExprDateValue("2021-01-03"), + "host", + "h1", + "errors", + 2, + "id", + 2, + "name", + "b")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void semi_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.SEMI, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)))); + } + + @Test + public void semi_join_side_exchange_test() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("id", INTEGER)), + ImmutableList.of(DSL.ref("errors", INTEGER)), + Join.JoinType.SEMI, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)))); + } + + @Test + public void semi_join_non_equi_cond_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.SEMI, + defaultBuildSide, + left, + right, + defaultNonEquiCond); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(3, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2)))); + } + + @Test + public void anti_join_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.ANTI, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(2, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } + + @Test + public void anti_join_side_exchange_test() { + // Exchange the tables + PhysicalPlan left = testScan(countTestInputs); + PhysicalPlan right = testScan(joinTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("id", INTEGER)), + ImmutableList.of(DSL.ref("errors", INTEGER)), + Join.JoinType.ANTI, + defaultBuildSide, + left, + right, + Optional.empty()); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(5, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), + ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + } + + @Test + public void anti_join_non_equi_cond_test() { + PhysicalPlan left = testScan(joinTestInputs); + PhysicalPlan right = testScan(countTestInputs); + PhysicalPlan joinPlan = + new HashJoinOperator( + ImmutableList.of(DSL.ref("errors", INTEGER)), + ImmutableList.of(DSL.ref("id", INTEGER)), + Join.JoinType.ANTI, + defaultBuildSide, + left, + right, + defaultNonEquiCond); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(6, result.size()); + assertThat( + result, + containsInAnyOrder( + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6)), + ExprValueUtils.tupleValue( + ImmutableMap.of("day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java similarity index 89% rename from core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java rename to core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java index 01e8f59ae8..b13bcb47e3 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java @@ -22,22 +22,37 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.planner.physical.join.JoinOperator; import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class JoinOperatorTest extends PhysicalPlanTestBase { +public class NestedLoopJoinOperatorTest extends PhysicalPlanTestBase { + private final JoinOperator.BuildSide defaultBuildSide = JoinOperator.BuildSide.BuildRight; + + private PhysicalPlan makeNestedLoopJoin( + PhysicalPlan left, PhysicalPlan right, Join.JoinType joinType) { + return makeNestedLoopJoin(left, right, joinType, defaultBuildSide); + } + + private PhysicalPlan makeNestedLoopJoin( + PhysicalPlan left, + PhysicalPlan right, + Join.JoinType joinType, + JoinOperator.BuildSide buildSide) { + return new NestedLoopJoinOperator( + left, + right, + joinType, + buildSide, + DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + } @Test - public void nested_loop_inner_join_test() { + public void inner_join_test() { PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.INNER, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.INNER); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); @@ -113,16 +128,11 @@ public void nested_loop_inner_join_test() { } @Test - public void nested_loop_inner_join_test_2() { + public void inner_join_side_exchange_test() { // Exchange the tables PhysicalPlan left = testScan(countTestInputs); PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.INNER, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.INNER); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); @@ -198,15 +208,10 @@ public void nested_loop_inner_join_test_2() { } @Test - public void nested_loop_left_join_test() { + public void left_join_test() { PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.LEFT, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.LEFT); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); @@ -288,16 +293,11 @@ public void nested_loop_left_join_test() { } @Test - public void nested_loop_left_join_test_2() { + public void left_join_side_exchange_test() { // Exchange the tables PhysicalPlan left = testScan(countTestInputs); PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.LEFT, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.LEFT); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); @@ -378,15 +378,11 @@ public void nested_loop_left_join_test_2() { } @Test - public void nested_loop_right_join_test() { + public void right_join_test() { PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.RIGHT, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + makeNestedLoopJoin(left, right, Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); @@ -467,16 +463,12 @@ public void nested_loop_right_join_test() { } @Test - public void nested_loop_right_join_test_2() { + public void right_join_side_exchange_test() { // Exchange the tables PhysicalPlan left = testScan(countTestInputs); PhysicalPlan right = testScan(joinTestInputs); PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.RIGHT, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + makeNestedLoopJoin(left, right, Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); @@ -558,15 +550,10 @@ public void nested_loop_right_join_test_2() { } @Test - public void nested_loop_semi_join_test() { + public void semi_join_test() { PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.SEMI, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.SEMI); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); @@ -591,16 +578,11 @@ public void nested_loop_semi_join_test() { } @Test - public void nested_loop_semi_join_test_2() { + public void semi_join_side_exchange_test() { // Exchange the tables PhysicalPlan left = testScan(countTestInputs); PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.SEMI, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.SEMI); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(6, result.size()); @@ -616,15 +598,10 @@ public void nested_loop_semi_join_test_2() { } @Test - public void nested_loop_anti_join_test() { + public void anti_join_test() { PhysicalPlan left = testScan(joinTestInputs); PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.ANTI, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.ANTI); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(2, result.size()); @@ -640,16 +617,11 @@ public void nested_loop_anti_join_test() { } @Test - public void nested_loop_anti_join_test_2() { + public void anti_join_side_exchange_test() { // Exchange the tables PhysicalPlan left = testScan(countTestInputs); PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = - new NestedLoopJoinOperator( - left, - right, - Join.JoinType.ANTI, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); + PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.ANTI); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(5, result.size()); From 1b83310aea0e85d430ea65c7c5d14d74fa712fdc Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 4 Sep 2024 23:26:36 +0800 Subject: [PATCH 4/4] Resolve ambiguous join condition by adding output schema in join Signed-off-by: Lantao Jin --- .../ExpressionReferenceOptimizer.java | 60 ++ .../sql/planner/DefaultImplementor.java | 1 + .../physical/datasource/DataSourceTable.java | 2 +- .../datasource/DataSourceTableScan.java | 23 + .../physical/join/HashJoinOperator.java | 31 +- .../planner/physical/join/JoinOperator.java | 58 +- .../physical/join/NestedLoopJoinOperator.java | 20 +- .../opensearch/sql/analysis/AnalyzerTest.java | 20 + .../org/opensearch/sql/config/TestConfig.java | 11 + .../physical/HashJoinOperatorTest.java | 839 +++--------------- .../physical/JoinOperatorTestHelper.java | 780 ++++++++++++++++ .../physical/NestedLoopJoinOperatorTest.java | 605 +++---------- .../physical/PhysicalPlanTestBase.java | 74 +- 13 files changed, 1250 insertions(+), 1274 deletions(-) create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index c9b618a70f..e598bb2efc 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -5,11 +5,15 @@ package org.opensearch.sql.analysis; +import static org.opensearch.sql.common.utils.StringUtils.format; + import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; +import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.FunctionExpression; @@ -23,6 +27,7 @@ import org.opensearch.sql.planner.logical.LogicalAggregation; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.logical.LogicalWindow; /** @@ -48,6 +53,11 @@ public class ExpressionReferenceOptimizer */ private final Map expressionMap = new HashMap<>(); + private String leftRelationName; + private String rightRelationName; + private Set leftSideAttributes; + private Set rightSideAttributes; + public ExpressionReferenceOptimizer( BuiltinFunctionRepository repository, LogicalPlan logicalPlan) { this.repository = repository; @@ -57,6 +67,17 @@ public ExpressionReferenceOptimizer( public ExpressionReferenceOptimizer( BuiltinFunctionRepository repository, LogicalPlan... logicalPlans) { this.repository = repository; + // To resolve join condition, we store left side and left side of join. + if (logicalPlans.length == 2) { + // TODO current implementation only support two-tables join, so we can directly convert them + // to LogicalRelation. To support two-plans join, we can get the LogicalRelation by searching. + this.leftRelationName = ((LogicalRelation) logicalPlans[0]).getRelationName(); + this.rightRelationName = ((LogicalRelation) logicalPlans[1]).getRelationName(); + this.leftSideAttributes = + ((LogicalRelation) logicalPlans[0]).getTable().getFieldTypes().keySet(); + this.rightSideAttributes = + ((LogicalRelation) logicalPlans[1]).getTable().getFieldTypes().keySet(); + } Arrays.stream(logicalPlans).forEach(p -> p.accept(new ExpressionMapBuilder(), null)); } @@ -69,6 +90,45 @@ public Expression visitNode(Expression node, AnalysisContext context) { return node; } + /** + * Add index prefix to reference attribute of join condition. The attribute could be: case 1: + * Field -> Index.Field case 2: Field.Field -> Index.Field.Field case 3: .Index.Field, + * .Index.Field.Field -> do nothing case 4: Index.Field, Index.Field.Field -> do nothing + */ + @Override + public Expression visitReference(ReferenceExpression node, AnalysisContext context) { + if (leftRelationName == null || rightRelationName == null) { + return node; + } + + String attr = node.getAttr(); + // case 1 or case 2 + if (!attr.contains(".") || (!attr.startsWith(".") && !isIndexPrefix(attr))) { + return replaceReferenceExpressionWithIndexPrefix(node, attr); + } + return node; + } + + private ReferenceExpression replaceReferenceExpressionWithIndexPrefix( + ReferenceExpression node, String attr) { + if (leftSideAttributes.contains(attr) && rightSideAttributes.contains(attr)) { + throw new SemanticCheckException(format("Reference `%s` is ambiguous", attr)); + } else if (leftSideAttributes.contains(attr)) { + return new ReferenceExpression(format("%s.%s", leftRelationName, attr), node.type()); + } else if (rightSideAttributes.contains(attr)) { + return new ReferenceExpression(format("%s.%s", rightRelationName, attr), node.type()); + } else { + return node; + } + } + + private boolean isIndexPrefix(String attr) { + int separator = attr.indexOf('.'); + String possibleIndexPrefix = attr.substring(0, separator); + return leftRelationName.contains(possibleIndexPrefix) + || rightRelationName.contains(possibleIndexPrefix); + } + @Override public Expression visitFunction(FunctionExpression node, AnalysisContext context) { if (expressionMap.containsKey(node)) { diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index 4dc0f5a27c..54fd10010b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -253,6 +253,7 @@ private JoinOperator.BuildSide determineBuildSide(Join.JoinType joinType) { /** Return true if the reference can be evaluated in relation */ private boolean canEvaluate(ReferenceExpression expr, LogicalPlan plan) { if (plan instanceof LogicalRelation relation) { + // TODO need fix, the attr() contains relation prefix: Index.Field return relation.getTable().getFieldTypes().containsKey(expr.getAttr()); } else { throw new UnsupportedOperationException("Only relation can be used in join"); diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java index 5542d0f0e4..9606f1cef9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTable.java @@ -47,7 +47,7 @@ public static class DataSourceTableDefaultImplementor extends DefaultImplementor @Override public PhysicalPlan visitRelation(LogicalRelation node, Object context) { - return new DataSourceTableScan(dataSourceService); + return new DataSourceTableScan(dataSourceService, node); } } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java index 89e21377dc..b2c6fb737d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/datasource/DataSourceTableScan.java @@ -14,11 +14,14 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.datasource.DataSourceService; import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.storage.TableScanOperator; /** @@ -29,11 +32,19 @@ public class DataSourceTableScan extends TableScanOperator { private final DataSourceService dataSourceService; + private final LogicalRelation relation; + private final String relationName; private Iterator iterator; public DataSourceTableScan(DataSourceService dataSourceService) { + this(dataSourceService, null); + } + + public DataSourceTableScan(DataSourceService dataSourceService, LogicalRelation relation) { this.dataSourceService = dataSourceService; + this.relation = relation; + this.relationName = relation.getRelationName(); this.iterator = Collections.emptyIterator(); } @@ -68,4 +79,16 @@ public boolean hasNext() { public ExprValue next() { return iterator.next(); } + + @Override + public ExecutionEngine.Schema schema() { + List columns = + relation.getTable().getFieldTypes().entrySet().stream() + .map( + (entry) -> + new ExecutionEngine.Schema.Column( + entry.getKey(), relationName + "." + entry.getKey(), entry.getValue())) + .collect(Collectors.toList()); + return new ExecutionEngine.Schema(columns); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java index 0be6612e96..112484c19c 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/HashJoinOperator.java @@ -14,7 +14,6 @@ import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; -import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -32,16 +31,28 @@ * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is * left (smaller), and the streamed side is right (bigger). */ -@RequiredArgsConstructor public class HashJoinOperator extends JoinOperator { private final List leftKeys; private final List rightKeys; - private final Join.JoinType joinType; private final BuildSide buildSide; - private final PhysicalPlan left; - private final PhysicalPlan right; private final Optional nonEquiCond; + // write the construct method + public HashJoinOperator( + List leftKeys, + List rightKeys, + Join.JoinType joinType, + BuildSide buildSide, + PhysicalPlan left, + PhysicalPlan right, + Optional nonEquiCond) { + super(left, right, joinType); + this.leftKeys = leftKeys; + this.rightKeys = rightKeys; + this.buildSide = buildSide; + this.nonEquiCond = nonEquiCond; + } + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); private Iterator joinedIterator; @@ -51,6 +62,8 @@ public class HashJoinOperator extends JoinOperator { @Override public void open() { + left.open(); + right.open(); if (!(leftKeys.size() == rightKeys.size() && IntStream.range(0, leftKeys.size()) .allMatch(i -> sameType(leftKeys.get(i), rightKeys.get(i))))) { @@ -58,8 +71,6 @@ public void open() { "Join keys from two sides should have same length and types"); } - left.open(); - right.open(); Iterator streamed; if (buildSide == BuildRight) { hashed = buildHashed(right, rightKeys); @@ -84,13 +95,13 @@ public void open() { @Override public void close() { + left.close(); + right.close(); joinedIterator = null; if (hashed != null) { hashed.close(); hashed = null; } - left.close(); - right.close(); } @Override @@ -208,7 +219,7 @@ public void antiJoin(Iterator streamed) { List matchedBuildRows = hashed.get(streamedRowKey); for (ExprValue matchedBuildRow : matchedBuildRows) { if (nonEquiCond.isPresent()) { - ExprValue joinedRow = combineExprTupleValue(buildSide, matchedBuildRow, streamedRow); + ExprValue joinedRow = combineExprTupleValue(buildSide, streamedRow, matchedBuildRow); ExprValue conditionValue = nonEquiCond.get().valueOf(joinedRow.bindingTuples()); if (!(conditionValue.isNull() || conditionValue.isMissing()) && conditionValue.booleanValue()) { diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java index 8e60c81415..d884bb643c 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/JoinOperator.java @@ -5,20 +5,54 @@ package org.opensearch.sql.planner.physical.join; -import static org.opensearch.sql.data.type.ExprCoreType.UNDEFINED; - +import java.util.Collection; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanNodeVisitor; public abstract class JoinOperator extends PhysicalPlan { + protected PhysicalPlan left; + protected PhysicalPlan right; + protected Join.JoinType joinType; + + protected ExecutionEngine.Schema leftSchema; + protected ExecutionEngine.Schema rightSchema; + protected ExecutionEngine.Schema outputSchema; + + JoinOperator(PhysicalPlan left, PhysicalPlan right, Join.JoinType joinType) { + this.left = left; + this.right = right; + this.joinType = joinType; + this.leftSchema = left.schema(); + this.rightSchema = right.schema(); + getOutputSchema(); + } + + private void getOutputSchema() { + switch (joinType) { + case INNER, LEFT, RIGHT, FULL -> { // merge left and right schemas + List columns = + Stream.of(left.schema().getColumns(), right.schema().getColumns()) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + this.outputSchema = new ExecutionEngine.Schema(columns); + } + case SEMI, ANTI -> outputSchema = left.schema(); // left schema only + default -> throw new UnsupportedOperationException("Unsupported Join Type " + joinType); + } + } @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -40,14 +74,28 @@ protected ExprTupleValue combineExprTupleValue( BuildSide buildSide, ExprValue streamedRow, ExprValue buildRow) { ExprValue left = buildSide == BuildSide.BuildLeft ? buildRow : streamedRow; ExprValue right = buildSide == BuildSide.BuildLeft ? streamedRow : buildRow; - Map leftTuple = left.type().equals(UNDEFINED) ? Map.of() : left.tupleValue(); - Map rightTuple = - right.type().equals(UNDEFINED) ? Map.of() : right.tupleValue(); + Map leftTuple = getExprTupleMapFromSchema(left, leftSchema); + Map rightTuple = getExprTupleMapFromSchema(right, rightSchema); Map combinedMap = new LinkedHashMap<>(leftTuple); combinedMap.putAll(rightTuple); return ExprTupleValue.fromExprValueMap(combinedMap); } + private Map getExprTupleMapFromSchema( + ExprValue row, ExecutionEngine.Schema schema) { + Map map = new LinkedHashMap<>(); + if (row.isNull()) { + schema.getColumns().forEach(col -> map.put(col.getAlias(), ExprNullValue.of())); + } else { + // replace to indexName.fieldName as tupleMap key in case the field names are same in join + // tables. + schema + .getColumns() + .forEach(col -> map.put(col.getAlias(), row.tupleValue().get(col.getName()))); + } + return map; + } + protected boolean sameType(Expression expr1, Expression expr2) { ExprType type1 = expr1.type(); ExprType type2 = expr2.type(); diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java index f9817b70c6..800d3798a3 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/join/NestedLoopJoinOperator.java @@ -12,7 +12,6 @@ import java.util.Iterator; import java.util.List; import java.util.Set; -import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.tree.Join; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; @@ -30,14 +29,21 @@ * bigger ON bigger.field1 = smaller.field2 AND bigger.field3 = smaller.field4 The build side is * left (smaller), and the streamed side is right (bigger). */ -@RequiredArgsConstructor public class NestedLoopJoinOperator extends JoinOperator { - private final PhysicalPlan left; - private final PhysicalPlan right; - private final Join.JoinType joinType; private final BuildSide buildSide; private final Expression condition; + public NestedLoopJoinOperator( + PhysicalPlan left, + PhysicalPlan right, + Join.JoinType joinType, + BuildSide buildSide, + Expression condition) { + super(left, right, joinType); + this.buildSide = buildSide; + this.condition = condition; + } + private final ImmutableList.Builder joinedBuilder = ImmutableList.builder(); private Iterator joinedIterator; @@ -67,10 +73,10 @@ public void open() { @Override public void close() { - joinedIterator = null; - cachedBuildSide = null; left.close(); right.close(); + joinedIterator = null; + cachedBuildSide = null; } /** The implementation for inner join: Inner with BuildRight */ diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 3668b13485..f12ddf5fd6 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -1973,4 +1973,24 @@ public void basic_SPJG() { AstDSL.alias("schema1.string_value", AstDSL.field("schema1.string_value")), AstDSL.alias("schema2.string_value", AstDSL.field("schema2.string_value")))); } + + @Test + public void join_condition_is_ambiguous() { + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.join( + AstDSL.relation("schema1"), + AstDSL.relation("schema2"), + Join.JoinType.INNER, + AstDSL.and( + AstDSL.equalTo( + AstDSL.field("schema1.integer_value"), + AstDSL.field("schema2.integer_value")), + AstDSL.equalTo( + AstDSL.field("double_value"), AstDSL.field("double_value")))))); + assertEquals("Reference `double_value` is ambiguous", exception.getMessage()); + } } diff --git a/core/src/test/java/org/opensearch/sql/config/TestConfig.java b/core/src/test/java/org/opensearch/sql/config/TestConfig.java index 92b6aac64f..3c12c4a1a6 100644 --- a/core/src/test/java/org/opensearch/sql/config/TestConfig.java +++ b/core/src/test/java/org/opensearch/sql/config/TestConfig.java @@ -61,6 +61,17 @@ public class TestConfig { .put("comment.data", ExprCoreType.STRING) .build(); + public static Map typeMapping2 = + new ImmutableMap.Builder() + .put("i_value", ExprCoreType.INTEGER) + .put("l_value", ExprCoreType.LONG) + .put("f_value", ExprCoreType.FLOAT) + .put("d_value", ExprCoreType.DOUBLE) + .put("msg", ExprCoreType.STRING) + .put("msg.info", ExprCoreType.STRING) + .put("msg.info.id", ExprCoreType.STRING) + .build(); + protected StorageEngine storageEngine() { return new StorageEngine() { @Override diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java index b0e6b36fc7..416bde5c3a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/HashJoinOperatorTest.java @@ -11,12 +11,10 @@ import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Optional; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -26,759 +24,212 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.planner.physical.join.HashJoinOperator; import org.opensearch.sql.planner.physical.join.JoinOperator; @ExtendWith(MockitoExtension.class) -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class HashJoinOperatorTest extends PhysicalPlanTestBase { - private final JoinOperator.BuildSide defaultBuildSide = JoinOperator.BuildSide.BuildRight; +public class HashJoinOperatorTest extends JoinOperatorTestHelper { + private final Optional emptyNonEquiCond = Optional.empty(); private final Optional defaultNonEquiCond = Optional.of( DSL.and( - DSL.equal(DSL.ref("host", STRING), DSL.literal("h1")), - DSL.lte(DSL.ref("id", INTEGER), DSL.literal(5)))); + DSL.equal(DSL.ref("error_t.host", STRING), DSL.literal("h1")), + DSL.lte(DSL.ref("name_t.id", INTEGER), DSL.literal(5)))); @Test public void inner_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.INNER, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10)); } @Test - public void inner_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void inner_join_side_reversed_test() { PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("id", INTEGER)), - ImmutableList.of(DSL.ref("errors", INTEGER)), - Join.JoinType.INNER, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10)); } @Test public void inner_join_with_non_equi_cond_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.INNER, - defaultBuildSide, - left, - right, - defaultNonEquiCond); + makeHashJoin( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(3, result.size()); - assertThat( - result, - containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")))); + assertThat(result, containsInAnyOrder(error1_id1, error1_id1_duplicated, error2_id2)); } @Test public void left_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.LEFT, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + error12_null, + error13_null)); } @Test - public void left_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void left_join_side_reversed_test() { PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("id", INTEGER)), - ImmutableList.of(DSL.ref("errors", INTEGER)), - Join.JoinType.LEFT, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + id4_null, + id5_null, + id7_null, + id9_null, + id11_null)); } @Test public void left_join_with_non_equi_cond_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.INNER, - defaultBuildSide, - left, - right, - defaultNonEquiCond); + makeHashJoin( + Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); - assertEquals(3, result.size()); + assertEquals(9, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_null, + error6_null, + error8_null, + error10_null, + error12_null, + error13_null)); } @Test public void right_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.RIGHT, - JoinOperator.BuildSide.BuildLeft, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, emptyNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + null_id4, + null_id5, + null_id7, + null_id9, + null_id11)); } @Test - public void right_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void right_join_side_reversed_test() { PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("id", INTEGER)), - ImmutableList.of(DSL.ref("errors", INTEGER)), - Join.JoinType.RIGHT, - JoinOperator.BuildSide.BuildLeft, - left, - right, - Optional.empty()); + makeHashJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, emptyNonEquiCond, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + null_error12, + null_error13)); } @Test public void right_join_with_non_equi_cond_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.RIGHT, - JoinOperator.BuildSide.BuildLeft, - left, - right, - defaultNonEquiCond); + makeHashJoin( + Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, defaultNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + error1_id1, + error1_id1_duplicated, + error2_id2, + null_id3, + null_id4, + null_id5, + null_id6, + null_id7, + null_id8, + null_id9, + null_id10, + null_id11)); } @Test public void semi_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.SEMI, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); @@ -803,19 +254,9 @@ public void semi_join_test() { } @Test - public void semi_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void semi_join_side_reversed_test() { PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("id", INTEGER)), - ImmutableList.of(DSL.ref("errors", INTEGER)), - Join.JoinType.SEMI, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(6, result.size()); @@ -826,23 +267,27 @@ public void semi_join_side_exchange_test() { ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)))); + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + }), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + }))); } @Test public void semi_join_non_equi_cond_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.SEMI, - defaultBuildSide, - left, - right, - defaultNonEquiCond); + makeHashJoin( + Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(3, result.size()); @@ -860,17 +305,9 @@ public void semi_join_non_equi_cond_test() { @Test public void anti_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.ANTI, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin( + Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(2, result.size()); @@ -886,19 +323,9 @@ public void anti_join_test() { } @Test - public void anti_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void anti_join_side_reversed_test() { PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("id", INTEGER)), - ImmutableList.of(DSL.ref("errors", INTEGER)), - Join.JoinType.ANTI, - defaultBuildSide, - left, - right, - Optional.empty()); + makeHashJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, emptyNonEquiCond, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(5, result.size()); @@ -914,17 +341,9 @@ public void anti_join_side_exchange_test() { @Test public void anti_join_non_equi_cond_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - new HashJoinOperator( - ImmutableList.of(DSL.ref("errors", INTEGER)), - ImmutableList.of(DSL.ref("id", INTEGER)), - Join.JoinType.ANTI, - defaultBuildSide, - left, - right, - defaultNonEquiCond); + makeHashJoin( + Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, defaultNonEquiCond, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(6, result.size()); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java new file mode 100644 index 0000000000..72f4d9b244 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/JoinOperatorTestHelper.java @@ -0,0 +1,780 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Optional; +import org.opensearch.sql.ast.tree.Join; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.planner.physical.join.HashJoinOperator; +import org.opensearch.sql.planner.physical.join.JoinOperator; +import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; + +public class JoinOperatorTestHelper extends PhysicalPlanTestBase { + + private final List errorInputs = + new ImmutableList.Builder() + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12))) + .add( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13))) + .build(); + + private final List nameInputs = + new ImmutableList.Builder() + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k"))) + .build(); + + private final List sameNameInputs = + new ImmutableList.Builder() + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3, "name", "c"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 5); + put("name", null); + } + })) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "jj"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "jjj"))) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 15, "name", "o"))) + .add( + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 16); + put("name", null); + } + })) + .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 17, "name", "q"))) + .build(); + + private final ExecutionEngine.Schema errorSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("day", "error_t.day", DATE), + new ExecutionEngine.Schema.Column("host", "error_t.host", STRING), + new ExecutionEngine.Schema.Column("errors", "error_t.errors", INTEGER))); + + private final ExecutionEngine.Schema nameSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("id", "name_t.id", INTEGER), + new ExecutionEngine.Schema.Column("name", "name_t.name", STRING))); + + private final ExecutionEngine.Schema sameNameSchema = + new ExecutionEngine.Schema( + List.of( + new ExecutionEngine.Schema.Column("id", "name_t2.id", INTEGER), + new ExecutionEngine.Schema.Column("name", "name_t2.name", STRING))); + + public PhysicalPlan makeNestedLoopJoin( + Join.JoinType joinType, JoinOperator.BuildSide buildSide, boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("error_t", errorSchema, errorInputs); + PhysicalPlan right = + reversed + ? testTableScan("error_t", errorSchema, errorInputs) + : testTableScan("name_t", nameSchema, nameInputs); + return new NestedLoopJoinOperator( + left, + right, + joinType, + buildSide, + DSL.equal(DSL.ref("error_t.errors", INTEGER), DSL.ref("name_t.id", INTEGER))); + } + + public PhysicalPlan makeNestedLoopJoinWithSameColumnNames( + Join.JoinType joinType, JoinOperator.BuildSide buildSide, boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t2", sameNameSchema, sameNameInputs) + : testTableScan("name_t", nameSchema, nameInputs); + PhysicalPlan right = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("name_t2", sameNameSchema, sameNameInputs); + return new NestedLoopJoinOperator( + left, + right, + joinType, + buildSide, + DSL.equal(DSL.ref("name_t.id", INTEGER), DSL.ref("name_t2.id", INTEGER))); + } + + public PhysicalPlan makeHashJoin( + Join.JoinType joinType, + JoinOperator.BuildSide buildSide, + Optional nonEquiCond, + boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("error_t", errorSchema, errorInputs); + PhysicalPlan right = + reversed + ? testTableScan("error_t", errorSchema, errorInputs) + : testTableScan("name_t", nameSchema, nameInputs); + List leftKeys = + reversed + ? ImmutableList.of(DSL.ref("id", INTEGER)) + : ImmutableList.of(DSL.ref("errors", INTEGER)); + + List rightKeys = + reversed + ? ImmutableList.of(DSL.ref("errors", INTEGER)) + : ImmutableList.of(DSL.ref("id", INTEGER)); + return new HashJoinOperator(leftKeys, rightKeys, joinType, buildSide, left, right, nonEquiCond); + } + + public PhysicalPlan makeHashJoinWithSameColumnNames( + Join.JoinType joinType, + JoinOperator.BuildSide buildSide, + Optional nonEquiCond, + boolean reversed) { + PhysicalPlan left = + reversed + ? testTableScan("name_t", nameSchema, nameInputs) + : testTableScan("name_t2", sameNameSchema, sameNameInputs); + PhysicalPlan right = + reversed + ? testTableScan("name_t2", sameNameSchema, sameNameInputs) + : testTableScan("name_t", nameSchema, nameInputs); + List leftKeys = ImmutableList.of(DSL.ref("id", INTEGER)); + + List rightKeys = ImmutableList.of(DSL.ref("id", INTEGER)); + return new HashJoinOperator(leftKeys, rightKeys, joinType, buildSide, left, right, nonEquiCond); + } + + /** {day:DATE '2021-01-04',host:"h1",errors:1,id:1,name:"a"} */ + protected ExprValue error1_id1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h1", + "error_t.errors", + 1, + "name_t.id", + 1, + "name_t.name", + "a")); + + /** {day:DATE '2021-01-06',host:"h1",errors:1,id:1,name:"a"} */ + protected ExprValue error1_id1_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-06"), + "error_t.host", + "h1", + "error_t.errors", + 1, + "name_t.id", + 1, + "name_t.name", + "a")); + + /** {day:DATE '2021-01-03',host:"h1",errors:2,id:2,name:"b"} */ + protected ExprValue error2_id2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-03"), + "error_t.host", + "h1", + "error_t.errors", + 2, + "name_t.id", + 2, + "name_t.name", + "b")); + + /** {day:DATE '2021-01-03',host:"h2",errors:3,id:3,name:NULL} */ + protected ExprValue error3_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + put("name_t.id", 3); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-03',host:"h2",errors:3,id:NULL,name:NULL} */ + protected ExprValue error3_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h1",errors:6,id:6,name:"f"} */ + protected ExprValue error6_id6 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-07"), + "error_t.host", + "h1", + "error_t.errors", + 6, + "name_t.id", + 6, + "name_t.name", + "f")); + + /** {day:DATE '2021-01-07',host:"h1",errors:6,id:NULL,name:NULL} */ + protected ExprValue error6_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h1"); + put("error_t.errors", 6); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:8,id:8,name:NULL} */ + protected ExprValue error8_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + put("name_t.id", 8); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:8,id:NULL,name:NULL} */ + protected ExprValue error8_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-04',host:"h2",errors:10,id:10,name:"j"} */ + protected ExprValue error10_id10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h2", + "error_t.errors", + 10, + "name_t.id", + 10, + "name_t.name", + "j")); + + /** {day:DATE '2021-01-04',host:"h2",errors:10,id:NULL,name:NULL} */ + protected ExprValue error10_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-04")); + put("error_t.host", "h2"); + put("error_t.errors", 10); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-07',host:"h2",errors:12,id:NULL,name:NULL} */ + protected ExprValue error12_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 12); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {day:DATE '2021-01-08',host:"h1",errors:13,id:NULL,name:NULL} */ + protected ExprValue error13_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", new ExprDateValue("2021-01-08")); + put("error_t.host", "h1"); + put("error_t.errors", 13); + put("name_t.id", null); + put("name_t.name", null); + } + }); + + /** {id:1,name:"a",day:DATE '2021-01-04',host:"h1",errors:1} */ + protected ExprValue id1_error1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 1, + "name_t.name", + "a", + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h1", + "error_t.errors", + 1)); + + /** {id:1,name:"a",day:DATE '2021-01-06',host:"h1",errors:1} */ + protected ExprValue id1_error1_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 1, + "name_t.name", + "a", + "error_t.day", + new ExprDateValue("2021-01-06"), + "error_t.host", + "h1", + "error_t.errors", + 1)); + + /** {id:2,name:"b",day:DATE '2021-01-03',host:"h1",errors:2} */ + protected ExprValue id2_error2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 2, + "name_t.name", + "b", + "error_t.day", + new ExprDateValue("2021-01-03"), + "error_t.host", + "h1", + "error_t.errors", + 2)); + + /** {id:3,name:NULL,day:DATE '2021-01-03',host:"h2",errors:3} */ + protected ExprValue id3_error3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 3); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-03")); + put("error_t.host", "h2"); + put("error_t.errors", 3); + } + }); + + /** {id:4,name:"d",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id4_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 4); + put("name_t.name", "d"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:5,name:"e",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id5_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 5); + put("name_t.name", "e"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:6,name:"f",day:DATE '2021-01-07',host:"h1",errors:6} */ + protected ExprValue id6_error6 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 6, + "name_t.name", + "f", + "error_t.day", + new ExprDateValue("2021-01-07"), + "error_t.host", + "h1", + "error_t.errors", + 6)); + + /** {id:7,name:"g",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id7_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 7); + put("name_t.name", "g"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:8,name:NULL,day:DATE '2021-01-07',host:"h2",errors:8} */ + protected ExprValue id8_error8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 8); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 8); + } + }); + + /** {id:9,name:"i",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id9_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 9); + put("name_t.name", "i"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {id:10,name:"j",day:DATE '2021-01-04',host:"h2",errors:10} */ + protected ExprValue id10_error10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", + 10, + "name_t.name", + "j", + "error_t.day", + new ExprDateValue("2021-01-04"), + "error_t.host", + "h2", + "error_t.errors", + 10)); + + /** {id:11,name:"k",day:NULL,host:NULL,errors:NULL} */ + protected ExprValue id11_null = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 11); + put("name_t.name", "k"); + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:3,name:NULL} */ + protected ExprValue null_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 3); + put("name_t.name", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:4,name:"d"} */ + protected ExprValue null_id4 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 4); + put("name_t.name", "d"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:5,name:"e"} */ + protected ExprValue null_id5 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 5); + put("name_t.name", "e"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:6,name:"f"} */ + protected ExprValue null_id6 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 6); + put("name_t.name", "f"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:7,name:"g"} */ + protected ExprValue null_id7 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 7); + put("name_t.name", "g"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:8,name:NULL} */ + protected ExprValue null_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 8); + put("name_t.name", null); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:9,name:"i"} */ + protected ExprValue null_id9 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 9); + put("name_t.name", "i"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:10,name:"j"} */ + protected ExprValue null_id10 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 10); + put("name_t.name", "j"); + } + }); + + /** {day:NULL,host:NULL,errors:NULL,id:11,name:"k"} */ + protected ExprValue null_id11 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("error_t.day", null); + put("error_t.host", null); + put("error_t.errors", null); + put("name_t.id", 11); + put("name_t.name", "k"); + } + }); + + /** {id:NULL,name:NULL,day:DATE '2021-01-07',host:"h2",errors:12} */ + protected ExprValue null_error12 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", null); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-07")); + put("error_t.host", "h2"); + put("error_t.errors", 12); + } + }); + + /** {id:NULL,name:NULL,day:DATE '2021-01-08',host:"h1",errors:13} */ + protected ExprValue null_error13 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", null); + put("name_t.name", null); + put("error_t.day", new ExprDateValue("2021-01-08")); + put("error_t.host", "h1"); + put("error_t.errors", 13); + } + }); + + /** {name_t.id:1,name_t.name:"a",name_t2.id:1,name_t2.name:"a"} */ + ExprValue id1_same_id1 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 1, "name_t.name", "a", "name_t2.id", 1, "name_t2.name", "a")); + + /** {name_t.id:3,name_t.name:NULL,name_t2.id:3,name_t2.name:"c"} */ + ExprValue id3_same_id3 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 3); + put("name_t.name", null); + put("name_t2.id", 3); + put("name_t2.name", "c"); + } + }); + + /** {name_t.id:5,name_t.name:"e",name_t2.id:5,name_t2.name:NULL} */ + ExprValue id5_same_id5 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 5); + put("name_t.name", "e"); + put("name_t2.id", 5); + put("name_t2.name", null); + } + }); + + /** {name_t.id:8,name_t.name:NULL,name_t2.id:8,name_t2.name:NULL} */ + ExprValue id8_same_id8 = + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("name_t.id", 8); + put("name_t.name", null); + put("name_t2.id", 8); + put("name_t2.name", null); + } + }); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"j"} */ + ExprValue id10_same_id10 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "j")); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"jj"} */ + ExprValue id10_same_id10_duplicated = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "jj")); + + /** {name_t.id:10,name_t.name:"j",name_t2.id:10,name_t2.name:"jjj"} */ + ExprValue id10_same_id10_duplicated2 = + ExprValueUtils.tupleValue( + ImmutableMap.of( + "name_t.id", 10, "name_t.name", "j", "name_t2.id", 10, "name_t2.name", "jjj")); +} diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java index b13bcb47e3..8e78954662 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/NestedLoopJoinOperatorTest.java @@ -8,12 +8,10 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import com.google.common.collect.ImmutableMap; +import java.util.LinkedHashMap; import java.util.List; -import org.junit.jupiter.api.DisplayNameGeneration; -import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -21,539 +19,143 @@ import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.expression.DSL; import org.opensearch.sql.planner.physical.join.JoinOperator; -import org.opensearch.sql.planner.physical.join.NestedLoopJoinOperator; @ExtendWith(MockitoExtension.class) -@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class NestedLoopJoinOperatorTest extends PhysicalPlanTestBase { - private final JoinOperator.BuildSide defaultBuildSide = JoinOperator.BuildSide.BuildRight; - - private PhysicalPlan makeNestedLoopJoin( - PhysicalPlan left, PhysicalPlan right, Join.JoinType joinType) { - return makeNestedLoopJoin(left, right, joinType, defaultBuildSide); - } - - private PhysicalPlan makeNestedLoopJoin( - PhysicalPlan left, - PhysicalPlan right, - Join.JoinType joinType, - JoinOperator.BuildSide buildSide) { - return new NestedLoopJoinOperator( - left, - right, - joinType, - buildSide, - DSL.equal(DSL.ref("errors", INTEGER), DSL.ref("id", INTEGER))); - } +public class NestedLoopJoinOperatorTest extends JoinOperatorTestHelper { @Test public void inner_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.INNER); + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10)); } @Test - public void inner_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.INNER); + public void inner_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10)); } @Test public void left_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.LEFT); + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + error12_null, + error13_null)); } @Test - public void left_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.LEFT); + public void left_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.LEFT, JoinOperator.BuildSide.BuildRight, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + id4_null, + id5_null, + id7_null, + id9_null, + id11_null)); } @Test public void right_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); PhysicalPlan joinPlan = - makeNestedLoopJoin(left, right, Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft); + makeNestedLoopJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(12, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1, - "id", - 1, - "name", - "a")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2, - "id", - 2, - "name", - "b")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3, "id", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6, - "id", - 6, - "name", - "f")), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8, "id", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10, - "id", - 10, - "name", - "j")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); + error1_id1, + error1_id1_duplicated, + error2_id2, + error3_id3, + error6_id6, + error8_id8, + error10_id10, + null_id4, + null_id5, + null_id7, + null_id9, + null_id11)); } @Test - public void right_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); + public void right_join_side_reversed_test() { PhysicalPlan joinPlan = - makeNestedLoopJoin(left, right, Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft); + makeNestedLoopJoin(Join.JoinType.RIGHT, JoinOperator.BuildSide.BuildLeft, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(9, result.size()); assertThat( result, containsInAnyOrder( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 2, - "name", - "b", - "day", - new ExprDateValue("2021-01-03"), - "host", - "h1", - "errors", - 2)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 3, "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 10, - "name", - "j", - "day", - new ExprDateValue("2021-01-04"), - "host", - "h2", - "errors", - 10)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 1, - "name", - "a", - "day", - new ExprDateValue("2021-01-06"), - "host", - "h1", - "errors", - 1)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", - 6, - "name", - "f", - "day", - new ExprDateValue("2021-01-07"), - "host", - "h1", - "errors", - 6)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "id", 8, "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12)), - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13)))); + id1_error1, + id1_error1_duplicated, + id2_error2, + id3_error3, + id6_error6, + id8_error8, + id10_error10, + null_error12, + null_error13)); } @Test public void semi_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.SEMI); + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(7, result.size()); @@ -578,11 +180,9 @@ public void semi_join_test() { } @Test - public void semi_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.SEMI); + public void semi_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.SEMI, JoinOperator.BuildSide.BuildRight, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(6, result.size()); @@ -593,15 +193,26 @@ public void semi_join_side_exchange_test() { ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j")), ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a")), ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f")), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 3)), - ExprValueUtils.tupleValue(ImmutableMap.of("id", 8)))); + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 3); + put("name", null); + } + }), + ExprValueUtils.tupleValue( + new LinkedHashMap<>() { + { + put("id", 8); + put("name", null); + } + }))); } @Test public void anti_join_test() { - PhysicalPlan left = testScan(joinTestInputs); - PhysicalPlan right = testScan(countTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.ANTI); + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, false); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(2, result.size()); @@ -617,11 +228,9 @@ public void anti_join_test() { } @Test - public void anti_join_side_exchange_test() { - // Exchange the tables - PhysicalPlan left = testScan(countTestInputs); - PhysicalPlan right = testScan(joinTestInputs); - PhysicalPlan joinPlan = makeNestedLoopJoin(left, right, Join.JoinType.ANTI); + public void anti_join_side_reversed_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoin(Join.JoinType.ANTI, JoinOperator.BuildSide.BuildRight, true); List result = execute(joinPlan); result.forEach(System.out::println); assertEquals(5, result.size()); @@ -634,4 +243,28 @@ public void anti_join_side_exchange_test() { ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i")), ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k")))); } + + // +-----------------------------------------+ + // | Test join tables with same column names | + // +-----------------------------------------+ + + @Test + public void same_column_names_inner_join_test() { + PhysicalPlan joinPlan = + makeNestedLoopJoinWithSameColumnNames( + Join.JoinType.INNER, JoinOperator.BuildSide.BuildRight, false); + List result = execute(joinPlan); + result.forEach(System.out::println); + assertEquals(7, result.size()); + assertThat( + result, + containsInAnyOrder( + id1_same_id1, + id3_same_id3, + id5_same_id5, + id8_same_id8, + id10_same_id10, + id10_same_id10_duplicated, + id10_same_id10_duplicated2)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java index 35de25375d..14b53d434a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanTestBase.java @@ -20,6 +20,7 @@ import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.env.Environment; @@ -27,61 +28,6 @@ public class PhysicalPlanTestBase { - protected static final List countTestInputs = - new ImmutableList.Builder() - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 1, "name", "a"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 2, "name", "b"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 3))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 4, "name", "d"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 5, "name", "e"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 6, "name", "f"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 7, "name", "g"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 8))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 9, "name", "i"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 10, "name", "j"))) - .add(ExprValueUtils.tupleValue(ImmutableMap.of("id", 11, "name", "k"))) - .build(); - - protected static final List joinTestInputs = - new ImmutableList.Builder() - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h1", "errors", 2))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-03"), "host", "h2", "errors", 3))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-04"), "host", "h1", "errors", 1))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-04"), "host", "h2", "errors", 10))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-06"), "host", "h1", "errors", 1))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h1", "errors", 6))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 8))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-07"), "host", "h2", "errors", 12))) - .add( - ExprValueUtils.tupleValue( - ImmutableMap.of( - "day", new ExprDateValue("2021-01-08"), "host", "h1", "errors", 13))) - .build(); - protected static final List inputs = new ImmutableList.Builder() .add( @@ -334,8 +280,15 @@ protected static PhysicalPlan testScan(List inputs) { return new TestScan(inputs); } + protected static PhysicalPlan testTableScan( + String relationName, ExecutionEngine.Schema schema, List inputs) { + return new TestScan(inputs, relationName, schema); + } + protected static class TestScan extends PhysicalPlan implements SerializablePlan { private final Iterator iterator; + private ExecutionEngine.Schema schema; + private String relationName; public TestScan() { iterator = inputs.iterator(); @@ -345,6 +298,12 @@ public TestScan(List inputs) { iterator = inputs.iterator(); } + public TestScan(List inputs, String relationName, ExecutionEngine.Schema schema) { + iterator = inputs.iterator(); + this.relationName = relationName; + this.schema = schema; + } + @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { return null; @@ -365,6 +324,11 @@ public ExprValue next() { return iterator.next(); } + @Override + public ExecutionEngine.Schema schema() { + return this.schema; + } + @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {}