From 4d60344c418be83680fa33aa3c3c3c2485174ccc Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Wed, 23 Oct 2024 19:01:20 +0000 Subject: [PATCH 01/19] added ast tree to simplify expression lifetime management --- .../cudf/ast/detail/expression_parser.hpp | 3 +- cpp/include/cudf/ast/expressions.hpp | 72 +++++++++++++++++-- cpp/src/ast/expression_parser.cpp | 2 +- cpp/src/ast/expressions.cpp | 21 +++--- cpp/src/io/parquet/predicate_pushdown.cpp | 7 +- cpp/src/io/parquet/reader_impl_helpers.hpp | 2 +- 6 files changed, 87 insertions(+), 20 deletions(-) diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index a254171ef11..1d27bcbe6b5 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -300,7 +301,7 @@ class expression_parser { * @return The indices of the operands stored in the data references. */ std::vector visit_operands( - std::vector> operands); + cudf::host_span const> operands); /** * @brief Add a data reference to the internal list. diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 4299ee5f20f..4ec7ed4c4ee 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -22,6 +22,8 @@ #include #include +#include +#include namespace CUDF_EXPORT cudf { namespace ast { @@ -478,9 +480,9 @@ class operation : public expression { * * @return Vector of operands */ - [[nodiscard]] std::vector> get_operands() const + [[nodiscard]] cudf::host_span const> get_operands() const { - return operands; + return cudf::host_span const>{operands, arity}; } /** @@ -498,16 +500,19 @@ class operation : public expression { table_view const& right, rmm::cuda_stream_view stream) const override { - return std::any_of(operands.cbegin(), - operands.cend(), + auto operands = get_operands(); + return std::any_of(operands.begin(), + operands.end(), [&left, &right, &stream](std::reference_wrapper subexpr) { return subexpr.get().may_evaluate_null(left, right, stream); }); }; private: - ast_operator const op; - std::vector> const operands; + ast_operator op; + // TODO: replace with cuda::std::inplace_vector, 2> + std::reference_wrapper operands[2]; + size_t arity; }; /** @@ -552,6 +557,61 @@ class column_name_reference : public expression { std::string column_name; }; +/** + * @brief An AST expression tree. it owns and contains multiple dependent expressions. All the + * expressions are destroyed once the tree is destructed. + */ +class tree { + public: + tree() = default; + tree(tree const&) = delete; + tree(tree&&) = default; + tree& operator=(tree const&) = delete; + tree& operator=(tree&&) = default; + ~tree() = default; + + /** + @brief Add an expression to the AST tree + @param expr AST expression to be added + @param args Arguments to use to construct the ast expression + @returns a reference to the added expression + */ + template + expression const& emplace(Args&&... args) + { + static_assert(std::is_base_of_v); + return *expressions.emplace_back(std::make_unique(std::forward(args)...)); + } + + /** + @brief Add an expression to the AST tree + @param expr AST expression to be added + @returns a reference to the added expression + */ + template + expression const& push(Expr expr) + { + return emplace(std::move(expr)); + } + + expression const& front() const { return *expressions.front(); } + + expression const& back() const { return *expressions.back(); } + + size_t size() const { return expressions.size(); } + + expression const& at(size_t index) { return *expressions.at(index); } + + expression const& operator[](size_t index) const { return *expressions[index]; } + + cudf::host_span const> get_expressions() const { return expressions; } + + private: + // TODO: ideally we'd use a custom bump allocator for constructing the expression objects and + // release all the objects at once. + std::vector> expressions; +}; + /** @} */ // end of group } // namespace ast diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index 3b650d791aa..f01374d758e 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -210,7 +210,7 @@ cudf::data_type expression_parser::output_type() const } std::vector expression_parser::visit_operands( - std::vector> operands) + cudf::host_span const> operands) { auto operand_data_reference_indices = std::vector(); for (auto const& operand : operands) { diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp index b45b9d0c78c..e0df22a3600 100644 --- a/cpp/src/ast/expressions.cpp +++ b/cpp/src/ast/expressions.cpp @@ -26,33 +26,35 @@ namespace cudf { namespace ast { -operation::operation(ast_operator op, expression const& input) : op(op), operands({input}) +operation::operation(ast_operator op, expression const& input) + : op{op}, operands{input, input}, arity{1} { - if (cudf::ast::detail::ast_operator_arity(op) != 1) { - CUDF_FAIL("The provided operator is not a unary operator."); - } + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1, + "The provided operator is not a unary operator."); } operation::operation(ast_operator op, expression const& left, expression const& right) - : op(op), operands({left, right}) + : op{op}, operands{left, right}, arity{2} { - if (cudf::ast::detail::ast_operator_arity(op) != 2) { - CUDF_FAIL("The provided operator is not a binary operator."); - } + CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, + "The provided operator is not a binary operator."); } cudf::size_type literal::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type column_reference::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type operation::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); } + cudf::size_type column_name_reference::accept(detail::expression_parser& visitor) const { return visitor.visit(*this); @@ -63,16 +65,19 @@ auto literal::accept(detail::expression_transformer& visitor) const { return visitor.visit(*this); } + auto column_reference::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { return visitor.visit(*this); } + auto operation::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { return visitor.visit(*this); } + auto column_name_reference::accept(detail::expression_transformer& visitor) const -> decltype(visitor.visit(*this)) { diff --git a/cpp/src/io/parquet/predicate_pushdown.cpp b/cpp/src/io/parquet/predicate_pushdown.cpp index f0a0bc0b51b..b51fdb23a9a 100644 --- a/cpp/src/io/parquet/predicate_pushdown.cpp +++ b/cpp/src/io/parquet/predicate_pushdown.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -374,7 +375,7 @@ class stats_expression_converter : public ast::detail::expression_transformer { private: std::vector> visit_operands( - std::vector> operands) + cudf::host_span const> operands) { std::vector> transformed_operands; for (auto const& operand : operands) { @@ -551,7 +552,7 @@ std::reference_wrapper named_to_reference_converter::visi std::vector> named_to_reference_converter::visit_operands( - std::vector> operands) + cudf::host_span const> operands) { std::vector> transformed_operands; for (auto const& operand : operands) { @@ -621,7 +622,7 @@ class names_from_expression : public ast::detail::expression_transformer { } private: - void visit_operands(std::vector> operands) + void visit_operands(cudf::host_span const> operands) { for (auto const& operand : operands) { operand.get().accept(*this); diff --git a/cpp/src/io/parquet/reader_impl_helpers.hpp b/cpp/src/io/parquet/reader_impl_helpers.hpp index 6487c92f48f..fd692c0cdd6 100644 --- a/cpp/src/io/parquet/reader_impl_helpers.hpp +++ b/cpp/src/io/parquet/reader_impl_helpers.hpp @@ -425,7 +425,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer private: std::vector> visit_operands( - std::vector> operands); + cudf::host_span const> operands); std::unordered_map column_name_to_index; std::optional> _stats_expr; From 58c12ba0605fdeaabe56abc60eb5e034b6e177b1 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Wed, 23 Oct 2024 19:43:03 +0000 Subject: [PATCH 02/19] updated copyright notices --- cpp/src/ast/expression_parser.cpp | 2 +- cpp/src/ast/expressions.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/ast/expression_parser.cpp b/cpp/src/ast/expression_parser.cpp index f01374d758e..1739670d81c 100644 --- a/cpp/src/ast/expression_parser.cpp +++ b/cpp/src/ast/expression_parser.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp index e0df22a3600..065d0738711 100644 --- a/cpp/src/ast/expressions.cpp +++ b/cpp/src/ast/expressions.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 04414db67fae0552c28876a25910b24448472938 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Thu, 24 Oct 2024 16:10:23 +0100 Subject: [PATCH 03/19] updated ast tree documentation --- cpp/include/cudf/ast/expressions.hpp | 35 ++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 4ec7ed4c4ee..b1ce3648deb 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -563,16 +563,20 @@ class column_name_reference : public expression { */ class tree { public: - tree() = default; + /** + * @brief construct an empty ast tree + */ + tree() = default; + tree(tree&&) = default; + tree& operator=(tree&&) = default; + ~tree() = default; + + // the tree is not copyable tree(tree const&) = delete; - tree(tree&&) = default; tree& operator=(tree const&) = delete; - tree& operator=(tree&&) = default; - ~tree() = default; /** @brief Add an expression to the AST tree - @param expr AST expression to be added @param args Arguments to use to construct the ast expression @returns a reference to the added expression */ @@ -594,16 +598,37 @@ class tree { return emplace(std::move(expr)); } + /** + @brief get the first expression in the tree + */ expression const& front() const { return *expressions.front(); } + /** + @brief get the last expression in the tree + */ expression const& back() const { return *expressions.back(); } + /** + @brief get the number of expressions added to the tree + */ size_t size() const { return expressions.size(); } + /** + @brief get the expression at a checked index in the tree + @returns the expression at the specified index + */ expression const& at(size_t index) { return *expressions.at(index); } + /** + @brief get the expression at an unchecked index in the tree + @returns the expression at the specified index + */ expression const& operator[](size_t index) const { return *expressions[index]; } + /** + @brief get an immutable span to the expressions in the tree + @returns all expressions added to the tree + */ cudf::host_span const> get_expressions() const { return expressions; } private: From 7d398f13fe91cff540ab9dbd1a41d966d31df74d Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 13:49:45 +0000 Subject: [PATCH 04/19] fixed ast tree docs formatting --- cpp/include/cudf/ast/expressions.hpp | 34 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index b1ce3648deb..abb0787ad94 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -576,10 +576,10 @@ class tree { tree& operator=(tree const&) = delete; /** - @brief Add an expression to the AST tree - @param args Arguments to use to construct the ast expression - @returns a reference to the added expression - */ + * @brief Add an expression to the AST tree + * @param args Arguments to use to construct the ast expression + * @returns a reference to the added expression + */ template expression const& emplace(Args&&... args) { @@ -588,10 +588,10 @@ class tree { } /** - @brief Add an expression to the AST tree - @param expr AST expression to be added - @returns a reference to the added expression - */ + * @brief Add an expression to the AST tree + * @param expr AST expression to be added + * @returns a reference to the added expression + */ template expression const& push(Expr expr) { @@ -599,35 +599,35 @@ class tree { } /** - @brief get the first expression in the tree + * @brief get the first expression in the tree */ expression const& front() const { return *expressions.front(); } /** - @brief get the last expression in the tree + * @brief get the last expression in the tree */ expression const& back() const { return *expressions.back(); } /** - @brief get the number of expressions added to the tree + * @brief get the number of expressions added to the tree */ size_t size() const { return expressions.size(); } /** - @brief get the expression at a checked index in the tree - @returns the expression at the specified index + * @brief get the expression at a checked index in the tree + * @returns the expression at the specified index */ expression const& at(size_t index) { return *expressions.at(index); } /** - @brief get the expression at an unchecked index in the tree - @returns the expression at the specified index + * @brief get the expression at an unchecked index in the tree + * @returns the expression at the specified index */ expression const& operator[](size_t index) const { return *expressions[index]; } /** - @brief get an immutable span to the expressions in the tree - @returns all expressions added to the tree + * @brief get an immutable span to the expressions in the tree + * @returns all expressions added to the tree */ cudf::host_span const> get_expressions() const { return expressions; } From 6c824116b2a4728f553dec00f6a0fbcc0711a6da Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 13:51:25 +0000 Subject: [PATCH 05/19] temporarily changed back to using vector for storing ast operation's operands --- cpp/include/cudf/ast/expressions.hpp | 17 ++++++----------- cpp/src/ast/expressions.cpp | 5 ++--- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index abb0787ad94..41757b949d7 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -480,9 +480,9 @@ class operation : public expression { * * @return Vector of operands */ - [[nodiscard]] cudf::host_span const> get_operands() const + [[nodiscard]] std::vector> const& get_operands() const { - return cudf::host_span const>{operands, arity}; + return operands; } /** @@ -500,9 +500,8 @@ class operation : public expression { table_view const& right, rmm::cuda_stream_view stream) const override { - auto operands = get_operands(); - return std::any_of(operands.begin(), - operands.end(), + return std::any_of(operands.cbegin(), + operands.cend(), [&left, &right, &stream](std::reference_wrapper subexpr) { return subexpr.get().may_evaluate_null(left, right, stream); }); @@ -510,9 +509,7 @@ class operation : public expression { private: ast_operator op; - // TODO: replace with cuda::std::inplace_vector, 2> - std::reference_wrapper operands[2]; - size_t arity; + std::vector> operands; }; /** @@ -629,11 +626,9 @@ class tree { * @brief get an immutable span to the expressions in the tree * @returns all expressions added to the tree */ - cudf::host_span const> get_expressions() const { return expressions; } + std::vector> const& get_expressions() const { return expressions; } private: - // TODO: ideally we'd use a custom bump allocator for constructing the expression objects and - // release all the objects at once. std::vector> expressions; }; diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp index 065d0738711..3afa06611c5 100644 --- a/cpp/src/ast/expressions.cpp +++ b/cpp/src/ast/expressions.cpp @@ -26,15 +26,14 @@ namespace cudf { namespace ast { -operation::operation(ast_operator op, expression const& input) - : op{op}, operands{input, input}, arity{1} +operation::operation(ast_operator op, expression const& input) : op{op}, operands{input} { CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1, "The provided operator is not a unary operator."); } operation::operation(ast_operator op, expression const& left, expression const& right) - : op{op}, operands{left, right}, arity{2} + : op{op}, operands{left, right} { CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, "The provided operator is not a binary operator."); From e3b482372893fe002293bed6739517ec749e2af1 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 14:01:38 +0000 Subject: [PATCH 06/19] made operation arity check throw std::invalid_argument --- cpp/src/ast/expressions.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/src/ast/expressions.cpp b/cpp/src/ast/expressions.cpp index 3afa06611c5..3af4211912b 100644 --- a/cpp/src/ast/expressions.cpp +++ b/cpp/src/ast/expressions.cpp @@ -23,20 +23,24 @@ #include #include +#include + namespace cudf { namespace ast { operation::operation(ast_operator op, expression const& input) : op{op}, operands{input} { CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1, - "The provided operator is not a unary operator."); + "The provided operator is not a unary operator.", + std::invalid_argument); } operation::operation(ast_operator op, expression const& left, expression const& right) : op{op}, operands{left, right} { CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2, - "The provided operator is not a binary operator."); + "The provided operator is not a binary operator.", + std::invalid_argument); } cudf::size_type literal::accept(detail::expression_parser& visitor) const From 342b184531f4a0015127508448c6bd296b78874b Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 23:00:42 +0000 Subject: [PATCH 07/19] corrected ast builder push return type --- cpp/include/cudf/ast/expressions.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 41757b949d7..c07a7b25d87 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -578,10 +578,13 @@ class tree { * @returns a reference to the added expression */ template - expression const& emplace(Args&&... args) + Expr const& emplace(Args&&... args) { static_assert(std::is_base_of_v); - return *expressions.emplace_back(std::make_unique(std::forward(args)...)); + std::unique_ptr expr = std::make_unique(std::forward(args)...); + Expr const& expr_ref = *expr; + expressions.emplace_back(std::static_pointer_cast(std::move(expr))); + return expr_ref; } /** @@ -590,7 +593,7 @@ class tree { * @returns a reference to the added expression */ template - expression const& push(Expr expr) + Expr const& push(Expr expr) { return emplace(std::move(expr)); } From daa882ba4900765d90736a1bb03f89d06b71d713 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 23:01:23 +0000 Subject: [PATCH 08/19] updated documentation for ast tree member functions --- cpp/include/cudf/ast/expressions.hpp | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index c07a7b25d87..b471dea7026 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -563,10 +563,19 @@ class tree { /** * @brief construct an empty ast tree */ - tree() = default; - tree(tree&&) = default; + tree() = default; + + /** + * @brief Moves the ast tree + */ + tree(tree&&) = default; + + /** + * @brief move-assigns the AST tree + */ tree& operator=(tree&&) = default; - ~tree() = default; + + ~tree() = default; // the tree is not copyable tree(tree const&) = delete; @@ -600,27 +609,32 @@ class tree { /** * @brief get the first expression in the tree + * @returns the first inserted expression into the tree */ expression const& front() const { return *expressions.front(); } /** * @brief get the last expression in the tree + * @returns the last inserted expression into the tree */ expression const& back() const { return *expressions.back(); } /** * @brief get the number of expressions added to the tree + * @returns the number of expressions added to the tree */ size_t size() const { return expressions.size(); } /** - * @brief get the expression at a checked index in the tree + * @brief get the expression at an index in the tree. index is checked. + * @param index index of expression in the ast tree * @returns the expression at the specified index */ expression const& at(size_t index) { return *expressions.at(index); } /** - * @brief get the expression at an unchecked index in the tree + * @brief get the expression at an index in the tree. index is unchecked. + * @param index index of expression in the ast tree * @returns the expression at the specified index */ expression const& operator[](size_t index) const { return *expressions[index]; } From 239bfd4fb0dd7526e97f39cb40da9621fcc0db0f Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 28 Oct 2024 23:02:20 +0000 Subject: [PATCH 09/19] started drafting ast tree test --- cpp/tests/ast/ast_tree_tests.cpp | 71 ++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 cpp/tests/ast/ast_tree_tests.cpp diff --git a/cpp/tests/ast/ast_tree_tests.cpp b/cpp/tests/ast/ast_tree_tests.cpp new file mode 100644 index 00000000000..a010bf60235 --- /dev/null +++ b/cpp/tests/ast/ast_tree_tests.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +TEST_F(AstTreeTest, ExpressionTree) +{ + // compute y = mx + c, across multiple columns and apply weights to them and then fold them + auto a = column_wrapper{3, 20, 1, 50}; + auto b = column_wrapper{10, 7, 20, 0}; + auto c = column_wrapper{10, 7, 20, 0}; + auto d = column_wrapper{10, 7, 20, 0}; + auto e = column_wrapper{10, 7, 20, 0}; + auto f = column_wrapper{10, 7, 20, 0}; + auto table = cudf::table_view{{a, b, c, d, e, f}}; + cudf::ast::tree tree; + + auto const& a_ref = tree.push(cudf::ast::column_reference(0)); + auto const& b_ref = tree.push(cudf::ast::column_reference(1)); + auto const& c_ref = tree.push(cudf::ast::column_reference(2)); + auto const& d_ref = tree.push(cudf::ast::column_reference(3)); + auto const& e_ref = tree.push(cudf::ast::column_reference(4)); + auto const& f_ref = tree.push(cudf::ast::column_reference(5)); + auto const& literal = tree.push(cudf::ast::literal{255}); + + /// compute: (a + b) - c + auto const& op_0 = tree.push(cudf::ast::operation{ + cudf::ast::ast_operator::SUBTRACT, + tree.push(cudf::ast::operation{cudf::ast::ast_operator::ADD, a_ref, b_ref}), + c_ref}); + + auto const& op_1 = tree.push(cudf::ast::operation{ + cudf::ast::ast_operator::MULTIPLY, + tree.push(cudf::ast::operation{cudf::ast::ast_operator::SUBTRACR, d_ref, e_ref}), + e_ref}); + + auto result = cudf::compute_column( + table, tree.push(cudf::ast::operation{cudf::ast::ast_operator::ADD, op_0, op_1})); + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} From ea1d4305979e70651d04677b73d08982bd51fee1 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Tue, 29 Oct 2024 15:10:37 +0000 Subject: [PATCH 10/19] updated ast tree expression management --- cpp/include/cudf/ast/expressions.hpp | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index b471dea7026..c2ccc726e1d 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -590,8 +590,8 @@ class tree { Expr const& emplace(Args&&... args) { static_assert(std::is_base_of_v); - std::unique_ptr expr = std::make_unique(std::forward(args)...); - Expr const& expr_ref = *expr; + auto expr = std::make_shared(std::forward(args)...); + Expr const& expr_ref = *expr; expressions.emplace_back(std::static_pointer_cast(std::move(expr))); return expr_ref; } @@ -639,14 +639,10 @@ class tree { */ expression const& operator[](size_t index) const { return *expressions[index]; } - /** - * @brief get an immutable span to the expressions in the tree - * @returns all expressions added to the tree - */ - std::vector> const& get_expressions() const { return expressions; } - private: - std::vector> expressions; + // TODO: use better ownership semantics, the shared_ptr here is redundant. consider using a bump + // allocator with type-erased deleters. + std::vector> expressions; }; /** @} */ // end of group From 4b7eebdc97135b747cc9918ca2a682db79b7a096 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Tue, 29 Oct 2024 15:22:36 +0000 Subject: [PATCH 11/19] added ast tree tests --- cpp/tests/CMakeLists.txt | 2 +- cpp/tests/ast/ast_tree_tests.cpp | 94 +++++++++++++++++--------------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b78a64d0e55..11d2c27bf2c 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -650,7 +650,7 @@ ConfigureTest(ENCODE_TEST encode/encode_tests.cpp) # ################################################################################################## # * ast tests ------------------------------------------------------------------------------------- -ConfigureTest(AST_TEST ast/transform_tests.cpp) +ConfigureTest(AST_TEST ast/transform_tests.cpp ast/ast_tree_tests.cpp) # ################################################################################################## # * lists tests ---------------------------------------------------------------------------------- diff --git a/cpp/tests/ast/ast_tree_tests.cpp b/cpp/tests/ast/ast_tree_tests.cpp index a010bf60235..611426fa7e2 100644 --- a/cpp/tests/ast/ast_tree_tests.cpp +++ b/cpp/tests/ast/ast_tree_tests.cpp @@ -14,58 +14,66 @@ * limitations under the License. */ -#include #include #include -#include -#include #include #include -#include -#include -#include #include -#include -#include -#include #include #include #include -TEST_F(AstTreeTest, ExpressionTree) +template +using column_wrapper = cudf::test::fixed_width_column_wrapper; + +TEST(AstTreeTest, ExpressionTree) { - // compute y = mx + c, across multiple columns and apply weights to them and then fold them - auto a = column_wrapper{3, 20, 1, 50}; - auto b = column_wrapper{10, 7, 20, 0}; - auto c = column_wrapper{10, 7, 20, 0}; - auto d = column_wrapper{10, 7, 20, 0}; - auto e = column_wrapper{10, 7, 20, 0}; - auto f = column_wrapper{10, 7, 20, 0}; - auto table = cudf::table_view{{a, b, c, d, e, f}}; - cudf::ast::tree tree; - - auto const& a_ref = tree.push(cudf::ast::column_reference(0)); - auto const& b_ref = tree.push(cudf::ast::column_reference(1)); - auto const& c_ref = tree.push(cudf::ast::column_reference(2)); - auto const& d_ref = tree.push(cudf::ast::column_reference(3)); - auto const& e_ref = tree.push(cudf::ast::column_reference(4)); - auto const& f_ref = tree.push(cudf::ast::column_reference(5)); - auto const& literal = tree.push(cudf::ast::literal{255}); - - /// compute: (a + b) - c - auto const& op_0 = tree.push(cudf::ast::operation{ - cudf::ast::ast_operator::SUBTRACT, - tree.push(cudf::ast::operation{cudf::ast::ast_operator::ADD, a_ref, b_ref}), - c_ref}); - - auto const& op_1 = tree.push(cudf::ast::operation{ - cudf::ast::ast_operator::MULTIPLY, - tree.push(cudf::ast::operation{cudf::ast::ast_operator::SUBTRACR, d_ref, e_ref}), - e_ref}); - - auto result = cudf::compute_column( - table, tree.push(cudf::ast::operation{cudf::ast::ast_operator::ADD, op_0, op_1})); - - CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); + namespace ast = cudf::ast; + using op = ast::ast_operator; + using operation = ast::operation; + + // computes (y = mx + c)... and linearly interpolates them using interpolator t + auto m0_col = column_wrapper{10, 20, 50, 100}; + auto x0_col = column_wrapper{10, 5, 2, 1}; + auto c0_col = column_wrapper{100, 100, 100, 100}; + + auto m1_col = column_wrapper{10, 20, 50, 100}; + auto x1_col = column_wrapper{20, 10, 4, 2}; + auto c1_col = column_wrapper{200, 200, 200, 200}; + + auto one_scalar = cudf::numeric_scalar{1}; + auto t_scalar = cudf::numeric_scalar{0.5F}; + + auto table = cudf::table_view{{m0_col, x0_col, c0_col, m1_col, x1_col, c1_col}}; + + ast::tree tree{}; + + auto const& one = tree.push(ast::literal{one_scalar}); + auto const& t = tree.push(ast::literal{t_scalar}); + auto const& m0 = tree.push(ast::column_reference(0)); + auto const& x0 = tree.push(ast::column_reference(1)); + auto const& c0 = tree.push(ast::column_reference(2)); + auto const& m1 = tree.push(ast::column_reference(3)); + auto const& x1 = tree.push(ast::column_reference(4)); + auto const& c1 = tree.push(ast::column_reference(5)); + + // compute: y = mx + c + auto const& y0 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m0, x0}), c0}); + + // compute: y = mx + c + auto const& y1 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m1, x1}), c1}); + + // compute weighted: (1 - t) * y + auto const& y0_w = tree.push(operation{op::MUL, tree.push(operation{op::SUB, one, t}), y0}); + + // compute weighted: y = t * y + auto const& y1_w = tree.push(operation{op::MUL, t, y1}); + + // add weighted: result = lerp(y0, y1, t) = (1 - t) * y0 + t * y1 + auto result = cudf::compute_column(table, tree.push(operation{op::ADD, y0_w, y1_w})); + + auto expected = column_wrapper{300, 300, 300, 300}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view()); } From e17c2effafab6d73042eed54412a476fc6d399c5 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Tue, 29 Oct 2024 15:22:59 +0000 Subject: [PATCH 12/19] updated ast transform test --- cpp/tests/ast/transform_tests.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index a4bde50a21e..a197d3432e9 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -538,9 +538,10 @@ TEST_F(TransformTest, UnaryTrigonometry) TEST_F(TransformTest, ArityCheckFailure) { auto col_ref_0 = cudf::ast::column_reference(0); - EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error); + EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), + std::invalid_argument); EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0), - cudf::logic_error); + std::invalid_argument); } TEST_F(TransformTest, StringComparison) From d9fa6d8155c2cc08b0252dda909ff2cf45c51107 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Tue, 29 Oct 2024 15:46:37 +0000 Subject: [PATCH 13/19] updated ast tree documentation --- cpp/include/cudf/ast/expressions.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index c2ccc726e1d..54b52cdea3e 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -572,6 +572,7 @@ class tree { /** * @brief move-assigns the AST tree + * @returns a reference to the move-assigned tree */ tree& operator=(tree&&) = default; From 8f2265c723296a7f2b2e2b387d7dff29e1137aa8 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:09:45 +0000 Subject: [PATCH 14/19] Update cpp/include/cudf/ast/expressions.hpp Co-authored-by: Lawrence Mitchell --- cpp/include/cudf/ast/expressions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 54b52cdea3e..e579bc05b18 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -555,7 +555,7 @@ class column_name_reference : public expression { }; /** - * @brief An AST expression tree. it owns and contains multiple dependent expressions. All the + * @brief An AST expression tree. It owns and contains multiple dependent expressions. All the * expressions are destroyed once the tree is destructed. */ class tree { From 8a6dac693b308f1363b3d62ec676cafc0768e307 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:09:55 +0000 Subject: [PATCH 15/19] Update cpp/include/cudf/ast/expressions.hpp Co-authored-by: Lawrence Mitchell --- cpp/include/cudf/ast/expressions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index e579bc05b18..bf4d2375592 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -556,7 +556,7 @@ class column_name_reference : public expression { /** * @brief An AST expression tree. It owns and contains multiple dependent expressions. All the - * expressions are destroyed once the tree is destructed. + * expressions are destroyed when the tree is destructed. */ class tree { public: From 26b2989352e89d636cdb4d8018ebecb2fde4a01f Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:10:08 +0000 Subject: [PATCH 16/19] Update cpp/include/cudf/ast/expressions.hpp Co-authored-by: Lawrence Mitchell --- cpp/include/cudf/ast/expressions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index bf4d2375592..4b4ffec50d6 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -641,7 +641,7 @@ class tree { expression const& operator[](size_t index) const { return *expressions[index]; } private: - // TODO: use better ownership semantics, the shared_ptr here is redundant. consider using a bump + // TODO: use better ownership semantics, the shared_ptr here is redundant. Consider using a bump // allocator with type-erased deleters. std::vector> expressions; }; From aa1d17b1960ecf4ab0ac25abc5fef0b70867388b Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:10:16 +0000 Subject: [PATCH 17/19] Update cpp/include/cudf/ast/expressions.hpp Co-authored-by: Lawrence Mitchell --- cpp/include/cudf/ast/expressions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index 4b4ffec50d6..b4a29a7d845 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -634,7 +634,7 @@ class tree { expression const& at(size_t index) { return *expressions.at(index); } /** - * @brief get the expression at an index in the tree. index is unchecked. + * @brief get the expression at an index in the tree. Index is unchecked. * @param index index of expression in the ast tree * @returns the expression at the specified index */ From 4cd078fed91f9bb8a227daee845416c23a7c5965 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:10:23 +0000 Subject: [PATCH 18/19] Update cpp/include/cudf/ast/expressions.hpp Co-authored-by: Lawrence Mitchell --- cpp/include/cudf/ast/expressions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/cudf/ast/expressions.hpp b/cpp/include/cudf/ast/expressions.hpp index b4a29a7d845..bcc9ad1b391 100644 --- a/cpp/include/cudf/ast/expressions.hpp +++ b/cpp/include/cudf/ast/expressions.hpp @@ -627,7 +627,7 @@ class tree { size_t size() const { return expressions.size(); } /** - * @brief get the expression at an index in the tree. index is checked. + * @brief get the expression at an index in the tree. Index is checked. * @param index index of expression in the ast tree * @returns the expression at the specified index */ From 279f0470fb66eea505542d459676de1c88940749 Mon Sep 17 00:00:00 2001 From: Basit Ayantunde Date: Mon, 4 Nov 2024 18:10:38 +0000 Subject: [PATCH 19/19] Update cpp/tests/ast/ast_tree_tests.cpp Co-authored-by: Lawrence Mitchell --- cpp/tests/ast/ast_tree_tests.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/tests/ast/ast_tree_tests.cpp b/cpp/tests/ast/ast_tree_tests.cpp index 611426fa7e2..1a960c68e23 100644 --- a/cpp/tests/ast/ast_tree_tests.cpp +++ b/cpp/tests/ast/ast_tree_tests.cpp @@ -58,16 +58,16 @@ TEST(AstTreeTest, ExpressionTree) auto const& x1 = tree.push(ast::column_reference(4)); auto const& c1 = tree.push(ast::column_reference(5)); - // compute: y = mx + c + // compute: y0 = m0 x0 + c0 auto const& y0 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m0, x0}), c0}); - // compute: y = mx + c + // compute: y1 = m1 x1 + c1 auto const& y1 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m1, x1}), c1}); - // compute weighted: (1 - t) * y + // compute weighted: (1 - t) * y0 auto const& y0_w = tree.push(operation{op::MUL, tree.push(operation{op::SUB, one, t}), y0}); - // compute weighted: y = t * y + // compute weighted: y = t * y1 auto const& y1_w = tree.push(operation{op::MUL, t, y1}); // add weighted: result = lerp(y0, y1, t) = (1 - t) * y0 + t * y1