From 494b0a229f803bd2e36b4ed40dba946d352f3b33 Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 11 Oct 2024 17:02:49 -0700 Subject: [PATCH 01/42] WIP: Add trendline PPL command Signed-off-by: James Duong --- .../org/opensearch/sql/analysis/Analyzer.java | 13 ++ .../sql/ast/AbstractNodeVisitor.java | 9 + .../opensearch/sql/ast/tree/Trendline.java | 69 ++++++ .../sql/planner/DefaultImplementor.java | 7 + .../logical/LogicalPlanNodeVisitor.java | 4 + .../sql/planner/logical/LogicalTrendline.java | 39 ++++ .../physical/PhysicalPlanNodeVisitor.java | 4 + .../planner/physical/TrendlineOperator.java | 207 ++++++++++++++++++ .../OpenSearchExecutionProtector.java | 10 + ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 5 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 13 ++ .../opensearch/sql/ppl/parser/AstBuilder.java | 11 + .../sql/ppl/parser/AstExpressionBuilder.java | 11 + 13 files changed, 402 insertions(+) create mode 100644 core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 71db736f78..ddfc8765f8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -62,6 +62,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -100,6 +101,7 @@ import org.opensearch.sql.planner.logical.LogicalRemove; import org.opensearch.sql.planner.logical.LogicalRename; import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.physical.datasource.DataSourceTable; import org.opensearch.sql.storage.Table; @@ -594,6 +596,17 @@ public LogicalPlan visitML(ML node, AnalysisContext context) { return new LogicalML(child, node.getArguments()); } + /** Build {@link LogicalTrendline} for Trendline command. */ + @Override + public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { + final LogicalPlan child = node.getChild().get(0).accept(this, context); + final List unresolvedComputations = node.getComputations(); + final List computations = + unresolvedComputations.stream().map(expression -> + (Trendline.TrendlineComputation) expression).toList(); + return new LogicalTrendline(child, computations); + } + @Override public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { LogicalPlan child = paginate.getChild().get(0).accept(this, context); diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index a0520dc70e..f27260dd5f 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -60,6 +60,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Values; /** AST nodes visitor Defines the traverse path. */ @@ -110,6 +111,14 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + + public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 0000000000..112481032f --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final List computations; + + @Override + public UnresolvedPlan attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation extends UnresolvedExpression { + + private final Integer numberOfDataPoints; + private final UnresolvedExpression dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase()); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitTrendlineComputation(this, context); + } + } + + public enum TrendlineType { + SMA, + WMA + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index f962c3e4bf..c988084d1b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -23,6 +23,7 @@ import org.opensearch.sql.planner.logical.LogicalRemove; import org.opensearch.sql.planner.logical.LogicalRename; import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.logical.LogicalWindow; import org.opensearch.sql.planner.physical.AggregationOperator; @@ -39,6 +40,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -166,6 +168,11 @@ public PhysicalPlan visitCloseCursor(LogicalCloseCursor node, C context) { return new CursorCloseOperator(visitChild(node, context)); } + @Override + public PhysicalPlan visitTrendline(LogicalTrendline plan, C context) { + return new TrendlineOperator(visitChild(plan, context), plan.getComputations()); + } + // Called when paging query requested without `FROM` clause only @Override public PhysicalPlan visitPaginate(LogicalPaginate plan, C context) { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 156db35306..c9eedd8efc 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -104,6 +104,10 @@ public R visitAD(LogicalAD plan, C context) { return visitNode(plan, context); } + public R visitTrendline(LogicalTrendline plan, C context) { + return visitNode(plan, context); + } + public R visitPaginate(LogicalPaginate plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java new file mode 100644 index 0000000000..3357b31e65 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.tree.Trendline; + +/* + * Trendline logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalTrendline extends LogicalPlan { + private final List computations; + + /** + * Constructor of LogicalTrendline. + * + * @param child child logical plan + * @param computations the computations for this trendline call. + */ + public LogicalTrendline(LogicalPlan child, List computations) { + super(Collections.singletonList(child)); + this.computations = computations; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 67d7a05135..b86edcc8f3 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -96,6 +96,10 @@ public R visitML(PhysicalPlan node, C context) { return visitNode(node, context); } + public R visitTrendline(PhysicalPlan node, C context) { + return visitNode(node, context); + } + public R visitCursorClose(CursorCloseOperator node, C context) { return visitNode(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java new file mode 100644 index 0000000000..d002b3ceb8 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -0,0 +1,207 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.executor.ExecutionEngine; + +import com.google.common.base.Preconditions; +import com.google.common.collect.EvictingQueue; +import com.google.common.collect.ImmutableMap.Builder; + +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; + +/** Trendline command implementation */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class TrendlineOperator extends PhysicalPlan { + @Getter private final PhysicalPlan input; + @Getter private final List computations; + private final List accumulators; + private final Map fieldToIndexMap; + private boolean hasAnotherRow = false; + private boolean isTuple = false; + + public TrendlineOperator(PhysicalPlan input, List computations) { + this.input = input; + this.computations = computations; + this.accumulators = computations.stream() + .map(TrendlineOperator::createAccumulator) + .toList(); + fieldToIndexMap = new HashMap<>(computations.size()); + for (int i = 0; i < computations.size(); ++i) { + + fieldToIndexMap.put(computations.get(i).getDataField().getChild().getFirst().toString(), i); + } + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + @Override + public boolean hasNext() { + return hasAnotherRow; + } + + @Override + public ExecutionEngine.Schema schema() { + // TODO: Don't hardcode the type. + return new ExecutionEngine.Schema( + computations.stream() + .map( + computation -> + new ExecutionEngine.Schema.Column(computation.getDataField().getChild().getFirst().toString(), + computation.getAlias(), DOUBLE)) + .collect(Collectors.toList())); } + + @Override + public ExprValue next() { + Preconditions.checkState(hasAnotherRow); + final ExprValue result; + if (isTuple) { + Builder mapBuilder = new Builder<>(); + for (int i = 0; i < accumulators.size(); ++i) { + final ExprValue calculateResult = accumulators.get(i).calculate(); + if (calculateResult == null) { + continue; + } + + if (null != computations.get(i).getAlias()) { + mapBuilder.put(computations.get(i).getAlias(), calculateResult); + } else { + mapBuilder.put(computations.get(i).getDataField().toString(), calculateResult); + } + } + result = ExprTupleValue.fromExprValueMap(mapBuilder.build()); + } else { + result = accumulators.getFirst().calculate(); + } + + if (input.hasNext()) { + final ExprValue next = input.next(); + consumeInputTuple(next); + } else { + hasAnotherRow = false; + } + return result; + } + + @Override + public void open() { + super.open(); + + // Position the cursor such that enough data points have been accumulated + // to get one trendline calculation. + final int smallestNumberOfDataPoints = computations.stream() + .mapToInt(Trendline.TrendlineComputation::getNumberOfDataPoints) + .min().orElseThrow(() -> new SyntaxCheckException("Period not supplied.")); + + int i; + for (i = 0; i < smallestNumberOfDataPoints && input.hasNext(); ++i) { + final ExprValue next = input.next(); + if (next.type() == STRUCT) { + isTuple = true; + } + consumeInputTuple(next); + } + + if (i == smallestNumberOfDataPoints) { + hasAnotherRow = true; + } + } + + private void consumeInputTuple(ExprValue inputValue) { + if (isTuple) { + Map tupleValue = ExprValueUtils.getTupleValue(inputValue); + for (String bindName : tupleValue.keySet()) { + final Integer index = fieldToIndexMap.get(bindName); + if (index == null) { + continue; + } + accumulators.get(index).accumulate(tupleValue.get(bindName)); + } + } else { + accumulators.getFirst().accumulate(inputValue); + } + } + + private static TrendlineAccumulator createAccumulator(Trendline.TrendlineComputation computation) { + switch (computation.getComputationType()) { + case SMA: + return new SimpleMovingAverageAccumulator(computation); + case WMA: + default: + throw new IllegalStateException("Unexpected value: " + computation.getComputationType()); + } + } + + /** + * Maintains stateful information for calculating the trendline. + */ + private interface TrendlineAccumulator { + void accumulate(ExprValue value); + ExprValue calculate(); + } + + private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { + private final ExprValue dataPointsNeeded; + private final EvictingQueue receivedValues; + private ExprValue runningAverage = new ExprDoubleValue(0.0); + + public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation) { + dataPointsNeeded = new ExprIntegerValue(computation.getNumberOfDataPoints()); + receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); + } + + @Override + public void accumulate(ExprValue value) { + receivedValues.add(value); + } + + @Override + public ExprValue calculate() { + // TODO: Calculate this properly using the DSL and optimize it to use + // a running average instead of iterating over the whole window. + if (receivedValues.size() < dataPointsNeeded.integerValue()) { + return ExprNullValue.of(); + } + ExprValue[] entries = new ExprValue[0]; + ExprValue[] data = receivedValues.toArray(entries); + double result = 0; + for (int i = 0; i < data.length; i++) { + result += data[i].doubleValue(); + } + + result /= receivedValues.size(); + return new ExprDoubleValue(result); + } + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 28827b0a54..0f673a2482 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -24,6 +24,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -187,6 +188,15 @@ public PhysicalPlan visitML(PhysicalPlan node, Object context) { mlOperator.getNodeClient())); } + @Override + public PhysicalPlan visitTrendline(PhysicalPlan node, Object context) { + TrendlineOperator trendlineOperator = (TrendlineOperator) node; + return doProtect( + new TrendlineOperator( + visitInput(trendlineOperator.getInput(), context), + trendlineOperator.getComputations())); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3ba8da74f4..6f4181c7f5 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -36,6 +36,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +TRENDLINE: 'TRENDLINE'; // COMMAND ASSIST KEYWORDS AS: 'AS'; @@ -57,6 +58,10 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; +// TRENDLINE KEYWORDS +SMA: 'SMA'; +WMA: 'WMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 89a32abe23..6325ca0065 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -50,6 +50,7 @@ commands | adCommand | mlCommand | fillnullCommand + | trendlineCommand ; searchCommand @@ -143,6 +144,18 @@ fillNullWithFieldVariousValues nullReplacementExpression : nullableField = fieldExpression EQUAL nullReplacement = valueExpression + +trendlineCommand + : TRENDLINE trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS AS alias = fieldExpression + ; + +trendlineType + : SMA + | WMA ; kmeansCommand diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 2fccb8e635..478bfac812 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -64,6 +64,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; @@ -421,6 +422,16 @@ public UnresolvedPlan visitFillNullWithFieldVariousValues( FillNull.ContainNullableFieldFill.ofVariousValue(replacementsBuilder.build())); } + /** trendline command. */ + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = ctx.trendlineClause() + .stream() + .map(expressionBuilder::visit) + .collect(Collectors.toList()); + return new Trendline(trendlineComputations); + } + /** Get original text in query. */ private String getTextInQuery(ParserRuleContext ctx) { Token start = ctx.getStart(); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 98c41027ff..87c906f931 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -52,6 +52,7 @@ import org.antlr.v4.runtime.RuleContext; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.*; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; @@ -75,6 +76,16 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } + /** Trendline clause. */ + @Override + public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) { + Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + Field dataField = (Field) this.visitFieldExpression(ctx.field); + String alias = ctx.alias.getText(); + String computationType = ctx.trendlineType().getText(); + return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, computationType); + } + /** Logical expression excluding boolean, comparison. */ @Override public UnresolvedExpression visitLogicalNot(LogicalNotContext ctx) { From c6ec31fab74c92356b00e0c90ccb833aa5aacae7 Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 14 Oct 2024 14:14:09 -0700 Subject: [PATCH 02/42] fix missing newline Signed-off-by: James Duong --- .../org/opensearch/sql/planner/physical/TrendlineOperator.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index d002b3ceb8..2371fe39e9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -80,7 +80,8 @@ public ExecutionEngine.Schema schema() { computation -> new ExecutionEngine.Schema.Column(computation.getDataField().getChild().getFirst().toString(), computation.getAlias(), DOUBLE)) - .collect(Collectors.toList())); } + .collect(Collectors.toList())); + } @Override public ExprValue next() { From 9a7b6a4c509474b7db203ef7521b678c666c69cd Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 14 Oct 2024 15:54:23 -0700 Subject: [PATCH 03/42] Optimize running average calculation and use DSL Rework calculation to use induction and use DSL for math Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 100 ++++++++++++------ 1 file changed, 66 insertions(+), 34 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 2371fe39e9..87da10f20a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -8,29 +8,27 @@ import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import com.google.common.base.Preconditions; +import com.google.common.collect.EvictingQueue; +import com.google.common.collect.ImmutableMap.Builder; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; - +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.antlr.SyntaxCheckException; -import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.executor.ExecutionEngine; - -import com.google.common.base.Preconditions; -import com.google.common.collect.EvictingQueue; -import com.google.common.collect.ImmutableMap.Builder; - -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; /** Trendline command implementation */ @ToString @@ -46,9 +44,7 @@ public class TrendlineOperator extends PhysicalPlan { public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; this.computations = computations; - this.accumulators = computations.stream() - .map(TrendlineOperator::createAccumulator) - .toList(); + this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList(); fieldToIndexMap = new HashMap<>(computations.size()); for (int i = 0; i < computations.size(); ++i) { @@ -78,8 +74,10 @@ public ExecutionEngine.Schema schema() { computations.stream() .map( computation -> - new ExecutionEngine.Schema.Column(computation.getDataField().getChild().getFirst().toString(), - computation.getAlias(), DOUBLE)) + new ExecutionEngine.Schema.Column( + computation.getDataField().getChild().getFirst().toString(), + computation.getAlias(), + DOUBLE)) .collect(Collectors.toList())); } @@ -121,9 +119,11 @@ public void open() { // Position the cursor such that enough data points have been accumulated // to get one trendline calculation. - final int smallestNumberOfDataPoints = computations.stream() - .mapToInt(Trendline.TrendlineComputation::getNumberOfDataPoints) - .min().orElseThrow(() -> new SyntaxCheckException("Period not supplied.")); + final int smallestNumberOfDataPoints = + computations.stream() + .mapToInt(Trendline.TrendlineComputation::getNumberOfDataPoints) + .min() + .orElseThrow(() -> new SyntaxCheckException("Period not supplied.")); int i; for (i = 0; i < smallestNumberOfDataPoints && input.hasNext(); ++i) { @@ -154,7 +154,8 @@ private void consumeInputTuple(ExprValue inputValue) { } } - private static TrendlineAccumulator createAccumulator(Trendline.TrendlineComputation computation) { + private static TrendlineAccumulator createAccumulator( + Trendline.TrendlineComputation computation) { switch (computation.getComputationType()) { case SMA: return new SimpleMovingAverageAccumulator(computation); @@ -164,18 +165,18 @@ private static TrendlineAccumulator createAccumulator(Trendline.TrendlineComputa } } - /** - * Maintains stateful information for calculating the trendline. - */ + /** Maintains stateful information for calculating the trendline. */ private interface TrendlineAccumulator { void accumulate(ExprValue value); + ExprValue calculate(); } + // TODO: Make the actual math polymorphic based on types to deal with datetimes. private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { private final ExprValue dataPointsNeeded; private final EvictingQueue receivedValues; - private ExprValue runningAverage = new ExprDoubleValue(0.0); + private ExprValue runningAverage = null; public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation) { dataPointsNeeded = new ExprIntegerValue(computation.getNumberOfDataPoints()); @@ -184,25 +185,56 @@ public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation @Override public void accumulate(ExprValue value) { + if (value == null) { + // Should this make the whole calculation null? + return; + } + + if (dataPointsNeeded.integerValue() == 1) { + runningAverage = value; + receivedValues.add(value); + return; + } + + final ExprValue valueToRemove; + if (receivedValues.size() == dataPointsNeeded.integerValue()) { + valueToRemove = receivedValues.remove(); + } else { + valueToRemove = null; + } receivedValues.add(value); + + if (receivedValues.size() == dataPointsNeeded.integerValue()) { + if (runningAverage != null) { + // We can use the previous average calculation. + // Subtract the evicted value / period and add the new value / period. + // Refactored, that would be previous + (newValue - oldValue) / period + runningAverage = + DSL.add( + DSL.literal(runningAverage), + DSL.divide( + DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove)), + DSL.literal(dataPointsNeeded.doubleValue()))) + .valueOf(); + } else { + // This is the first average calculation so sum the entire receivedValues dataset. + final List data = receivedValues.stream().toList(); + Expression runningTotal = DSL.literal(0.0D); + for (ExprValue entry : data) { + runningTotal = DSL.add(runningTotal, DSL.literal(entry)); + } + runningAverage = + DSL.divide(runningTotal, DSL.literal(dataPointsNeeded.doubleValue())).valueOf(); + } + } } @Override public ExprValue calculate() { - // TODO: Calculate this properly using the DSL and optimize it to use - // a running average instead of iterating over the whole window. if (receivedValues.size() < dataPointsNeeded.integerValue()) { return ExprNullValue.of(); } - ExprValue[] entries = new ExprValue[0]; - ExprValue[] data = receivedValues.toArray(entries); - double result = 0; - for (int i = 0; i < data.length; i++) { - result += data[i].doubleValue(); - } - - result /= receivedValues.size(); - return new ExprDoubleValue(result); + return runningAverage; } } } From f39bef15a7f07216ee800a58b62580b0ad52e958 Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 14 Oct 2024 15:54:33 -0700 Subject: [PATCH 04/42] spotless Signed-off-by: James Duong --- .../main/java/org/opensearch/sql/analysis/Analyzer.java | 5 +++-- .../main/java/org/opensearch/sql/ast/tree/Trendline.java | 9 ++++++--- .../java/org/opensearch/sql/ppl/parser/AstBuilder.java | 6 ++---- .../opensearch/sql/ppl/parser/AstExpressionBuilder.java | 3 ++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index ddfc8765f8..cd93d726d3 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -602,8 +602,9 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { final LogicalPlan child = node.getChild().get(0).accept(this, context); final List unresolvedComputations = node.getComputations(); final List computations = - unresolvedComputations.stream().map(expression -> - (Trendline.TrendlineComputation) expression).toList(); + unresolvedComputations.stream() + .map(expression -> (Trendline.TrendlineComputation) expression) + .toList(); return new LogicalTrendline(child, computations); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 112481032f..cb137d97ff 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -6,6 +6,7 @@ package org.opensearch.sql.ast.tree; import com.google.common.collect.ImmutableList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -14,8 +15,6 @@ import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.UnresolvedExpression; -import java.util.List; - @ToString @Getter @RequiredArgsConstructor @@ -49,7 +48,11 @@ public static class TrendlineComputation extends UnresolvedExpression { private final String alias; private final TrendlineType computationType; - public TrendlineComputation(Integer numberOfDataPoints, UnresolvedExpression dataField, String alias, String computationType) { + public TrendlineComputation( + Integer numberOfDataPoints, + UnresolvedExpression dataField, + String alias, + String computationType) { this.numberOfDataPoints = numberOfDataPoints; this.dataField = dataField; this.alias = alias; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 478bfac812..d0e4224a81 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -425,10 +425,8 @@ public UnresolvedPlan visitFillNullWithFieldVariousValues( /** trendline command. */ @Override public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { - List trendlineComputations = ctx.trendlineClause() - .stream() - .map(expressionBuilder::visit) - .collect(Collectors.toList()); + List trendlineComputations = + ctx.trendlineClause().stream().map(expressionBuilder::visit).collect(Collectors.toList()); return new Trendline(trendlineComputations); } diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 87c906f931..da6edbb03d 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -83,7 +83,8 @@ public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineCl Field dataField = (Field) this.visitFieldExpression(ctx.field); String alias = ctx.alias.getText(); String computationType = ctx.trendlineType().getText(); - return new Trendline.TrendlineComputation(numberOfDataPoints, dataField, alias, computationType); + return new Trendline.TrendlineComputation( + numberOfDataPoints, dataField, alias, computationType); } /** Logical expression excluding boolean, comparison. */ From 92657a13fe53543b7839414f8404255331f18cca Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 18 Oct 2024 12:35:20 -0700 Subject: [PATCH 05/42] Make implementation preserve child data Preserve child data when the child field isn't overwritten by trendline Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 117 ++++++------------ 1 file changed, 35 insertions(+), 82 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 87da10f20a..c62fbc3300 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -5,9 +5,6 @@ package org.opensearch.sql.planner.physical; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; -import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; - import com.google.common.base.Preconditions; import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; @@ -15,18 +12,15 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; import org.opensearch.sql.ast.tree.Trendline; -import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -39,7 +33,6 @@ public class TrendlineOperator extends PhysicalPlan { private final List accumulators; private final Map fieldToIndexMap; private boolean hasAnotherRow = false; - private boolean isTuple = false; public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; @@ -64,93 +57,40 @@ public List getChild() { @Override public boolean hasNext() { - return hasAnotherRow; - } - - @Override - public ExecutionEngine.Schema schema() { - // TODO: Don't hardcode the type. - return new ExecutionEngine.Schema( - computations.stream() - .map( - computation -> - new ExecutionEngine.Schema.Column( - computation.getDataField().getChild().getFirst().toString(), - computation.getAlias(), - DOUBLE)) - .collect(Collectors.toList())); + return getChild().getFirst().hasNext(); } @Override public ExprValue next() { - Preconditions.checkState(hasAnotherRow); + Preconditions.checkState(hasNext()); final ExprValue result; - if (isTuple) { - Builder mapBuilder = new Builder<>(); - for (int i = 0; i < accumulators.size(); ++i) { - final ExprValue calculateResult = accumulators.get(i).calculate(); - if (calculateResult == null) { - continue; - } - - if (null != computations.get(i).getAlias()) { - mapBuilder.put(computations.get(i).getAlias(), calculateResult); - } else { - mapBuilder.put(computations.get(i).getDataField().toString(), calculateResult); - } + final ExprValue next = input.next(); + consumeInputTuple(next); + final Map inputStruct = ExprValueUtils.getTupleValue(next); + final Builder mapBuilder = new Builder<>(); + mapBuilder.putAll(inputStruct); + + // Add calculated trendline values, which might overwrite existing fields from the input. + for (int i = 0; i < accumulators.size(); ++i) { + final ExprValue calculateResult = accumulators.get(i).calculate(); + if (null != computations.get(i).getAlias()) { + mapBuilder.put(computations.get(i).getAlias(), calculateResult); + } else { + mapBuilder.put(computations.get(i).getDataField().getChild().getFirst().toString(), calculateResult); } - result = ExprTupleValue.fromExprValueMap(mapBuilder.build()); - } else { - result = accumulators.getFirst().calculate(); - } - - if (input.hasNext()) { - final ExprValue next = input.next(); - consumeInputTuple(next); - } else { - hasAnotherRow = false; } + result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); return result; } - @Override - public void open() { - super.open(); - - // Position the cursor such that enough data points have been accumulated - // to get one trendline calculation. - final int smallestNumberOfDataPoints = - computations.stream() - .mapToInt(Trendline.TrendlineComputation::getNumberOfDataPoints) - .min() - .orElseThrow(() -> new SyntaxCheckException("Period not supplied.")); - - int i; - for (i = 0; i < smallestNumberOfDataPoints && input.hasNext(); ++i) { - final ExprValue next = input.next(); - if (next.type() == STRUCT) { - isTuple = true; - } - consumeInputTuple(next); - } - - if (i == smallestNumberOfDataPoints) { - hasAnotherRow = true; - } - } - private void consumeInputTuple(ExprValue inputValue) { - if (isTuple) { - Map tupleValue = ExprValueUtils.getTupleValue(inputValue); - for (String bindName : tupleValue.keySet()) { - final Integer index = fieldToIndexMap.get(bindName); - if (index == null) { - continue; - } - accumulators.get(index).accumulate(tupleValue.get(bindName)); + final Map tupleValue = ExprValueUtils.getTupleValue(inputValue); + for (String bindName : tupleValue.keySet()) { + final Integer index = fieldToIndexMap.get(bindName); + if (index == null) { + continue; } - } else { - accumulators.getFirst().accumulate(inputValue); + accumulators.get(index).accumulate(tupleValue.get(bindName)); } } @@ -237,4 +177,17 @@ public ExprValue calculate() { return runningAverage; } } + + private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { + + @Override + public void accumulate(ExprValue value) { + + } + + @Override + public ExprValue calculate() { + return null; + } + } } From 5732ac4c10c166e932fb5e544b332eace323c238 Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 18 Oct 2024 12:36:17 -0700 Subject: [PATCH 06/42] Implement logging and explain for trendline Signed-off-by: James Duong --- .../org/opensearch/sql/executor/Explain.java | 25 +++++++++++++++++++ .../physical/PhysicalPlanNodeVisitor.java | 3 ++- .../OpenSearchExecutionProtector.java | 2 +- .../sql/ppl/utils/PPLQueryDataAnonymizer.java | 18 +++++++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index fffbe6f693..99d66f1bfc 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -8,12 +8,14 @@ import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode; import org.opensearch.sql.expression.Expression; @@ -31,6 +33,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -211,6 +214,15 @@ public ExplainResponseNode visitNested(NestedOperator node, Object context) { explanNode -> explanNode.setDescription(ImmutableMap.of("nested", node.getFields()))); } + @Override + public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context) { + return explain( + node, + context, + explainNode -> explainNode.setDescription( + ImmutableMap.of("computations", describeTrendlineComputations(node.getComputations())))); + } + protected ExplainResponseNode explain( PhysicalPlan node, Object context, Consumer doExplain) { ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node)); @@ -245,4 +257,17 @@ private Map> describeSortList( "sortOrder", p.getLeft().getSortOrder().toString(), "nullOrder", p.getLeft().getNullOrder().toString()))); } + + private List> describeTrendlineComputations( + List computations) { + return computations.stream() + .map(computation -> + ImmutableMap.of( + "computationType", computation.getComputationType().name().toLowerCase(Locale.ROOT), + "numberOfDataPoints", computation.getNumberOfDataPoints().toString(), + "dataField", computation.getDataField().getChild().getFirst().toString(), + "alias", computation.getAlias() != null ? computation.getAlias() : "")) + .collect(Collectors.toList()); + } + } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index b86edcc8f3..ac2740b76d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -5,6 +5,7 @@ package org.opensearch.sql.planner.physical; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.write.TableWriteOperator; @@ -96,7 +97,7 @@ public R visitML(PhysicalPlan node, C context) { return visitNode(node, context); } - public R visitTrendline(PhysicalPlan node, C context) { + public R visitTrendline(TrendlineOperator node, C context) { return visitNode(node, context); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 0f673a2482..41070d3f6f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -189,7 +189,7 @@ public PhysicalPlan visitML(PhysicalPlan node, Object context) { } @Override - public PhysicalPlan visitTrendline(PhysicalPlan node, Object context) { + public PhysicalPlan visitTrendline(TrendlineOperator node, Object context) { TrendlineOperator trendlineOperator = (TrendlineOperator) node; return doProtect( new TrendlineOperator( diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index a1ca0fd69a..e6163f15d0 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -43,6 +44,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.planner.logical.LogicalAggregation; @@ -221,6 +223,13 @@ public String visitHead(Head node, String context) { return StringUtils.format("%s | head %d", child, size); } + @Override + public String visitTrendline(Trendline node, String context) { + String child = node.getChild().getFirst().accept(this, context); + String computations = visitExpressionList(node.getComputations()); + return StringUtils.format("%s | trendline %s", child, computations); + } + private String visitFieldList(List fieldList) { return fieldList.stream().map(this::visitExpression).collect(Collectors.joining(",")); } @@ -344,5 +353,14 @@ public String visitAlias(Alias node, String context) { String expr = node.getDelegated().accept(this, context); return StringUtils.format("%s", expr); } + + @Override + public String visitTrendlineComputation(Trendline.TrendlineComputation node, String context) { + final String dataField = node.getDataField().accept(this, context); + final String aliasOrEmpty = node.getAlias() != null ? " as " + node.getAlias() : ""; + final String computationType = node.getComputationType().name().toLowerCase(Locale.ROOT); + return StringUtils.format("%s(%d, %s)%s", + computationType, node.getNumberOfDataPoints(), dataField, aliasOrEmpty); + } } } From ec89a61b73e45bb79be30a1bc172e0dc72a7ac37 Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 21 Oct 2024 11:07:10 -0700 Subject: [PATCH 07/42] Use List#get(0) instead of getFirst() to support older compilers Signed-off-by: James Duong --- core/src/main/java/org/opensearch/sql/executor/Explain.java | 2 +- .../opensearch/sql/planner/physical/TrendlineOperator.java | 5 ++--- .../org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index 99d66f1bfc..baedea0fa4 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -265,7 +265,7 @@ private List> describeTrendlineComputations( ImmutableMap.of( "computationType", computation.getComputationType().name().toLowerCase(Locale.ROOT), "numberOfDataPoints", computation.getNumberOfDataPoints().toString(), - "dataField", computation.getDataField().getChild().getFirst().toString(), + "dataField", computation.getDataField().getChild().get(0).toString(), "alias", computation.getAlias() != null ? computation.getAlias() : "")) .collect(Collectors.toList()); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index c62fbc3300..8a0bfa87c9 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -32,7 +32,6 @@ public class TrendlineOperator extends PhysicalPlan { @Getter private final List computations; private final List accumulators; private final Map fieldToIndexMap; - private boolean hasAnotherRow = false; public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; @@ -41,7 +40,7 @@ public TrendlineOperator(PhysicalPlan input, List(computations.size()); for (int i = 0; i < computations.size(); ++i) { - fieldToIndexMap.put(computations.get(i).getDataField().getChild().getFirst().toString(), i); + fieldToIndexMap.put(computations.get(i).getDataField().getChild().get(0).toString(), i); } } @@ -76,7 +75,7 @@ public ExprValue next() { if (null != computations.get(i).getAlias()) { mapBuilder.put(computations.get(i).getAlias(), calculateResult); } else { - mapBuilder.put(computations.get(i).getDataField().getChild().getFirst().toString(), calculateResult); + mapBuilder.put(computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); } } result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index e6163f15d0..ade900b46f 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -225,7 +225,7 @@ public String visitHead(Head node, String context) { @Override public String visitTrendline(Trendline node, String context) { - String child = node.getChild().getFirst().accept(this, context); + String child = node.getChild().get(0).accept(this, context); String computations = visitExpressionList(node.getComputations()); return StringUtils.format("%s | trendline %s", child, computations); } From a25303d57ceb02aec8ed9e4125d80e09615ef403 Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 21 Oct 2024 12:21:43 -0700 Subject: [PATCH 08/42] Tell the project operator about new fields from trendline computations Signed-off-by: James Duong --- core/src/main/java/org/opensearch/sql/analysis/Analyzer.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index cd93d726d3..04c6a8d0b8 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -600,11 +600,16 @@ public LogicalPlan visitML(ML node, AnalysisContext context) { @Override public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { final LogicalPlan child = node.getChild().get(0).accept(this, context); + + final TypeEnvironment currEnv = context.peek(); final List unresolvedComputations = node.getComputations(); final List computations = unresolvedComputations.stream() .map(expression -> (Trendline.TrendlineComputation) expression) .toList(); + + computations.forEach(computation -> currEnv.define( + new Symbol(Namespace.FIELD_NAME, computation.getAlias()), ExprCoreType.DOUBLE)); return new LogicalTrendline(child, computations); } From e8b0fe26fbd1c4bcc6c8c9b9089a9ed2bd9d8693 Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 22 Oct 2024 16:46:56 -0700 Subject: [PATCH 09/42] Add Trendline parser test Signed-off-by: James Duong --- .../java/org/opensearch/sql/ast/dsl/AstDSL.java | 13 +++++++++++++ .../java/org/opensearch/sql/ast/tree/Trendline.java | 6 ++++-- .../opensearch/sql/ppl/parser/AstBuilderTest.java | 12 ++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 8135731ff6..5746428455 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -62,6 +62,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; @@ -466,6 +467,18 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) { return new Limit(limit, offset).attach(input); } + public static Trendline trendline(UnresolvedPlan input, Trendline.TrendlineComputation... computations) { + return new Trendline(Arrays.asList(computations)).attach(input); + } + + public static Trendline.TrendlineComputation computation( + Integer numDataPoints, + UnresolvedExpression dataField, + String alias, + String type) { + return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type); + } + public static Parse parse( UnresolvedPlan input, ParseMethod parseMethod, diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index cb137d97ff..8858353b6b 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -7,6 +7,8 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Locale; + import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -25,7 +27,7 @@ public class Trendline extends UnresolvedPlan { private final List computations; @Override - public UnresolvedPlan attach(UnresolvedPlan child) { + public Trendline attach(UnresolvedPlan child) { this.child = child; return this; } @@ -56,7 +58,7 @@ public TrendlineComputation( this.numberOfDataPoints = numberOfDataPoints; this.dataField = dataField; this.alias = alias; - this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase()); + this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase(Locale.ROOT)); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index ac2bce9dbc..e61bb97e4f 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; +import static org.opensearch.sql.ast.dsl.AstDSL.computation; import static org.opensearch.sql.ast.dsl.AstDSL.dedupe; import static org.opensearch.sql.ast.dsl.AstDSL.defaultDedupArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; @@ -38,6 +39,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.tableFunction; +import static org.opensearch.sql.ast.dsl.AstDSL.trendline; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; import static org.opensearch.sql.utils.SystemIndexUtils.DATASOURCES_TABLE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; @@ -692,6 +694,16 @@ public void testFillNullCommandVariousValues() { .build()))); } + public void testTrendline() { + assertEqual( + "source=t | trendline sma(5, test_field) as test_field_alias sma(1, test_field_2) as test_field_alias_2", + trendline( + relation("t"), + computation(5, field("test_field"), "test_field_alias", "sma"), + computation(1, field("test_field)2"), "test_field_alias_2", "sma") + )); + } + @Test public void testDescribeCommand() { assertEqual("describe t", relation(mappingTable("t"))); From f96a657c0eb5396748d7cbc04612ba4d2078bc21 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 23 Oct 2024 13:30:58 -0700 Subject: [PATCH 10/42] Spotless Signed-off-by: James Duong --- .../java/org/opensearch/sql/analysis/Analyzer.java | 6 ++++-- .../java/org/opensearch/sql/ast/dsl/AstDSL.java | 8 +++----- .../java/org/opensearch/sql/ast/tree/Trendline.java | 4 ++-- .../java/org/opensearch/sql/executor/Explain.java | 13 ++++++++----- .../planner/physical/PhysicalPlanNodeVisitor.java | 1 - .../sql/planner/physical/TrendlineOperator.java | 7 +++---- .../sql/ppl/utils/PPLQueryDataAnonymizer.java | 4 ++-- .../opensearch/sql/ppl/parser/AstBuilderTest.java | 6 +++--- 8 files changed, 25 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 04c6a8d0b8..82cb62d2e2 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -608,8 +608,10 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { .map(expression -> (Trendline.TrendlineComputation) expression) .toList(); - computations.forEach(computation -> currEnv.define( - new Symbol(Namespace.FIELD_NAME, computation.getAlias()), ExprCoreType.DOUBLE)); + computations.forEach( + computation -> + currEnv.define( + new Symbol(Namespace.FIELD_NAME, computation.getAlias()), ExprCoreType.DOUBLE)); return new LogicalTrendline(child, computations); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 5746428455..816cabcc21 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -467,15 +467,13 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) { return new Limit(limit, offset).attach(input); } - public static Trendline trendline(UnresolvedPlan input, Trendline.TrendlineComputation... computations) { + public static Trendline trendline( + UnresolvedPlan input, Trendline.TrendlineComputation... computations) { return new Trendline(Arrays.asList(computations)).attach(input); } public static Trendline.TrendlineComputation computation( - Integer numDataPoints, - UnresolvedExpression dataField, - String alias, - String type) { + Integer numDataPoints, UnresolvedExpression dataField, String alias, String type) { return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 8858353b6b..f58527621d 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -8,7 +8,6 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Locale; - import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -58,7 +57,8 @@ public TrendlineComputation( this.numberOfDataPoints = numberOfDataPoints; this.dataField = dataField; this.alias = alias; - this.computationType = Trendline.TrendlineType.valueOf(computationType.toUpperCase(Locale.ROOT)); + this.computationType = + Trendline.TrendlineType.valueOf(computationType.toUpperCase(Locale.ROOT)); } @Override diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index baedea0fa4..096fb240dd 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -219,8 +219,10 @@ public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context return explain( node, context, - explainNode -> explainNode.setDescription( - ImmutableMap.of("computations", describeTrendlineComputations(node.getComputations())))); + explainNode -> + explainNode.setDescription( + ImmutableMap.of( + "computations", describeTrendlineComputations(node.getComputations())))); } protected ExplainResponseNode explain( @@ -261,13 +263,14 @@ private Map> describeSortList( private List> describeTrendlineComputations( List computations) { return computations.stream() - .map(computation -> + .map( + computation -> ImmutableMap.of( - "computationType", computation.getComputationType().name().toLowerCase(Locale.ROOT), + "computationType", + computation.getComputationType().name().toLowerCase(Locale.ROOT), "numberOfDataPoints", computation.getNumberOfDataPoints().toString(), "dataField", computation.getDataField().getChild().get(0).toString(), "alias", computation.getAlias() != null ? computation.getAlias() : "")) .collect(Collectors.toList()); } - } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index ac2740b76d..66c7219e39 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -5,7 +5,6 @@ package org.opensearch.sql.planner.physical; -import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.storage.TableScanOperator; import org.opensearch.sql.storage.write.TableWriteOperator; diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 8a0bfa87c9..ad6cc32158 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -75,7 +75,8 @@ public ExprValue next() { if (null != computations.get(i).getAlias()) { mapBuilder.put(computations.get(i).getAlias(), calculateResult); } else { - mapBuilder.put(computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); + mapBuilder.put( + computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); } } result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); @@ -180,9 +181,7 @@ public ExprValue calculate() { private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { @Override - public void accumulate(ExprValue value) { - - } + public void accumulate(ExprValue value) {} @Override public ExprValue calculate() { diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index ade900b46f..15b78ac956 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -359,8 +359,8 @@ public String visitTrendlineComputation(Trendline.TrendlineComputation node, Str final String dataField = node.getDataField().accept(this, context); final String aliasOrEmpty = node.getAlias() != null ? " as " + node.getAlias() : ""; final String computationType = node.getComputationType().name().toLowerCase(Locale.ROOT); - return StringUtils.format("%s(%d, %s)%s", - computationType, node.getNumberOfDataPoints(), dataField, aliasOrEmpty); + return StringUtils.format( + "%s(%d, %s)%s", computationType, node.getNumberOfDataPoints(), dataField, aliasOrEmpty); } } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index e61bb97e4f..6039a1ce8d 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -696,12 +696,12 @@ public void testFillNullCommandVariousValues() { public void testTrendline() { assertEqual( - "source=t | trendline sma(5, test_field) as test_field_alias sma(1, test_field_2) as test_field_alias_2", + "source=t | trendline sma(5, test_field) as test_field_alias sma(1, test_field_2) as" + + " test_field_alias_2", trendline( relation("t"), computation(5, field("test_field"), "test_field_alias", "sma"), - computation(1, field("test_field)2"), "test_field_alias_2", "sma") - )); + computation(1, field("test_field)2"), "test_field_alias_2", "sma"))); } @Test From ed1670aa41657966bb4578269a69ca2768538220 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 23 Oct 2024 14:24:24 -0700 Subject: [PATCH 11/42] LogicalPlan testing Signed-off-by: James Duong --- .../sql/planner/logical/LogicalPlanDSL.java | 6 ++++++ .../org/opensearch/sql/analysis/AnalyzerTest.java | 14 ++++++++++++++ .../logical/LogicalPlanNodeVisitorTest.java | 10 +++++++++- .../opensearch/sql/ppl/parser/AstBuilderTest.java | 2 +- 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 2a886ba0ca..3a4acdf3aa 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; @@ -130,6 +131,11 @@ public static LogicalPlan rareTopN( return new LogicalRareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupByList); } + public static LogicalTrendline trendline( + LogicalPlan input, Trendline.TrendlineComputation... computations) { + return new LogicalTrendline(input, Arrays.asList(computations)); + } + @SafeVarargs public LogicalPlan values(List... values) { return new LogicalValues(Arrays.asList(values)); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 4f06ce9d23..c7603de34c 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; +import static org.opensearch.sql.ast.dsl.AstDSL.computation; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; @@ -1481,6 +1482,19 @@ public void fillnull_various_values() { AstDSL.field("int_null_value"), AstDSL.intLiteral(1)))))); } + @Test + public void trendline() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.relation("schema", table), + computation(5, field("float_value"), "test_field_alias", "sma"), + computation(1, field("double_value"), "test_field_alias_2", "sma")), + AstDSL.trendline( + AstDSL.relation("schema"), + computation(5, field("float_value"), "test_field_alias", "sma"), + computation(1, field("double_value"), "test_field_alias_2", "sma"))); + } + @Test public void ad_batchRCF_relation() { Map argumentMap = diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f212749f48..9b947f6e21 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -25,6 +25,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; @@ -141,6 +142,12 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + LogicalTrendline trendline = + new LogicalTrendline( + relation, + Collections.singletonList( + AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"))); + return Stream.of( relation, tableScanBuilder, @@ -163,7 +170,8 @@ public TableWriteOperator build(PhysicalPlan child) { paginate, nested, cursor, - closeCursor) + closeCursor, + trendline) .map(Arguments::of); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 6039a1ce8d..73cd01f3cf 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -701,7 +701,7 @@ public void testTrendline() { trendline( relation("t"), computation(5, field("test_field"), "test_field_alias", "sma"), - computation(1, field("test_field)2"), "test_field_alias_2", "sma"))); + computation(1, field("test_field_2"), "test_field_alias_2", "sma"))); } @Test From cb4ea19d4d059e49884531fae5e3eac4eb3d0bdb Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 25 Oct 2024 06:32:14 -0700 Subject: [PATCH 12/42] Physical plan tests Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 21 ++- .../opensearch/sql/executor/ExplainTest.java | 32 +++++ .../physical/TrendlineOperatorTest.java | 131 ++++++++++++++++++ .../sql/ppl/utils/PPLQueryDataAnonymizer.java | 8 +- .../ppl/utils/PPLQueryDataAnonymizerTest.java | 7 + 5 files changed, 186 insertions(+), 13 deletions(-) create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index ad6cc32158..5503fe876b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -5,7 +5,6 @@ package org.opensearch.sql.planner.physical; -import com.google.common.base.Preconditions; import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; import java.util.Collections; @@ -17,7 +16,6 @@ import lombok.ToString; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprIntegerValue; -import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -30,8 +28,8 @@ public class TrendlineOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; @Getter private final List computations; - private final List accumulators; - private final Map fieldToIndexMap; + @EqualsAndHashCode.Exclude private final List accumulators; + @EqualsAndHashCode.Exclude private final Map fieldToIndexMap; public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; @@ -61,7 +59,6 @@ public boolean hasNext() { @Override public ExprValue next() { - Preconditions.checkState(hasNext()); final ExprValue result; final ExprValue next = input.next(); consumeInputTuple(next); @@ -72,11 +69,13 @@ public ExprValue next() { // Add calculated trendline values, which might overwrite existing fields from the input. for (int i = 0; i < accumulators.size(); ++i) { final ExprValue calculateResult = accumulators.get(i).calculate(); - if (null != computations.get(i).getAlias()) { - mapBuilder.put(computations.get(i).getAlias(), calculateResult); - } else { - mapBuilder.put( - computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); + if (null != calculateResult) { + if (null != computations.get(i).getAlias()) { + mapBuilder.put(computations.get(i).getAlias(), calculateResult); + } else { + mapBuilder.put( + computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); + } } } result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); @@ -172,7 +171,7 @@ public void accumulate(ExprValue value) { @Override public ExprValue calculate() { if (receivedValues.size() < dataPointsNeeded.integerValue()) { - return ExprNullValue.of(); + return null; } return runningAverage; } diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index eaeae07242..549cec5c7f 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -31,6 +31,7 @@ import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -39,6 +40,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; @@ -52,8 +54,11 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.storage.TableScanOperator; +import com.google.common.collect.ImmutableMap; + @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class ExplainTest extends ExpressionTestBase { @@ -256,6 +261,33 @@ void can_explain_nested() { explain.apply(plan)); } + @Test + void can_explain_trendline() { + PhysicalPlan plan = new TrendlineOperator(tableScan, Arrays.asList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"))); + assertEquals( + new ExplainResponse( + new ExplainResponseNode( + "TrendlineOperator", + ImmutableMap.of("computations", List.of( + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", 2, + "dataField", "distance", + "alias", "distance_alias"), + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", 3, + "dataField", "time", + "alias", "time_alias"))), + singletonList(tableScan.explainNode()))), + explain.apply(plan)); + } + + private static class FakeTableScan extends TableScanOperator { @Override public boolean hasNext() { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java new file mode 100644 index 0000000000..900a291e61 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -0,0 +1,131 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; + +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.data.model.ExprValueUtils; + +import com.google.common.collect.ImmutableMap; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +public class TrendlineOperatorTest { + @Mock private PhysicalPlan inputPlan; + + @Test + public void calculates_simple_moving_average_one_field_one_sample() { + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()) + .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + + var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( + 1, AstDSL.field("distance"), "distance_alias", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next()); + } + + @Test + public void calculates_simple_moving_average_one_field_two_samples() { + when(inputPlan.hasNext()).thenReturn(true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + + var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( + 2, AstDSL.field("distance"), "distance_alias", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_one_field_two_samples_three_rows() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( + 2, AstDSL.field("distance"), "distance_alias", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_multiple_computations() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + + var plan = new TrendlineOperator(inputPlan, Arrays.asList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0))); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))); + assertFalse(plan.hasNext()); + } + + public void alias_overwrites_input_field() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( + 2, AstDSL.field("distance"), "time", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 100)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0))); + assertTrue(plan.hasNext()); + assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))); + assertFalse(plan.hasNext()); + } +} diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 15b78ac956..934e8ca11a 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -226,7 +226,7 @@ public String visitHead(Head node, String context) { @Override public String visitTrendline(Trendline node, String context) { String child = node.getChild().get(0).accept(this, context); - String computations = visitExpressionList(node.getComputations()); + String computations = visitExpressionList(node.getComputations(), " "); return StringUtils.format("%s | trendline %s", child, computations); } @@ -235,9 +235,13 @@ private String visitFieldList(List fieldList) { } private String visitExpressionList(List expressionList) { + return visitExpressionList(expressionList, ","); + } + + private String visitExpressionList(List expressionList, String delimiter) { return expressionList.isEmpty() ? "" - : expressionList.stream().map(this::visitExpression).collect(Collectors.joining(",")); + : expressionList.stream().map(this::visitExpression).collect(Collectors.joining(delimiter)); } private String visitExpression(UnresolvedExpression expression) { diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index b5b4c97f13..06f8fbb061 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -89,6 +89,13 @@ public void testDedupCommand() { anonymize("source=t | dedup f1, f2")); } + @Test + public void testTrendlineCommand() { + assertEquals( + "source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias", + anonymize("source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias")); + } + @Test public void testHeadCommandWithNumber() { assertEquals("source=t | head 3", anonymize("source=t | head 3")); From fc3027ea1ead301f8bdefad9e1cdaa2756f7cd2b Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 25 Oct 2024 06:42:52 -0700 Subject: [PATCH 13/42] Spotless Signed-off-by: James Duong --- .../opensearch/sql/executor/ExplainTest.java | 47 ++++++---- .../physical/TrendlineOperatorTest.java | 92 +++++++++++++------ 2 files changed, 93 insertions(+), 46 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index 549cec5c7f..c7fc9b4c50 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -31,6 +31,7 @@ import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -57,8 +58,6 @@ import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.storage.TableScanOperator; -import com.google.common.collect.ImmutableMap; - @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class ExplainTest extends ExpressionTestBase { @@ -263,31 +262,41 @@ void can_explain_nested() { @Test void can_explain_trendline() { - PhysicalPlan plan = new TrendlineOperator(tableScan, Arrays.asList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), - AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"))); + PhysicalPlan plan = + new TrendlineOperator( + tableScan, + Arrays.asList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"))); assertEquals( new ExplainResponse( new ExplainResponseNode( "TrendlineOperator", - ImmutableMap.of("computations", List.of( - ImmutableMap.of( - "computationType", - "sma", - "numberOfDataPoints", 2, - "dataField", "distance", - "alias", "distance_alias"), - ImmutableMap.of( - "computationType", - "sma", - "numberOfDataPoints", 3, - "dataField", "time", - "alias", "time_alias"))), + ImmutableMap.of( + "computations", + List.of( + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", + 2, + "dataField", + "distance", + "alias", + "distance_alias"), + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", + 3, + "dataField", + "time", + "alias", + "time_alias"))), singletonList(tableScan.explainNode()))), explain.apply(plan)); } - private static class FakeTableScan extends TableScanOperator { @Override public boolean hasNext() { diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 900a291e61..1a913fa31e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -10,9 +10,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; - import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -22,8 +22,6 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprValueUtils; -import com.google.common.collect.ImmutableMap; - @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) public class TrendlineOperatorTest { @@ -35,12 +33,18 @@ public void calculates_simple_moving_average_one_field_one_sample() { when(inputPlan.next()) .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); - var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( - 1, AstDSL.field("distance"), "distance_alias", "sma"))); + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"))); plan.open(); assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), plan.next()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + plan.next()); } @Test @@ -51,15 +55,21 @@ public void calculates_simple_moving_average_one_field_two_samples() { ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); - - var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( - 2, AstDSL.field("distance"), "distance_alias", "sma"))); + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); plan.open(); assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); assertFalse(plan.hasNext()); } @@ -72,16 +82,26 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); - var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( - 2, AstDSL.field("distance"), "distance_alias", "sma"))); + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); plan.open(); assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))); assertFalse(plan.hasNext()); } @@ -94,17 +114,29 @@ public void calculates_simple_moving_average_multiple_computations() { ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); - var plan = new TrendlineOperator(inputPlan, Arrays.asList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), - AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"))); + var plan = + new TrendlineOperator( + inputPlan, + Arrays.asList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"))); plan.open(); assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0))); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0))); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))); assertFalse(plan.hasNext()); } @@ -116,16 +148,22 @@ public void alias_overwrites_input_field() { ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); - var plan = new TrendlineOperator(inputPlan, Collections.singletonList(AstDSL.computation( - 2, AstDSL.field("distance"), "time", "sma"))); + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + AstDSL.computation(2, AstDSL.field("distance"), "time", "sma"))); plan.open(); assertTrue(plan.hasNext()); - assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 100)), plan.next()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 100)), plan.next()); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0))); + assertEquals( + plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0))); assertTrue(plan.hasNext()); - assertEquals(plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))); + assertEquals( + plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))); assertFalse(plan.hasNext()); } } From a7354c5894581e2c7c27c1e81cc5267a38d1cfc7 Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 25 Oct 2024 07:02:00 -0700 Subject: [PATCH 14/42] Fix explain test failure Signed-off-by: James Duong --- .../test/java/org/opensearch/sql/executor/ExplainTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index c7fc9b4c50..72708f84f1 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -279,7 +279,7 @@ void can_explain_trendline() { "computationType", "sma", "numberOfDataPoints", - 2, + "2", "dataField", "distance", "alias", @@ -288,7 +288,7 @@ void can_explain_trendline() { "computationType", "sma", "numberOfDataPoints", - 3, + "3", "dataField", "time", "alias", From 332f5e64dd2f9903f2dc56a9697b1bf1fdfc6bc7 Mon Sep 17 00:00:00 2001 From: James Duong Date: Fri, 25 Oct 2024 17:00:12 -0700 Subject: [PATCH 15/42] Add integration tests Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 36 +++++++----- .../sql/ppl/TrendlineCommandIT.java | 56 +++++++++++++++++++ 2 files changed, 77 insertions(+), 15 deletions(-) create mode 100644 integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 5503fe876b..b74c922418 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableMap.Builder; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import lombok.EqualsAndHashCode; @@ -30,15 +31,20 @@ public class TrendlineOperator extends PhysicalPlan { @Getter private final List computations; @EqualsAndHashCode.Exclude private final List accumulators; @EqualsAndHashCode.Exclude private final Map fieldToIndexMap; + @EqualsAndHashCode.Exclude private final HashSet aliases; public TrendlineOperator(PhysicalPlan input, List computations) { this.input = input; this.computations = computations; this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList(); fieldToIndexMap = new HashMap<>(computations.size()); + aliases = new HashSet<>(computations.size()); for (int i = 0; i < computations.size(); ++i) { - - fieldToIndexMap.put(computations.get(i).getDataField().getChild().get(0).toString(), i); + final Trendline.TrendlineComputation computation = computations.get(i); + fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i); + if (computation.getAlias() != null) { + aliases.add(computation.getAlias()); + } } } @@ -61,36 +67,36 @@ public boolean hasNext() { public ExprValue next() { final ExprValue result; final ExprValue next = input.next(); - consumeInputTuple(next); - final Map inputStruct = ExprValueUtils.getTupleValue(next); + final Map inputStruct = consumeInputTuple(next); final Builder mapBuilder = new Builder<>(); mapBuilder.putAll(inputStruct); // Add calculated trendline values, which might overwrite existing fields from the input. for (int i = 0; i < accumulators.size(); ++i) { final ExprValue calculateResult = accumulators.get(i).calculate(); - if (null != calculateResult) { - if (null != computations.get(i).getAlias()) { - mapBuilder.put(computations.get(i).getAlias(), calculateResult); - } else { - mapBuilder.put( - computations.get(i).getDataField().getChild().get(0).toString(), calculateResult); - } + final String field = + null != computations.get(i).getAlias() + ? computations.get(i).getAlias() + : computations.get(i).getDataField().getChild().get(0).toString(); + if (calculateResult != null) { + mapBuilder.put(field, calculateResult); } } + result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); return result; } - private void consumeInputTuple(ExprValue inputValue) { + private Map consumeInputTuple(ExprValue inputValue) { final Map tupleValue = ExprValueUtils.getTupleValue(inputValue); for (String bindName : tupleValue.keySet()) { final Integer index = fieldToIndexMap.get(bindName); - if (index == null) { - continue; + if (index != null) { + accumulators.get(index).accumulate(tupleValue.get(bindName)); } - accumulators.get(index).accumulate(tupleValue.get(bindName)); } + tupleValue.keySet().removeAll(aliases); + return tupleValue; } private static TrendlineAccumulator createAccumulator( diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java new file mode 100644 index 0000000000..98e33c09a6 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; + +public class TrendlineCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.BANK); + } + + @Test + public void testTrendline() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } + + @Test + public void testTrendlineMultipleFields() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend sma(2, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(null, null), rows(44313.0, 28.5), rows(39882.5, 13.0)); + } + + @Test + public void testTrendlineOverwritesExistingField() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } +} From 289a74f1915050fad882ff35a6a5490fdf249205 Mon Sep 17 00:00:00 2001 From: James Duong Date: Sun, 27 Oct 2024 08:58:23 -0700 Subject: [PATCH 16/42] Add trendline documentation Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 63 +++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 docs/user/ppl/cmd/trendline.rst diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst new file mode 100644 index 0000000000..38d3375890 --- /dev/null +++ b/docs/user/ppl/cmd/trendline.rst @@ -0,0 +1,63 @@ +============= +rename +============= + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 2 + + +Description +============ +| Use the ``trendline`` command to calculate the moving average on one or more fields in a search result. + + +Syntax +============ +trendline (, ) AS [" " (, ) AS ]... + +* average-type: mandatory. The moving average computation. Can be ``sma`` (simple moving average) currently. +* number-of-samples: mandatory. The number of samples to use in the average calculation. Must be a positive non-zero integer. +* source-field: mandatory. The field to compute the average on. +* target-field: mandatory. The field name to report the computation under. + + +Example 1: Calculate the moving average on one field. +===================================================== + +The example shows how to calculate the moving average on one field. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) as an | fields an; + fetched rows / total rows = 4/4 + +------+ + | an | + |------| + | null | + | 3.5 | + | 9.5 | + | 15.5 | + +------+ + + +Example 2: Calculate the moving average on multiple fields. +=========================================================== + +The example shows how to calculate the moving average on multiple fields. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) as an sma(2, age) as age_trend | fields an, age_trend ; + fetched rows / total rows = 4/4 + +------+-----------+ + | an | age_trend | + |------|-----------| + | null | null | + | 3.5 | 34.0 | + | 9.5 | 32.0 | + | 15.5 | 30.5 | + +------+-----------+ + From 1ec45f02a1250f548f543f70d02b0fbbdfa55730 Mon Sep 17 00:00:00 2001 From: James Duong Date: Sun, 27 Oct 2024 14:28:39 -0700 Subject: [PATCH 17/42] Make null handling consistent with normal aggregation Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 2 +- .../physical/TrendlineOperatorTest.java | 33 +++++++++++++++++-- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index b74c922418..edc6b2162b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -131,7 +131,7 @@ public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation @Override public void accumulate(ExprValue value) { if (value == null) { - // Should this make the whole calculation null? + // Ignore null values, for consistency with average aggregate. return; } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 1a913fa31e..5b71842d09 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -140,6 +140,7 @@ public void calculates_simple_moving_average_multiple_computations() { assertFalse(plan.hasNext()); } + @Test public void alias_overwrites_input_field() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); when(inputPlan.next()) @@ -156,8 +157,7 @@ public void alias_overwrites_input_field() { plan.open(); assertTrue(plan.hasNext()); - assertEquals( - ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 100)), plan.next()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100)), plan.next()); assertTrue(plan.hasNext()); assertEquals( plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0))); @@ -166,4 +166,33 @@ public void alias_overwrites_input_field() { plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))); assertFalse(plan.hasNext()); } + + @Test + public void calculates_simple_moving_average_one_field_two_samples_three_rows_null_value() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + assertTrue(plan.hasNext()); + assertEquals( + plan.next(), + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0))); + assertFalse(plan.hasNext()); + } } From bf038837681173f2e78342d080fbf08c708ba84b Mon Sep 17 00:00:00 2001 From: James Duong Date: Sun, 27 Oct 2024 14:32:40 -0700 Subject: [PATCH 18/42] Make alias mandatory Signed-off-by: James Duong --- .../main/java/org/opensearch/sql/executor/Explain.java | 2 +- .../sql/planner/physical/TrendlineOperator.java | 9 ++------- .../opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java | 4 ++-- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index 096fb240dd..cdf689bd45 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -270,7 +270,7 @@ private List> describeTrendlineComputations( computation.getComputationType().name().toLowerCase(Locale.ROOT), "numberOfDataPoints", computation.getNumberOfDataPoints().toString(), "dataField", computation.getDataField().getChild().get(0).toString(), - "alias", computation.getAlias() != null ? computation.getAlias() : "")) + "alias", computation.getAlias())) .collect(Collectors.toList()); } } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index edc6b2162b..d5e6ef8da8 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -42,9 +42,7 @@ public TrendlineOperator(PhysicalPlan input, List Date: Sun, 27 Oct 2024 14:32:50 -0700 Subject: [PATCH 19/42] Remove weighted moving average stub Signed-off-by: James Duong --- .../sql/planner/physical/TrendlineOperator.java | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index d5e6ef8da8..2f5fb79038 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -177,15 +177,4 @@ public ExprValue calculate() { return runningAverage; } } - - private static class WeightedMovingAverageAccumulator implements TrendlineAccumulator { - - @Override - public void accumulate(ExprValue value) {} - - @Override - public ExprValue calculate() { - return null; - } - } } From 0c9dfb04b10f6ba220a7268ea83d23b9d93feb42 Mon Sep 17 00:00:00 2001 From: James Duong Date: Sun, 27 Oct 2024 15:01:35 -0700 Subject: [PATCH 20/42] Add DefaultImplementor and PhysicalPlanNodeVisitor tests Signed-off-by: James Duong --- .../sql/planner/DefaultImplementorTest.java | 18 ++++++++++++ .../physical/PhysicalPlanNodeVisitorTest.java | 28 +++++++++++++++++-- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 8e71fc2bec..a525998ffc 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -7,6 +7,7 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -44,8 +45,10 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.executor.pagination.PlanSerializer; @@ -63,11 +66,13 @@ import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.planner.physical.ProjectOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; @@ -304,4 +309,17 @@ public void visitLimit_support_return_takeOrdered() { 5); assertEquals(physicalPlanTree, logicalLimit.accept(implementor, null)); } + + @Test + public void visitTrendline_should_build_TrendlineOperator() { + var logicalChild = mock(LogicalPlan.class); + var physicalChild = mock(PhysicalPlan.class); + when(logicalChild.accept(implementor, null)).thenReturn(physicalChild); + final Trendline.TrendlineComputation computation = + AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"); + var logicalPlan = new LogicalTrendline(logicalChild, Collections.singletonList(computation)); + var implemented = logicalPlan.accept(implementor, null); + assertInstanceOf(TrendlineOperator.class, implemented); + assertSame(physicalChild, implemented.getChild().get(0)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 17fb128ace..716f83ae0b 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -29,6 +29,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -43,6 +44,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -65,7 +67,14 @@ public void print_physical_plan() { agg( rareTopN( filter( - limit(new TestScan(), 1, 1), + limit( + new TrendlineOperator( + new TestScan(), + Collections.singletonList( + AstDSL.computation( + 1, AstDSL.field("field"), "alias", "sma"))), + 1, + 1), DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))), CommandType.TOP, ImmutableList.of(), @@ -85,7 +94,8 @@ public void print_physical_plan() { + "\t\t\tAggregation->\n" + "\t\t\t\tRareTopN->\n" + "\t\t\t\t\tFilter->\n" - + "\t\t\t\t\t\tLimit->", + + "\t\t\t\t\t\tLimit->\n" + + "\t\t\t\t\t\t\tTrendline->", printer.print(plan)); } @@ -134,6 +144,12 @@ public static Stream getPhysicalPlanForTest() { PhysicalPlan cursorClose = new CursorCloseOperator(plan); + PhysicalPlan trendline = + new TrendlineOperator( + plan, + Collections.singletonList( + AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"))); + return Stream.of( Arguments.of(filter, "filter"), Arguments.of(aggregation, "aggregation"), @@ -149,7 +165,8 @@ public static Stream getPhysicalPlanForTest() { Arguments.of(rareTopN, "rareTopN"), Arguments.of(limit, "limit"), Arguments.of(nested, "nested"), - Arguments.of(cursorClose, "cursorClose")); + Arguments.of(cursorClose, "cursorClose"), + Arguments.of(trendline, "trendline")); } @ParameterizedTest(name = "{1}") @@ -223,6 +240,11 @@ public String visitLimit(LimitOperator node, Integer tabs) { return name(node, "Limit->", tabs); } + @Override + public String visitTrendline(TrendlineOperator node, Integer tabs) { + return name(node, "Trendline->", tabs); + } + private String name(PhysicalPlan node, String current, int tabs) { String child = node.getChild().get(0).accept(this, tabs + 1); StringBuilder sb = new StringBuilder(); From a5c455c397f91f408d9d479ccb05187d6c9a9e4d Mon Sep 17 00:00:00 2001 From: James Duong Date: Sun, 27 Oct 2024 15:56:15 -0700 Subject: [PATCH 21/42] Propagate type information Resolve base field types in prep for supporting datetime trendlines. Signed-off-by: James Duong --- .../org/opensearch/sql/analysis/Analyzer.java | 34 ++++++++++++++++--- .../org/opensearch/sql/executor/Explain.java | 6 +++- .../sql/planner/logical/LogicalPlanDSL.java | 3 +- .../sql/planner/logical/LogicalTrendline.java | 7 ++-- .../planner/physical/TrendlineOperator.java | 21 +++++++----- .../opensearch/sql/analysis/AnalyzerTest.java | 4 +-- .../opensearch/sql/executor/ExplainTest.java | 6 ++-- .../sql/planner/DefaultImplementorTest.java | 4 ++- .../logical/LogicalPlanNodeVisitorTest.java | 5 ++- .../physical/PhysicalPlanNodeVisitorTest.java | 8 +++-- .../physical/TrendlineOperatorTest.java | 30 ++++++++++++---- .../OpenSearchExecutionProtector.java | 5 +-- 12 files changed, 97 insertions(+), 36 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 82cb62d2e2..999aa30d3a 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -608,11 +608,37 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { .map(expression -> (Trendline.TrendlineComputation) expression) .toList(); + final ImmutableList.Builder> + computationsAndTypes = ImmutableList.builder(); computations.forEach( - computation -> - currEnv.define( - new Symbol(Namespace.FIELD_NAME, computation.getAlias()), ExprCoreType.DOUBLE)); - return new LogicalTrendline(child, computations); + computation -> { + final Expression resolvedField = + expressionAnalyzer.analyze(computation.getDataField(), context); + final ExprCoreType averageType; + // Duplicate the semantics of AvgAggregator#create(): + // - All numerical types have the DOUBLE type for the moving average. + // - All datetime types have the same datetime type for the moving average. + if (ExprCoreType.numberTypes().contains(resolvedField.type())) { + averageType = ExprCoreType.DOUBLE; + } else if (ExprCoreType.DATE == resolvedField.type()) { + averageType = ExprCoreType.DATE; + } else if (ExprCoreType.TIME == resolvedField.type()) { + averageType = ExprCoreType.TIME; + } else if (ExprCoreType.TIMESTAMP == resolvedField.type()) { + averageType = ExprCoreType.TIMESTAMP; + } else { + throw new SemanticCheckException( + String.format( + "Invalid field used for trendline computation %s. Source field %s had type %s" + + " but must be a numerical or datetime field.", + computation.getAlias(), + computation.getDataField().getChild().get(0), + resolvedField.type().typeName())); + } + currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType); + computationsAndTypes.add(Pair.of(computation, averageType)); + }); + return new LogicalTrendline(child, computationsAndTypes.build()); } @Override diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index cdf689bd45..31890a8090 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -222,7 +222,11 @@ public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context explainNode -> explainNode.setDescription( ImmutableMap.of( - "computations", describeTrendlineComputations(node.getComputations())))); + "computations", + describeTrendlineComputations( + node.getComputations().stream() + .map(Pair::getKey) + .collect(Collectors.toList()))))); } protected ExplainResponseNode explain( diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 3a4acdf3aa..13c6d7a979 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -16,6 +16,7 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; @@ -132,7 +133,7 @@ public static LogicalPlan rareTopN( } public static LogicalTrendline trendline( - LogicalPlan input, Trendline.TrendlineComputation... computations) { + LogicalPlan input, Pair... computations) { return new LogicalTrendline(input, Arrays.asList(computations)); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java index 3357b31e65..3e992035e2 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java @@ -10,7 +10,9 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.type.ExprCoreType; /* * Trendline logical plan. @@ -19,7 +21,7 @@ @ToString @EqualsAndHashCode(callSuper = true) public class LogicalTrendline extends LogicalPlan { - private final List computations; + private final List> computations; /** * Constructor of LogicalTrendline. @@ -27,7 +29,8 @@ public class LogicalTrendline extends LogicalPlan { * @param child child logical plan * @param computations the computations for this trendline call. */ - public LogicalTrendline(LogicalPlan child, List computations) { + public LogicalTrendline( + LogicalPlan child, List> computations) { super(Collections.singletonList(child)); this.computations = computations; } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 2f5fb79038..d1dc749a06 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -15,11 +15,14 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; + +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; @@ -28,19 +31,20 @@ @EqualsAndHashCode(callSuper = false) public class TrendlineOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; - @Getter private final List computations; + @Getter private final List> computations; @EqualsAndHashCode.Exclude private final List accumulators; @EqualsAndHashCode.Exclude private final Map fieldToIndexMap; @EqualsAndHashCode.Exclude private final HashSet aliases; - public TrendlineOperator(PhysicalPlan input, List computations) { + public TrendlineOperator( + PhysicalPlan input, List> computations) { this.input = input; this.computations = computations; this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList(); fieldToIndexMap = new HashMap<>(computations.size()); aliases = new HashSet<>(computations.size()); for (int i = 0; i < computations.size(); ++i) { - final Trendline.TrendlineComputation computation = computations.get(i); + final Trendline.TrendlineComputation computation = computations.get(i).getKey(); fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i); aliases.add(computation.getAlias()); } @@ -72,7 +76,7 @@ public ExprValue next() { // Add calculated trendline values, which might overwrite existing fields from the input. for (int i = 0; i < accumulators.size(); ++i) { final ExprValue calculateResult = accumulators.get(i).calculate(); - final String field = computations.get(i).getAlias(); + final String field = computations.get(i).getKey().getAlias(); if (calculateResult != null) { mapBuilder.put(field, calculateResult); } @@ -95,13 +99,14 @@ private Map consumeInputTuple(ExprValue inputValue) { } private static TrendlineAccumulator createAccumulator( - Trendline.TrendlineComputation computation) { - switch (computation.getComputationType()) { + Pair computation) { + switch (computation.getKey().getComputationType()) { case SMA: - return new SimpleMovingAverageAccumulator(computation); + return new SimpleMovingAverageAccumulator(computation.getKey()); case WMA: default: - throw new IllegalStateException("Unexpected value: " + computation.getComputationType()); + throw new IllegalStateException( + "Unexpected value: " + computation.getKey().getComputationType()); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index c7603de34c..3b25f4b9f4 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -1487,8 +1487,8 @@ public void trendline() { assertAnalyzeEqual( LogicalPlanDSL.trendline( LogicalPlanDSL.relation("schema", table), - computation(5, field("float_value"), "test_field_alias", "sma"), - computation(1, field("double_value"), "test_field_alias_2", "sma")), + Pair.of(computation(5, field("float_value"), "test_field_alias", "sma"), DOUBLE), + Pair.of(computation(1, field("double_value"), "test_field_alias_2", "sma"), DOUBLE)), AstDSL.trendline( AstDSL.relation("schema"), computation(5, field("float_value"), "test_field_alias", "sma"), diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index 72708f84f1..0e71f72b50 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -266,8 +266,10 @@ void can_explain_trendline() { new TrendlineOperator( tableScan, Arrays.asList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), - AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + DOUBLE), + Pair.of(AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"), DOUBLE))); assertEquals( new ExplainResponse( new ExplainResponseNode( diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index a525998ffc..b62f59f192 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -317,7 +317,9 @@ public void visitTrendline_should_build_TrendlineOperator() { when(logicalChild.accept(implementor, null)).thenReturn(physicalChild); final Trendline.TrendlineComputation computation = AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"); - var logicalPlan = new LogicalTrendline(logicalChild, Collections.singletonList(computation)); + var logicalPlan = + new LogicalTrendline( + logicalChild, Collections.singletonList(Pair.of(computation, ExprCoreType.DOUBLE))); var implemented = logicalPlan.accept(implementor, null); assertInstanceOf(TrendlineOperator.class, implemented); assertSame(physicalChild, implemented.getChild().get(0)); diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 9b947f6e21..8fd031e666 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -29,6 +29,7 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; @@ -146,7 +147,9 @@ public TableWriteOperator build(PhysicalPlan child) { new LogicalTrendline( relation, Collections.singletonList( - AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"))); + Pair.of( + AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"), + ExprCoreType.DOUBLE))); return Stream.of( relation, diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 716f83ae0b..f079791195 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -71,8 +71,10 @@ public void print_physical_plan() { new TrendlineOperator( new TestScan(), Collections.singletonList( - AstDSL.computation( - 1, AstDSL.field("field"), "alias", "sma"))), + Pair.of( + AstDSL.computation( + 1, AstDSL.field("field"), "alias", "sma"), + DOUBLE))), 1, 1), DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))), @@ -148,7 +150,7 @@ public static Stream getPhysicalPlanForTest() { new TrendlineOperator( plan, Collections.singletonList( - AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"))); + Pair.of(AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"), DOUBLE))); return Stream.of( Arguments.of(filter, "filter"), diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 5b71842d09..f1f88bdeca 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.Collections; +import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -21,6 +22,7 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) @@ -37,7 +39,9 @@ public void calculates_simple_moving_average_one_field_one_sample() { new TrendlineOperator( inputPlan, Collections.singletonList( - AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"))); + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); @@ -59,7 +63,9 @@ public void calculates_simple_moving_average_one_field_two_samples() { new TrendlineOperator( inputPlan, Collections.singletonList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); @@ -86,7 +92,9 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() new TrendlineOperator( inputPlan, Collections.singletonList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); @@ -118,8 +126,12 @@ public void calculates_simple_moving_average_multiple_computations() { new TrendlineOperator( inputPlan, Arrays.asList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), - AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE), + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); @@ -153,7 +165,9 @@ public void alias_overwrites_input_field() { new TrendlineOperator( inputPlan, Collections.singletonList( - AstDSL.computation(2, AstDSL.field("distance"), "time", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "time", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); @@ -180,7 +194,9 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu new TrendlineOperator( inputPlan, Collections.singletonList( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"))); + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE))); plan.open(); assertTrue(plan.hasNext()); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 41070d3f6f..358bc10ab4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -190,11 +190,8 @@ public PhysicalPlan visitML(PhysicalPlan node, Object context) { @Override public PhysicalPlan visitTrendline(TrendlineOperator node, Object context) { - TrendlineOperator trendlineOperator = (TrendlineOperator) node; return doProtect( - new TrendlineOperator( - visitInput(trendlineOperator.getInput(), context), - trendlineOperator.getComputations())); + new TrendlineOperator(visitInput(node.getInput(), context), node.getComputations())); } PhysicalPlan visitInput(PhysicalPlan node, Object context) { From 4abbae7451fb5c1597be613f571c4285fd33696c Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 28 Oct 2024 09:32:28 -0700 Subject: [PATCH 22/42] Tweak math to evaluate lazily and reduce division Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 42 ++++++++----------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index d1dc749a06..9dc38a96e4 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -15,16 +15,15 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.ToString; - import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Trendline; -import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; /** Trendline command implementation */ @ToString @@ -119,12 +118,12 @@ private interface TrendlineAccumulator { // TODO: Make the actual math polymorphic based on types to deal with datetimes. private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { - private final ExprValue dataPointsNeeded; + private final LiteralExpression dataPointsNeeded; private final EvictingQueue receivedValues; - private ExprValue runningAverage = null; + private Expression runningTotal = null; public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation) { - dataPointsNeeded = new ExprIntegerValue(computation.getNumberOfDataPoints()); + dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); } @@ -135,51 +134,44 @@ public void accumulate(ExprValue value) { return; } - if (dataPointsNeeded.integerValue() == 1) { - runningAverage = value; + if (dataPointsNeeded.valueOf().integerValue() == 1) { + runningTotal = DSL.literal(value); receivedValues.add(value); return; } final ExprValue valueToRemove; - if (receivedValues.size() == dataPointsNeeded.integerValue()) { + if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) { valueToRemove = receivedValues.remove(); } else { valueToRemove = null; } receivedValues.add(value); - if (receivedValues.size() == dataPointsNeeded.integerValue()) { - if (runningAverage != null) { - // We can use the previous average calculation. - // Subtract the evicted value / period and add the new value / period. - // Refactored, that would be previous + (newValue - oldValue) / period - runningAverage = - DSL.add( - DSL.literal(runningAverage), - DSL.divide( - DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove)), - DSL.literal(dataPointsNeeded.doubleValue()))) - .valueOf(); + if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) { + if (runningTotal != null) { + // We can use the previous calculation. + // Subtract the evicted value and add the new value. + // Refactored, that would be previous + (newValue - oldValue). + runningTotal = + DSL.add(runningTotal, DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove))); } else { // This is the first average calculation so sum the entire receivedValues dataset. final List data = receivedValues.stream().toList(); - Expression runningTotal = DSL.literal(0.0D); + runningTotal = DSL.literal(0.0D); for (ExprValue entry : data) { runningTotal = DSL.add(runningTotal, DSL.literal(entry)); } - runningAverage = - DSL.divide(runningTotal, DSL.literal(dataPointsNeeded.doubleValue())).valueOf(); } } } @Override public ExprValue calculate() { - if (receivedValues.size() < dataPointsNeeded.integerValue()) { + if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { return null; } - return runningAverage; + return DSL.divide(runningTotal, dataPointsNeeded).valueOf(); } } } From db6a08aabf4cba46295f8f4c9dbfee8c685bfc49 Mon Sep 17 00:00:00 2001 From: James Duong Date: Mon, 28 Oct 2024 16:11:01 -0700 Subject: [PATCH 23/42] Add support for moving averages on datetime types Signed-off-by: James Duong --- .../org/opensearch/sql/analysis/Analyzer.java | 31 +-- .../planner/physical/TrendlineOperator.java | 169 ++++++++++++++-- .../opensearch/sql/analysis/AnalyzerTest.java | 23 +++ .../physical/TrendlineOperatorTest.java | 180 ++++++++++++++++-- 4 files changed, 364 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 999aa30d3a..6bc1bdde0d 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -10,7 +10,10 @@ import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST; import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; @@ -620,20 +623,22 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { // - All datetime types have the same datetime type for the moving average. if (ExprCoreType.numberTypes().contains(resolvedField.type())) { averageType = ExprCoreType.DOUBLE; - } else if (ExprCoreType.DATE == resolvedField.type()) { - averageType = ExprCoreType.DATE; - } else if (ExprCoreType.TIME == resolvedField.type()) { - averageType = ExprCoreType.TIME; - } else if (ExprCoreType.TIMESTAMP == resolvedField.type()) { - averageType = ExprCoreType.TIMESTAMP; } else { - throw new SemanticCheckException( - String.format( - "Invalid field used for trendline computation %s. Source field %s had type %s" - + " but must be a numerical or datetime field.", - computation.getAlias(), - computation.getDataField().getChild().get(0), - resolvedField.type().typeName())); + switch (resolvedField.type()) { + case DATE: + case TIME: + case TIMESTAMP: + averageType = (ExprCoreType) resolvedField.type(); + break; + default: + throw new SemanticCheckException( + String.format( + "Invalid field used for trendline computation %s. Source field %s had type" + + " %s but must be a numerical or datetime field.", + computation.getAlias(), + computation.getDataField().getChild().get(0), + resolvedField.type().typeName())); + } } currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType); computationsAndTypes.add(Pair.of(computation, averageType)); diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 9dc38a96e4..65af0e01f6 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -5,8 +5,12 @@ package org.opensearch.sql.planner.physical; +import static java.time.temporal.ChronoUnit.MILLIS; + import com.google.common.collect.EvictingQueue; import com.google.common.collect.ImmutableMap.Builder; +import java.time.Instant; +import java.time.LocalTime; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -101,7 +105,7 @@ private static TrendlineAccumulator createAccumulator( Pair computation) { switch (computation.getKey().getComputationType()) { case SMA: - return new SimpleMovingAverageAccumulator(computation.getKey()); + return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); case WMA: default: throw new IllegalStateException( @@ -114,17 +118,34 @@ private interface TrendlineAccumulator { void accumulate(ExprValue value); ExprValue calculate(); + + static ArithmeticEvaluator getEvaluator(ExprCoreType type) { + switch (type) { + case DOUBLE: + return NumericArithmeticEvaluator.INSTANCE; + case DATE: + return DateArithmeticEvaluator.INSTANCE; + case TIME: + return TimeArithmeticEvaluator.INSTANCE; + case TIMESTAMP: + return TimestampArithmeticEvaluator.INSTANCE; + } + throw new IllegalArgumentException( + String.format("Invalid type %s used for moving average.", type.typeName())); + } } - // TODO: Make the actual math polymorphic based on types to deal with datetimes. private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { private final LiteralExpression dataPointsNeeded; private final EvictingQueue receivedValues; + private final ArithmeticEvaluator evaluator; private Expression runningTotal = null; - public SimpleMovingAverageAccumulator(Trendline.TrendlineComputation computation) { + public SimpleMovingAverageAccumulator( + Trendline.TrendlineComputation computation, ExprCoreType type) { dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); + evaluator = TrendlineAccumulator.getEvaluator(type); } @Override @@ -135,7 +156,7 @@ public void accumulate(ExprValue value) { } if (dataPointsNeeded.valueOf().integerValue() == 1) { - runningTotal = DSL.literal(value); + runningTotal = evaluator.calculateFirstTotal(Collections.singletonList(value)); receivedValues.add(value); return; } @@ -153,15 +174,11 @@ public void accumulate(ExprValue value) { // We can use the previous calculation. // Subtract the evicted value and add the new value. // Refactored, that would be previous + (newValue - oldValue). - runningTotal = - DSL.add(runningTotal, DSL.subtract(DSL.literal(value), DSL.literal(valueToRemove))); + runningTotal = evaluator.add(runningTotal, value, valueToRemove); } else { // This is the first average calculation so sum the entire receivedValues dataset. final List data = receivedValues.stream().toList(); - runningTotal = DSL.literal(0.0D); - for (ExprValue entry : data) { - runningTotal = DSL.add(runningTotal, DSL.literal(entry)); - } + runningTotal = evaluator.calculateFirstTotal(data); } } } @@ -170,8 +187,138 @@ public void accumulate(ExprValue value) { public ExprValue calculate() { if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { return null; + } else if (dataPointsNeeded.valueOf().integerValue() == 1) { + return receivedValues.peek(); + } + return evaluator.evaluate(runningTotal, dataPointsNeeded); + } + } + + private interface ArithmeticEvaluator { + Expression calculateFirstTotal(List dataPoints); + + Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); + + ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); + } + + private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { + private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); + + private NumericArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0.0D); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); } - return DSL.divide(runningTotal, dataPointsNeeded).valueOf(); + return DSL.literal(total.valueOf().doubleValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) + .valueOf() + .doubleValue()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); + } + } + + private static class DateArithmeticEvaluator implements ArithmeticEvaluator { + private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); + + private DateArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + final ExprValue timestampResult = + TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); + return ExprValueUtils.dateValue(timestampResult.dateValue()); + } + } + + private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); + + private TimeArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), + DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timeValue( + LocalTime.MIN.plus( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); + } + } + + private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimestampArithmeticEvaluator INSTANCE = new TimestampArithmeticEvaluator(); + + private TimestampArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(incomingValue.timestampValue().toEpochMilli()), + DSL.literal(evictedValue.timestampValue().toEpochMilli()))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); } } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 3b25f4b9f4..3366fe8aae 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -1495,6 +1495,29 @@ public void trendline() { computation(1, field("double_value"), "test_field_alias_2", "sma"))); } + @Test + public void trendline_datetime_types() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.relation("schema", table), + Pair.of( + computation(5, field("timestamp_value"), "test_field_alias", "sma"), TIMESTAMP)), + AstDSL.trendline( + AstDSL.relation("schema"), + computation(5, field("timestamp_value"), "test_field_alias", "sma"))); + } + + @Test + public void trendline_illegal_type() { + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.trendline( + AstDSL.relation("schema"), + computation(5, field("array_value"), "test_field_alias", "sma")))); + } + @Test public void ad_batchRCF_relation() { Map argumentMap = diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index f1f88bdeca..7fea71d03f 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -7,10 +7,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableMap; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalTime; import java.util.Arrays; import java.util.Collections; import org.apache.commons.lang3.tuple.Pair; @@ -73,9 +77,9 @@ public void calculates_simple_moving_average_one_field_two_samples() { ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); assertFalse(plan.hasNext()); } @@ -102,14 +106,14 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0))); + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0))); + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); assertFalse(plan.hasNext()); } @@ -139,16 +143,16 @@ public void calculates_simple_moving_average_multiple_computations() { ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0))); + "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), + plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( ImmutableMap.of( - "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0))); + "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)), + plan.next()); assertFalse(plan.hasNext()); } @@ -174,10 +178,10 @@ public void alias_overwrites_input_field() { assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0))); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0))); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)), plan.next()); assertFalse(plan.hasNext()); } @@ -203,12 +207,158 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void use_illegal_core_type() { + assertThrows( + IllegalArgumentException.class, + () -> { + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.ARRAY))); + }); + } + + @Test + public void calculates_simple_moving_average_date() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("date"), "date_alias", "sma"), + ExprCoreType.DATE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_time() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"), + ExprCoreType.TIME))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9))), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_timestamp() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", "sma"), + ExprCoreType.TIMESTAMP))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); assertTrue(plan.hasNext()); assertEquals( - plan.next(), ExprValueUtils.tupleValue( - ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0))); + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(500))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1250))), + plan.next()); assertFalse(plan.hasNext()); } } From ce2d089e6309955712138e706fc49f5459f72f0f Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 29 Oct 2024 09:18:15 -0700 Subject: [PATCH 24/42] Add missing doc link Signed-off-by: James Duong --- docs/user/ppl/index.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/user/ppl/index.rst b/docs/user/ppl/index.rst index 9525874c59..ef8cff334e 100644 --- a/docs/user/ppl/index.rst +++ b/docs/user/ppl/index.rst @@ -74,6 +74,8 @@ The query start with search command and then flowing a set of command delimited - `stats command `_ + - `trendline command `_ + - `where command `_ - `head command `_ From 0e114211abd9af16d790c6b822a6252021855be9 Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 29 Oct 2024 14:22:32 -0700 Subject: [PATCH 25/42] Fix code coverage gaps - Fix handling of possible null values from input field - Remove weighted moving average from computation type enum Signed-off-by: James Duong --- .../planner/physical/TrendlineOperator.java | 21 +++------ .../physical/TrendlineOperatorTest.java | 47 +++++++++++++++++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 1 - ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 - 4 files changed, 54 insertions(+), 16 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java index 65af0e01f6..7bf10964cf 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -94,7 +94,10 @@ private Map consumeInputTuple(ExprValue inputValue) { for (String bindName : tupleValue.keySet()) { final Integer index = fieldToIndexMap.get(bindName); if (index != null) { - accumulators.get(index).accumulate(tupleValue.get(bindName)); + final ExprValue fieldValue = tupleValue.get(bindName); + if (!fieldValue.isNull()) { + accumulators.get(index).accumulate(fieldValue); + } } } tupleValue.keySet().removeAll(aliases); @@ -103,14 +106,9 @@ private Map consumeInputTuple(ExprValue inputValue) { private static TrendlineAccumulator createAccumulator( Pair computation) { - switch (computation.getKey().getComputationType()) { - case SMA: - return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); - case WMA: - default: - throw new IllegalStateException( - "Unexpected value: " + computation.getKey().getComputationType()); - } + // Add a switch statement based on computation type to choose the accumulator when more + // types of computations are supported. + return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); } /** Maintains stateful information for calculating the trendline. */ @@ -150,11 +148,6 @@ public SimpleMovingAverageAccumulator( @Override public void accumulate(ExprValue value) { - if (value == null) { - // Ignore null values, for consistency with average aggregate. - return; - } - if (dataPointsNeeded.valueOf().integerValue() == 1) { runningTotal = evaluator.calculateFirstTotal(Collections.singletonList(value)); receivedValues.add(value); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 7fea71d03f..993767046e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -25,6 +25,7 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; @@ -216,6 +217,38 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu assertFalse(plan.hasNext()); } + @Test + public void use_null_value() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + plan.next()); + assertFalse(plan.hasNext()); + } + @Test public void use_illegal_core_type() { assertThrows( @@ -230,6 +263,20 @@ public void use_illegal_core_type() { }); } + @Test + public void use_illegal_computation_type() { + assertThrows( + IllegalArgumentException.class, + () -> { + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "fake"), + ExprCoreType.DOUBLE))); + }); + } + @Test public void calculates_simple_moving_average_date() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 6f4181c7f5..4a883fa656 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -60,7 +60,6 @@ NUM: 'NUM'; // TRENDLINE KEYWORDS SMA: 'SMA'; -WMA: 'WMA'; // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 6325ca0065..9077dd2b25 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -155,7 +155,6 @@ trendlineClause trendlineType : SMA - | WMA ; kmeansCommand From 68ac7adfb7d8e353448681ccdff0dac83eb22568 Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 29 Oct 2024 14:42:04 -0700 Subject: [PATCH 26/42] Fix docs typo Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 38d3375890..592b152d70 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -1,5 +1,5 @@ ============= -rename +trendline ============= .. rubric:: Table of contents From 048930c1d83f2462aef9b2bb5cf53df21ddc31f7 Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 29 Oct 2024 16:12:45 -0700 Subject: [PATCH 27/42] Add missing OpenSearchExecutionProtectorTest Signed-off-by: James Duong --- .../OpenSearchExecutionProtectorTest.java | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index da06c1eb66..10cdab34c8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -23,6 +23,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -37,10 +38,12 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.expression.DSL; @@ -67,6 +70,7 @@ import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -318,6 +322,22 @@ public void test_visitTakeOrdered() { resourceMonitor(takeOrdered), executionProtector.visitTakeOrdered(takeOrdered, null)); } + @Test + public void test_visitTrendline() { + final TrendlineOperator trendlineOperator = + new TrendlineOperator( + PhysicalPlanDSL.values(emptyList()), + Collections.singletonList( + Pair.of( + new Trendline.TrendlineComputation( + 1, AstDSL.field("dummy"), "dummy_alias", "sma"), + DOUBLE))); + + assertEquals( + resourceMonitor(trendlineOperator), + executionProtector.visitTrendline(trendlineOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } From 1dd720a75c65b15aa9bedc144db77d1c3a85fa0f Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 29 Oct 2024 16:12:54 -0700 Subject: [PATCH 28/42] Include trendline in docs test Signed-off-by: James Duong --- docs/category.json | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/category.json b/docs/category.json index aacfc43478..32f56cfb46 100644 --- a/docs/category.json +++ b/docs/category.json @@ -25,6 +25,7 @@ "user/ppl/cmd/sort.rst", "user/ppl/cmd/stats.rst", "user/ppl/cmd/syntax.rst", + "user/ppl/cmd/trendline.rst", "user/ppl/cmd/top.rst", "user/ppl/cmd/where.rst", "user/ppl/general/identifiers.rst", From 48a93e4ff568d8330efd5ddeb8663c09c78f02af Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 30 Oct 2024 10:36:35 -0700 Subject: [PATCH 29/42] Fix typo drawing example table Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 592b152d70..d092db2d59 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -54,7 +54,7 @@ PPL query:: fetched rows / total rows = 4/4 +------+-----------+ | an | age_trend | - |------|-----------| + |------+-----------| | null | null | | 3.5 | 34.0 | | 9.5 | 32.0 | From 4e2e2c04b5af113eef52c420c2a62f73971634b9 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 30 Oct 2024 17:20:23 -0700 Subject: [PATCH 30/42] Add explain integration test Signed-off-by: James Duong --- .../org/opensearch/sql/ppl/ExplainIT.java | 12 +++++++ .../ppl/explain_trendline_push.json | 32 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index b9c7f89ba0..52ba0bc411 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -100,6 +100,18 @@ public void testFillNullPushDownExplain() throws Exception { + " | fillnull with -1 in age,balance | fields age, balance")); } + @Test + public void testTrendlinePushDownExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_trendline_push.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| trendline sma(2, age) as ageTrend " + + "| fields ageTrend")); + } + String loadFromFile(String filename) throws Exception { URI uri = Resources.getResource(filename).toURI(); return new String(Files.readAllBytes(Paths.get(uri))); diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json new file mode 100644 index 0000000000..754535dc32 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json @@ -0,0 +1,32 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[ageTrend]" + }, + "children": [ + { + "name": "TrendlineOperator", + "description": { + "computations": [ + { + "computationType" : "sma", + "numberOfDataPoints" : "2", + "dataField" : "age", + "alias" : "ageTrend" + } + ] + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\"}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" + }, + "children": [] + } + ] + } + ] + } +} From 54a85690da0118741500d0aa2ffbf22e3f1ae53f Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 30 Oct 2024 17:27:58 -0700 Subject: [PATCH 31/42] Fix trendline explain IT test Signed-off-by: James Duong --- integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index 52ba0bc411..c604b74348 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -108,6 +108,7 @@ public void testTrendlinePushDownExplain() throws Exception { expected, explainQueryToString( "source=opensearch-sql_test_index_account" + + "| head 5 " + "| trendline sma(2, age) as ageTrend " + "| fields ageTrend")); } From aca31bdb383171af1dc1a713cd05a66c633f8b17 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 30 Oct 2024 18:00:18 -0700 Subject: [PATCH 32/42] Make the alias optional Also evaluate the computation type in the parser Signed-off-by: James Duong --- .../org/opensearch/sql/ast/dsl/AstDSL.java | 2 +- .../opensearch/sql/ast/tree/Trendline.java | 12 +++------ .../opensearch/sql/analysis/AnalyzerTest.java | 16 ++++++------ .../opensearch/sql/executor/ExplainTest.java | 6 ++--- .../sql/planner/DefaultImplementorTest.java | 3 ++- .../logical/LogicalPlanNodeVisitorTest.java | 3 ++- .../physical/PhysicalPlanNodeVisitorTest.java | 5 ++-- .../physical/TrendlineOperatorTest.java | 25 ++++++++++--------- .../sql/ppl/TrendlineCommandIT.java | 11 ++++++++ .../OpenSearchExecutionProtectorTest.java | 3 ++- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../sql/ppl/parser/AstExpressionBuilder.java | 14 ++++++++--- .../sql/ppl/parser/AstBuilderTest.java | 12 +++++++-- 13 files changed, 70 insertions(+), 44 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 816cabcc21..4aef491af3 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -473,7 +473,7 @@ public static Trendline trendline( } public static Trendline.TrendlineComputation computation( - Integer numDataPoints, UnresolvedExpression dataField, String alias, String type) { + Integer numDataPoints, Field dataField, String alias, Trendline.TrendlineType type) { return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index f58527621d..33eda22280 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -7,13 +7,13 @@ import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.Locale; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import lombok.ToString; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.UnresolvedExpression; @ToString @@ -45,20 +45,16 @@ public T accept(AbstractNodeVisitor visitor, C context) { public static class TrendlineComputation extends UnresolvedExpression { private final Integer numberOfDataPoints; - private final UnresolvedExpression dataField; + private final Field dataField; private final String alias; private final TrendlineType computationType; public TrendlineComputation( - Integer numberOfDataPoints, - UnresolvedExpression dataField, - String alias, - String computationType) { + Integer numberOfDataPoints, Field dataField, String alias, TrendlineType computationType) { this.numberOfDataPoints = numberOfDataPoints; this.dataField = dataField; this.alias = alias; - this.computationType = - Trendline.TrendlineType.valueOf(computationType.toUpperCase(Locale.ROOT)); + this.computationType = computationType; } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 3366fe8aae..265d878e66 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -34,6 +34,7 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; @@ -1487,12 +1488,12 @@ public void trendline() { assertAnalyzeEqual( LogicalPlanDSL.trendline( LogicalPlanDSL.relation("schema", table), - Pair.of(computation(5, field("float_value"), "test_field_alias", "sma"), DOUBLE), - Pair.of(computation(1, field("double_value"), "test_field_alias_2", "sma"), DOUBLE)), + Pair.of(computation(5, field("float_value"), "test_field_alias", SMA), DOUBLE), + Pair.of(computation(1, field("double_value"), "test_field_alias_2", SMA), DOUBLE)), AstDSL.trendline( AstDSL.relation("schema"), - computation(5, field("float_value"), "test_field_alias", "sma"), - computation(1, field("double_value"), "test_field_alias_2", "sma"))); + computation(5, field("float_value"), "test_field_alias", SMA), + computation(1, field("double_value"), "test_field_alias_2", SMA))); } @Test @@ -1500,11 +1501,10 @@ public void trendline_datetime_types() { assertAnalyzeEqual( LogicalPlanDSL.trendline( LogicalPlanDSL.relation("schema", table), - Pair.of( - computation(5, field("timestamp_value"), "test_field_alias", "sma"), TIMESTAMP)), + Pair.of(computation(5, field("timestamp_value"), "test_field_alias", SMA), TIMESTAMP)), AstDSL.trendline( AstDSL.relation("schema"), - computation(5, field("timestamp_value"), "test_field_alias", "sma"))); + computation(5, field("timestamp_value"), "test_field_alias", SMA))); } @Test @@ -1515,7 +1515,7 @@ public void trendline_illegal_type() { analyze( AstDSL.trendline( AstDSL.relation("schema"), - computation(5, field("array_value"), "test_field_alias", "sma")))); + computation(5, field("array_value"), "test_field_alias", SMA)))); } @Test diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index 0e71f72b50..febf662843 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.opensearch.sql.ast.tree.RareTopN.CommandType.TOP; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -267,9 +268,8 @@ void can_explain_trendline() { tableScan, Arrays.asList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), - DOUBLE), - Pair.of(AstDSL.computation(3, AstDSL.field("time"), "time_alias", "sma"), DOUBLE))); + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), DOUBLE), + Pair.of(AstDSL.computation(3, AstDSL.field("time"), "time_alias", SMA), DOUBLE))); assertEquals( new ExplainResponse( new ExplainResponseNode( diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index b62f59f192..8ee0dd7e70 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -13,6 +13,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; @@ -316,7 +317,7 @@ public void visitTrendline_should_build_TrendlineOperator() { var physicalChild = mock(PhysicalPlan.class); when(logicalChild.accept(implementor, null)).thenReturn(physicalChild); final Trendline.TrendlineComputation computation = - AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"); + AstDSL.computation(1, AstDSL.field("field"), "alias", SMA); var logicalPlan = new LogicalTrendline( logicalChild, Collections.singletonList(Pair.of(computation, ExprCoreType.DOUBLE))); diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 8fd031e666..43ce23ed56 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.Mockito.mock; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.named; @@ -148,7 +149,7 @@ public TableWriteOperator build(PhysicalPlan child) { relation, Collections.singletonList( Pair.of( - AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"), + AstDSL.computation(1, AstDSL.field("testField"), "dummy", SMA), ExprCoreType.DOUBLE))); return Stream.of( diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index f079791195..26f288e6b6 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.Mockito.mock; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.expression.DSL.named; @@ -73,7 +74,7 @@ public void print_physical_plan() { Collections.singletonList( Pair.of( AstDSL.computation( - 1, AstDSL.field("field"), "alias", "sma"), + 1, AstDSL.field("field"), "alias", SMA), DOUBLE))), 1, 1), @@ -150,7 +151,7 @@ public static Stream getPhysicalPlanForTest() { new TrendlineOperator( plan, Collections.singletonList( - Pair.of(AstDSL.computation(1, AstDSL.field("field"), "alias", "sma"), DOUBLE))); + Pair.of(AstDSL.computation(1, AstDSL.field("field"), "alias", SMA), DOUBLE))); return Stream.of( Arguments.of(filter, "filter"), diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index 993767046e..ae7c4255f7 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import com.google.common.collect.ImmutableMap; import java.time.Instant; @@ -45,7 +46,7 @@ public void calculates_simple_moving_average_one_field_one_sample() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -69,7 +70,7 @@ public void calculates_simple_moving_average_one_field_two_samples() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -98,7 +99,7 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows() inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -132,10 +133,10 @@ public void calculates_simple_moving_average_multiple_computations() { inputPlan, Arrays.asList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE), Pair.of( - AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"), + AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -171,7 +172,7 @@ public void alias_overwrites_input_field() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "time", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "time", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -200,7 +201,7 @@ public void calculates_simple_moving_average_one_field_two_samples_three_rows_nu inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -231,7 +232,7 @@ public void use_null_value() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.DOUBLE))); plan.open(); @@ -258,7 +259,7 @@ public void use_illegal_core_type() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "sma"), + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), ExprCoreType.ARRAY))); }); } @@ -294,7 +295,7 @@ public void calculates_simple_moving_average_date() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("date"), "date_alias", "sma"), + AstDSL.computation(2, AstDSL.field("date"), "date_alias", SMA), ExprCoreType.DATE))); plan.open(); @@ -341,7 +342,7 @@ public void calculates_simple_moving_average_time() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("time"), "time_alias", "sma"), + AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), ExprCoreType.TIME))); plan.open(); @@ -381,7 +382,7 @@ public void calculates_simple_moving_average_timestamp() { inputPlan, Collections.singletonList( Pair.of( - AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", "sma"), + AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", SMA), ExprCoreType.TIMESTAMP))); plan.open(); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index 98e33c09a6..ae8d33cdb2 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -53,4 +53,15 @@ public void testTrendlineOverwritesExistingField() throws IOException { TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } + + @Test + public void testTrendlineNoAlias() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index 10cdab34c8..724178bd34 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.*; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -330,7 +331,7 @@ public void test_visitTrendline() { Collections.singletonList( Pair.of( new Trendline.TrendlineComputation( - 1, AstDSL.field("dummy"), "dummy_alias", "sma"), + 1, AstDSL.field("dummy"), "dummy_alias", SMA), DOUBLE))); assertEquals( diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 9077dd2b25..298e98da4b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -150,7 +150,7 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS AS alias = fieldExpression + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = fieldExpression)? ; trendlineType diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index da6edbb03d..6100a38998 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -45,6 +45,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -79,10 +80,15 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { /** Trendline clause. */ @Override public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) { - Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); - Field dataField = (Field) this.visitFieldExpression(ctx.field); - String alias = ctx.alias.getText(); - String computationType = ctx.trendlineType().getText(); + final Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + final Field dataField = (Field) this.visitFieldExpression(ctx.field); + final String alias = + ctx.alias != null + ? ctx.alias.getText() + : dataField.getChild().get(0).toString() + "_trendline"; + + final Trendline.TrendlineType computationType = + Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); return new Trendline.TrendlineComputation( numberOfDataPoints, dataField, alias, computationType); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 73cd01f3cf..6cedacbaba 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -41,6 +41,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.tableFunction; import static org.opensearch.sql.ast.dsl.AstDSL.trendline; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.utils.SystemIndexUtils.DATASOURCES_TABLE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; @@ -700,8 +701,15 @@ public void testTrendline() { + " test_field_alias_2", trendline( relation("t"), - computation(5, field("test_field"), "test_field_alias", "sma"), - computation(1, field("test_field_2"), "test_field_alias_2", "sma"))); + computation(5, field("test_field"), "test_field_alias", SMA), + computation(1, field("test_field_2"), "test_field_alias_2", SMA))); + } + + @Test + public void testTrendlineNoAlias() { + assertEqual( + "source=t | trendline sma(5, test_field)", + trendline(relation("t"), computation(5, field("test_field"), "test_field_trendline", SMA))); } @Test From 50d590a274f8732fb708dd6a294db1fb578417d6 Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 30 Oct 2024 18:07:14 -0700 Subject: [PATCH 33/42] Add validation on number of data points Signed-off-by: James Duong --- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../opensearch/sql/ppl/parser/AstExpressionBuilder.java | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 298e98da4b..271089682b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -150,7 +150,7 @@ trendlineCommand ; trendlineClause - : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = fieldExpression)? + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? ; trendlineType diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 6100a38998..746ebb6452 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -54,6 +54,7 @@ import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; @@ -80,7 +81,11 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { /** Trendline clause. */ @Override public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) { - final Integer numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + final int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + } + final Field dataField = (Field) this.visitFieldExpression(ctx.field); final String alias = ctx.alias != null From 07c7efbdf8e7e4c7a443d737201370fd27c6a09d Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 31 Oct 2024 14:10:56 -0700 Subject: [PATCH 34/42] Add sort functionality to trendline Sort by creating a LogicalSort between the input plan and LogicalTrendline Signed-off-by: James Duong --- .../org/opensearch/sql/analysis/Analyzer.java | 56 +++++++++++-------- .../org/opensearch/sql/ast/dsl/AstDSL.java | 7 ++- .../opensearch/sql/ast/tree/Trendline.java | 4 +- .../opensearch/sql/analysis/AnalyzerTest.java | 23 ++++++++ .../physical/TrendlineOperatorTest.java | 14 ----- .../org/opensearch/sql/ppl/ExplainIT.java | 13 +++++ .../sql/ppl/TrendlineCommandIT.java | 11 ++++ .../ppl/explain_trendline_sort_push.json | 32 +++++++++++ ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 13 ++++- .../sql/ppl/parser/AstExpressionBuilder.java | 6 +- .../sql/ppl/utils/PPLQueryDataAnonymizer.java | 5 +- 12 files changed, 137 insertions(+), 49 deletions(-) create mode 100644 integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 6bc1bdde0d..d0051568c4 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -25,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -474,23 +475,7 @@ public LogicalPlan visitParse(Parse node, AnalysisContext context) { @Override public LogicalPlan visitSort(Sort node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - ExpressionReferenceOptimizer optimizer = - new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); - - List> sortList = - node.getSortList().stream() - .map( - sortField -> { - var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); - if (analyzed == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", sortField.getField())); - } - Expression expression = optimizer.optimize(analyzed, context); - return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); - }) - .collect(Collectors.toList()); - return new LogicalSort(child, sortList); + return buildSort(child, context, node.getSortList()); } /** Build {@link LogicalDedupe}. */ @@ -605,12 +590,7 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { final LogicalPlan child = node.getChild().get(0).accept(this, context); final TypeEnvironment currEnv = context.peek(); - final List unresolvedComputations = node.getComputations(); - final List computations = - unresolvedComputations.stream() - .map(expression -> (Trendline.TrendlineComputation) expression) - .toList(); - + final List computations = node.getComputations(); final ImmutableList.Builder> computationsAndTypes = ImmutableList.builder(); computations.forEach( @@ -643,7 +623,14 @@ public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType); computationsAndTypes.add(Pair.of(computation, averageType)); }); - return new LogicalTrendline(child, computationsAndTypes.build()); + + if (node.getSortByField().isEmpty()) { + return new LogicalTrendline(child, computationsAndTypes.build()); + } + + return new LogicalTrendline( + buildSort(child, context, Collections.singletonList(node.getSortByField().get())), + computationsAndTypes.build()); } @Override @@ -664,6 +651,27 @@ public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext con return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context)); } + private LogicalSort buildSort( + LogicalPlan child, AnalysisContext context, List sortFields) { + ExpressionReferenceOptimizer optimizer = + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); + + List> sortList = + sortFields.stream() + .map( + sortField -> { + var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); + if (analyzed == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", sortField.getField())); + } + Expression expression = optimizer.optimize(analyzed, context); + return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); + }) + .collect(Collectors.toList()); + return new LogicalSort(child, sortList); + } + /** * The first argument is always "asc", others are optional. Given nullFirst argument, use its * value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 4aef491af3..d9956609ec 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -468,8 +469,10 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) { } public static Trendline trendline( - UnresolvedPlan input, Trendline.TrendlineComputation... computations) { - return new Trendline(Arrays.asList(computations)).attach(input); + UnresolvedPlan input, + Optional sortField, + Trendline.TrendlineComputation... computations) { + return new Trendline(sortField, Arrays.asList(computations)).attach(input); } public static Trendline.TrendlineComputation computation( diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 33eda22280..3f9f9e2fbc 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Optional; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -23,7 +24,8 @@ public class Trendline extends UnresolvedPlan { private UnresolvedPlan child; - private final List computations; + private final Optional sortByField; + private final List computations; @Override public Trendline attach(UnresolvedPlan child) { diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 265d878e66..d6cb0544d8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -68,6 +68,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -91,6 +92,7 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -1492,6 +1494,7 @@ public void trendline() { Pair.of(computation(1, field("double_value"), "test_field_alias_2", SMA), DOUBLE)), AstDSL.trendline( AstDSL.relation("schema"), + Optional.empty(), computation(5, field("float_value"), "test_field_alias", SMA), computation(1, field("double_value"), "test_field_alias_2", SMA))); } @@ -1504,6 +1507,7 @@ public void trendline_datetime_types() { Pair.of(computation(5, field("timestamp_value"), "test_field_alias", SMA), TIMESTAMP)), AstDSL.trendline( AstDSL.relation("schema"), + Optional.empty(), computation(5, field("timestamp_value"), "test_field_alias", SMA))); } @@ -1515,9 +1519,28 @@ public void trendline_illegal_type() { analyze( AstDSL.trendline( AstDSL.relation("schema"), + Optional.empty(), computation(5, field("array_value"), "test_field_alias", SMA)))); } + @Test + public void trendline_with_sort() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.sort( + LogicalPlanDSL.relation("schema", table), + Pair.of( + new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST), + DSL.ref("float_value", ExprCoreType.FLOAT))), + Pair.of(computation(5, field("float_value"), "test_field_alias", SMA), DOUBLE), + Pair.of(computation(1, field("double_value"), "test_field_alias_2", SMA), DOUBLE)), + AstDSL.trendline( + AstDSL.relation("schema"), + Optional.of(field("float_value", argument("asc", booleanLiteral(true)))), + computation(5, field("float_value"), "test_field_alias", SMA), + computation(1, field("double_value"), "test_field_alias_2", SMA))); + } + @Test public void ad_batchRCF_relation() { Map argumentMap = diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java index ae7c4255f7..ef2c2907ce 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -264,20 +264,6 @@ public void use_illegal_core_type() { }); } - @Test - public void use_illegal_computation_type() { - assertThrows( - IllegalArgumentException.class, - () -> { - new TrendlineOperator( - inputPlan, - Collections.singletonList( - Pair.of( - AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", "fake"), - ExprCoreType.DOUBLE))); - }); - } - @Test public void calculates_simple_moving_average_date() { when(inputPlan.hasNext()).thenReturn(true, true, true, false); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index c604b74348..531a24bad6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -113,6 +113,19 @@ public void testTrendlinePushDownExplain() throws Exception { + "| fields ageTrend")); } + @Test + public void testTrendlineWithSortPushDownExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_trendline_sort_push.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| head 5 " + + "| trendline sort age sma(2, age) as ageTrend " + + "| fields ageTrend")); + } + String loadFromFile(String filename) throws Exception { URI uri = Resources.getResource(filename).toURI(); return new String(Files.readAllBytes(Paths.get(uri))); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java index ae8d33cdb2..38baa0f01f 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -64,4 +64,15 @@ public void testTrendlineNoAlias() throws IOException { TEST_INDEX_BANK)); verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); } + + @Test + public void testTrendlineWithSort() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | trendline sort balance sma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } } diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json new file mode 100644 index 0000000000..6629434108 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json @@ -0,0 +1,32 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[ageTrend]" + }, + "children": [ + { + "name": "TrendlineOperator", + "description": { + "computations": [ + { + "computationType" : "sma", + "numberOfDataPoints" : "2", + "dataField" : "age", + "alias" : "ageTrend" + } + ] + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":5,\"timeout\":\"1m\",\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" + }, + "children": [] + } + ] + } + ] + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 271089682b..0d3a328a9f 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -146,7 +146,7 @@ nullReplacementExpression : nullableField = fieldExpression EQUAL nullReplacement = valueExpression trendlineCommand - : TRENDLINE trendlineClause (trendlineClause)* + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* ; trendlineClause diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index d0e4224a81..c3c31ee2e1 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -425,9 +425,16 @@ public UnresolvedPlan visitFillNullWithFieldVariousValues( /** trendline command. */ @Override public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { - List trendlineComputations = - ctx.trendlineClause().stream().map(expressionBuilder::visit).collect(Collectors.toList()); - return new Trendline(trendlineComputations); + List trendlineComputations = + ctx.trendlineClause().stream() + .map(expressionBuilder::visit) + .map(Trendline.TrendlineComputation.class::cast) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); } /** Get original text in query. */ diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 746ebb6452..8bc98c8eee 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -80,10 +80,12 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { /** Trendline clause. */ @Override - public UnresolvedExpression visitTrendlineClause(OpenSearchPPLParser.TrendlineClauseContext ctx) { + public Trendline.TrendlineComputation visitTrendlineClause( + OpenSearchPPLParser.TrendlineClauseContext ctx) { final int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); if (numberOfDataPoints < 1) { - throw new SyntaxCheckException("Number of trendline data-points must be greater than or equal to 1"); + throw new SyntaxCheckException( + "Number of trendline data-points must be greater than or equal to 1"); } final Field dataField = (Field) this.visitFieldExpression(ctx.field); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index a0047e3c8d..96e21eafcd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -234,11 +234,12 @@ private String visitFieldList(List fieldList) { return fieldList.stream().map(this::visitExpression).collect(Collectors.joining(",")); } - private String visitExpressionList(List expressionList) { + private String visitExpressionList(List expressionList) { return visitExpressionList(expressionList, ","); } - private String visitExpressionList(List expressionList, String delimiter) { + private String visitExpressionList( + List expressionList, String delimiter) { return expressionList.isEmpty() ? "" : expressionList.stream().map(this::visitExpression).collect(Collectors.joining(delimiter)); From b60093ec97756db532271f9dfc77442e7096ec31 Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 31 Oct 2024 14:25:41 -0700 Subject: [PATCH 35/42] Make docs more consistent with Spark Also add examples with the sort option and without an alias Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 41 +++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index d092db2d59..14d12f34d0 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -11,18 +11,27 @@ trendline Description ============ -| Use the ``trendline`` command to calculate the moving average on one or more fields in a search result. - +| Using ``trendline`` command to calculate moving averages of fields. Syntax ============ -trendline (, ) AS [" " (, ) AS ]... +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. the name of the field the moving average should be calculated for. +* alias: optional. the name of the resulting column containing the moving average. + +And the moment only the Simple Moving Average (SMA) type is supported. -* average-type: mandatory. The moving average computation. Can be ``sma`` (simple moving average) currently. -* number-of-samples: mandatory. The number of samples to use in the average calculation. Must be a positive non-zero integer. -* source-field: mandatory. The field to compute the average on. -* target-field: mandatory. The field name to report the computation under. +It is calculated like + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t Example 1: Calculate the moving average on one field. ===================================================== @@ -61,3 +70,21 @@ PPL query:: | 15.5 | 30.5 | +------+-----------+ +Example 4: Calculate the moving average on one field without specifying an alias. +================================================================================= + +The example shows how to calculate the moving average on one field. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) | fields accounts_trendline; + fetched rows / total rows = 4/4 + +--------------------+ + | accounts_trendline | + |--------------------| + | null | + | 3.5 | + | 9.5 | + | 15.5 | + +--------------------+ + From 25ccada11d0a9ba4804ba39a4fa79a83ef42df09 Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 31 Oct 2024 15:51:59 -0700 Subject: [PATCH 36/42] Fix docs typo in example Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 14d12f34d0..1212dc3bb7 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -77,7 +77,7 @@ The example shows how to calculate the moving average on one field. PPL query:: - os> source=accounts | trendline sma(2, account_number) | fields accounts_trendline; + os> source=accounts | trendline sma(2, account_number) | fields account_number_trendline; fetched rows / total rows = 4/4 +--------------------+ | accounts_trendline | From c7b7cb1385ed29dbbb36a913406f64b4796593d0 Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 31 Oct 2024 16:00:42 -0700 Subject: [PATCH 37/42] Add missed update to AstBuilderTest for sort option Signed-off-by: James Duong --- .../sql/ppl/parser/AstBuilderTest.java | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 6cedacbaba..b28992d8d7 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -48,6 +48,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.Optional; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -701,15 +702,61 @@ public void testTrendline() { + " test_field_alias_2", trendline( relation("t"), + Optional.empty(), computation(5, field("test_field"), "test_field_alias", SMA), computation(1, field("test_field_2"), "test_field_alias_2", SMA))); } + @Test + public void testTrendlineSort() { + assertEqual( + "source=t | trendline sort test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(true)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineSortDesc() { + assertEqual( + "source=t | trendline sort - test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(false)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineSortAsc() { + assertEqual( + "source=t | trendline sort + test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(true)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + @Test public void testTrendlineNoAlias() { assertEqual( "source=t | trendline sma(5, test_field)", - trendline(relation("t"), computation(5, field("test_field"), "test_field_trendline", SMA))); + trendline( + relation("t"), + Optional.empty(), + computation(5, field("test_field"), "test_field_trendline", SMA))); } @Test From 23917a580b1a13f6b0dff6cbc0a037095533c187 Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 7 Nov 2024 12:00:05 -0800 Subject: [PATCH 38/42] Add test for checking an invalid number of samples Signed-off-by: James Duong --- .../java/org/opensearch/sql/ppl/parser/AstBuilderTest.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index b28992d8d7..c6f4ed2044 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -7,6 +7,7 @@ import static java.util.Collections.emptyList; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.agg; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; @@ -64,6 +65,7 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; public class AstBuilderTest { @@ -759,6 +761,11 @@ public void testTrendlineNoAlias() { computation(5, field("test_field"), "test_field_trendline", SMA))); } + @Test + public void testTrendlineTooFewSamples() { + assertThrows(SyntaxCheckException.class, () -> plan("source=t | trendline sma(0, test_field)")); + } + @Test public void testDescribeCommand() { assertEqual("describe t", relation(mappingTable("t"))); From b01dded7595a1d5d27337921e4dcf2cd502605ef Mon Sep 17 00:00:00 2001 From: James Duong Date: Wed, 27 Nov 2024 11:58:15 -0800 Subject: [PATCH 39/42] Add Trendline to KeywordsCanBeId Signed-off-by: James Duong --- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 + 1 file changed, 1 insertion(+) diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 0d3a328a9f..1b7739d81c 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -888,6 +888,7 @@ keywordsCanBeId | KMEANS | AD | ML + | TRENDLINE // commands assist keywords | SOURCE | INDEX From 8da8255548beb1d2cdd4f0e2274998b00443c9dc Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 3 Dec 2024 16:31:11 -0800 Subject: [PATCH 40/42] Fix wrong column name in docs Signed-off-by: James Duong --- docs/user/ppl/cmd/trendline.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 1212dc3bb7..4e528508f9 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -79,12 +79,12 @@ PPL query:: os> source=accounts | trendline sma(2, account_number) | fields account_number_trendline; fetched rows / total rows = 4/4 - +--------------------+ - | accounts_trendline | - |--------------------| - | null | - | 3.5 | - | 9.5 | - | 15.5 | - +--------------------+ + +--------------------------+ + | account_number_trendline | + |--------------------------| + | null | + | 3.5 | + | 9.5 | + | 15.5 | + +--------------------------+ From 9f1684f6b4ea444ff5b6ac7f69b3303af4a637ac Mon Sep 17 00:00:00 2001 From: James Duong Date: Tue, 10 Dec 2024 11:10:05 -0800 Subject: [PATCH 41/42] Fix rebase error in parse Add back missing semi-colon Signed-off-by: James Duong --- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 1 + 1 file changed, 1 insertion(+) diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 1b7739d81c..c9d0f2e110 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -144,6 +144,7 @@ fillNullWithFieldVariousValues nullReplacementExpression : nullableField = fieldExpression EQUAL nullReplacement = valueExpression + ; trendlineCommand : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* From 0ff1653c0a2a90f78ac853ff633df0e36493f316 Mon Sep 17 00:00:00 2001 From: Andrew Carbonetto Date: Thu, 12 Dec 2024 10:19:29 -0800 Subject: [PATCH 42/42] PPL-Trendline: remove unused grammar; clean doc Signed-off-by: Andrew Carbonetto --- .../main/java/org/opensearch/sql/ast/tree/Trendline.java | 3 +-- docs/user/ppl/cmd/trendline.rst | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java index 3f9f9e2fbc..aa4fcc200d 100644 --- a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -66,7 +66,6 @@ public R accept(AbstractNodeVisitor nodeVisitor, C context) { } public enum TrendlineType { - SMA, - WMA + SMA } } diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst index 4e528508f9..166a3c056f 100644 --- a/docs/user/ppl/cmd/trendline.rst +++ b/docs/user/ppl/cmd/trendline.rst @@ -19,9 +19,9 @@ Syntax * [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. * sort-field: mandatory when sorting is used. The field used to sort. -* number-of-datapoints: mandatory. number of datapoints to calculate the moving average (must be greater than zero). -* field: mandatory. the name of the field the moving average should be calculated for. -* alias: optional. the name of the resulting column containing the moving average. +* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. The name of the field the moving average should be calculated for. +* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). And the moment only the Simple Moving Average (SMA) type is supported.