diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 2a886ba0ca..3a4acdf3aa 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -15,6 +15,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; @@ -130,6 +131,11 @@ public static LogicalPlan rareTopN( return new LogicalRareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupByList); } + public static LogicalTrendline trendline( + LogicalPlan input, Trendline.TrendlineComputation... computations) { + return new LogicalTrendline(input, Arrays.asList(computations)); + } + @SafeVarargs public LogicalPlan values(List... values) { return new LogicalValues(Arrays.asList(values)); diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 8d935b11d2..d2613409b8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; +import static org.opensearch.sql.ast.dsl.AstDSL.computation; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; @@ -1437,6 +1438,19 @@ public void kmeanns_relation() { new Kmeans(AstDSL.relation("schema"), argumentMap)); } + @Test + public void trendline() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.relation("schema", table), + computation(5, field("float_value"), "test_field_alias", "sma"), + computation(1, field("double_value"), "test_field_alias_2", "sma")), + AstDSL.trendline( + AstDSL.relation("schema"), + computation(5, field("float_value"), "test_field_alias", "sma"), + computation(1, field("double_value"), "test_field_alias_2", "sma"))); + } + @Test public void ad_batchRCF_relation() { Map argumentMap = diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f212749f48..9b947f6e21 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -25,6 +25,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; @@ -141,6 +142,12 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + LogicalTrendline trendline = + new LogicalTrendline( + relation, + Collections.singletonList( + AstDSL.computation(1, AstDSL.field("testField"), "dummy", "sma"))); + return Stream.of( relation, tableScanBuilder, @@ -163,7 +170,8 @@ public TableWriteOperator build(PhysicalPlan child) { paginate, nested, cursor, - closeCursor) + closeCursor, + trendline) .map(Arguments::of); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index 4931689cfb..046c38ee5b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -670,7 +670,7 @@ public void testTrendline() { trendline( relation("t"), computation(5, field("test_field"), "test_field_alias", "sma"), - computation(1, field("test_field)2"), "test_field_alias_2", "sma"))); + computation(1, field("test_field_2"), "test_field_alias_2", "sma"))); } @Test