Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ast tree to simplify expression lifetime management #17156

Open
wants to merge 29 commits into
base: branch-24.12
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4d60344
added ast tree to simplify expression lifetime management
lamarrr Oct 23, 2024
bf9d71f
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Oct 23, 2024
58c12ba
updated copyright notices
lamarrr Oct 23, 2024
479a7b0
Merge branch 'ast-expr-enhancement' of https://github.com/lamarrr/cud…
lamarrr Oct 23, 2024
04414db
updated ast tree documentation
lamarrr Oct 24, 2024
7d398f1
fixed ast tree docs formatting
lamarrr Oct 28, 2024
6c82411
temporarily changed back to using vector for storing ast operation's …
lamarrr Oct 28, 2024
e3b4823
made operation arity check throw std::invalid_argument
lamarrr Oct 28, 2024
5695f29
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Oct 28, 2024
342b184
corrected ast builder push return type
lamarrr Oct 28, 2024
daa882b
updated documentation for ast tree member functions
lamarrr Oct 28, 2024
239bfd4
started drafting ast tree test
lamarrr Oct 28, 2024
ea1d430
updated ast tree expression management
lamarrr Oct 29, 2024
4b7eebd
added ast tree tests
lamarrr Oct 29, 2024
e17c2ef
updated ast transform test
lamarrr Oct 29, 2024
d9fa6d8
updated ast tree documentation
lamarrr Oct 29, 2024
e33a7a0
Merge remote-tracking branch 'upstream/branch-24.12' into ast-expr-en…
lamarrr Oct 29, 2024
e356ee5
Merge remote-tracking branch 'upstream/branch-24.12' into ast-expr-en…
lamarrr Oct 29, 2024
8f2265c
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
8a6dac6
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
26b2989
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
aa1d17b
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
4cd078f
Update cpp/include/cudf/ast/expressions.hpp
lamarrr Nov 4, 2024
279f047
Update cpp/tests/ast/ast_tree_tests.cpp
lamarrr Nov 4, 2024
984e34f
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 4, 2024
6afd563
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 4, 2024
1dc38fc
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 6, 2024
72b2f58
Merge branch 'branch-24.12' into ast-expr-enhancement
lamarrr Nov 6, 2024
dccfeb9
Merge branch 'branch-24.12' into ast-expr-enhancement
vyasr Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cpp/include/cudf/ast/detail/expression_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <cudf/ast/expressions.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>

#include <thrust/scan.h>
vyasr marked this conversation as resolved.
Show resolved Hide resolved

#include <functional>
#include <numeric>
Expand Down Expand Up @@ -296,7 +300,7 @@ class expression_parser {
* @return The indices of the operands stored in the data references.
*/
std::vector<cudf::size_type> visit_operands(
std::vector<std::reference_wrapper<expression const>> operands);
cudf::host_span<std::reference_wrapper<cudf::ast::expression const> const> operands);

/**
* @brief Add a data reference to the internal list.
Expand Down
100 changes: 97 additions & 3 deletions cpp/include/cudf/ast/expressions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include <cudf/utilities/error.hpp>

#include <cstdint>
#include <memory>
#include <vector>

namespace CUDF_EXPORT cudf {
namespace ast {
Expand Down Expand Up @@ -478,7 +480,7 @@ class operation : public expression {
*
* @return Vector of operands
*/
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> get_operands() const
[[nodiscard]] std::vector<std::reference_wrapper<expression const>> const& get_operands() const
{
return operands;
}
Expand Down Expand Up @@ -506,8 +508,8 @@ class operation : public expression {
};

private:
ast_operator const op;
std::vector<std::reference_wrapper<expression const>> const operands;
ast_operator op;
std::vector<std::reference_wrapper<expression const>> operands;
};

/**
Expand Down Expand Up @@ -552,6 +554,98 @@ class column_name_reference : public expression {
std::string column_name;
};

/**
* @brief An AST expression tree. It owns and contains multiple dependent expressions. All the
* expressions are destroyed when the tree is destructed.
*/
class tree {
public:
/**
* @brief construct an empty ast tree
*/
tree() = default;

/**
* @brief Moves the ast tree
*/
tree(tree&&) = default;

/**
* @brief move-assigns the AST tree
* @returns a reference to the move-assigned tree
*/
tree& operator=(tree&&) = default;

~tree() = default;

// the tree is not copyable
tree(tree const&) = delete;
tree& operator=(tree const&) = delete;

/**
* @brief Add an expression to the AST tree
* @param args Arguments to use to construct the ast expression
* @returns a reference to the added expression
*/
template <typename Expr, typename... Args>
Expr const& emplace(Args&&... args)
{
static_assert(std::is_base_of_v<expression, Expr>);
auto expr = std::make_shared<Expr>(std::forward<Args>(args)...);
Expr const& expr_ref = *expr;
expressions.emplace_back(std::static_pointer_cast<expression>(std::move(expr)));
return expr_ref;
}

/**
* @brief Add an expression to the AST tree
* @param expr AST expression to be added
* @returns a reference to the added expression
*/
template <typename Expr>
Expr const& push(Expr expr)
{
return emplace<Expr>(std::move(expr));
}

/**
* @brief get the first expression in the tree
* @returns the first inserted expression into the tree
*/
expression const& front() const { return *expressions.front(); }

/**
* @brief get the last expression in the tree
* @returns the last inserted expression into the tree
*/
expression const& back() const { return *expressions.back(); }

/**
* @brief get the number of expressions added to the tree
* @returns the number of expressions added to the tree
*/
size_t size() const { return expressions.size(); }

/**
* @brief get the expression at an index in the tree. Index is checked.
* @param index index of expression in the ast tree
* @returns the expression at the specified index
*/
expression const& at(size_t index) { return *expressions.at(index); }

/**
* @brief get the expression at an index in the tree. Index is unchecked.
* @param index index of expression in the ast tree
* @returns the expression at the specified index
*/
expression const& operator[](size_t index) const { return *expressions[index]; }

private:
// TODO: use better ownership semantics, the shared_ptr here is redundant. Consider using a bump
// allocator with type-erased deleters.
std::vector<std::shared_ptr<expression>> expressions;
vyasr marked this conversation as resolved.
Show resolved Hide resolved
lamarrr marked this conversation as resolved.
Show resolved Hide resolved
};

/** @} */ // end of group
} // namespace ast

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/ast/expression_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ cudf::data_type expression_parser::output_type() const
}

std::vector<cudf::size_type> expression_parser::visit_operands(
std::vector<std::reference_wrapper<expression const>> operands)
cudf::host_span<std::reference_wrapper<expression const> const> operands)
{
auto operand_data_reference_indices = std::vector<cudf::size_type>();
for (auto const& operand : operands) {
Expand Down
24 changes: 16 additions & 8 deletions cpp/src/ast/expressions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,41 @@
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>

#include <stdexcept>

namespace cudf {
namespace ast {

operation::operation(ast_operator op, expression const& input) : op(op), operands({input})
operation::operation(ast_operator op, expression const& input) : op{op}, operands{input}
{
if (cudf::ast::detail::ast_operator_arity(op) != 1) {
CUDF_FAIL("The provided operator is not a unary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 1,
"The provided operator is not a unary operator.",
std::invalid_argument);
}

operation::operation(ast_operator op, expression const& left, expression const& right)
: op(op), operands({left, right})
: op{op}, operands{left, right}
{
if (cudf::ast::detail::ast_operator_arity(op) != 2) {
CUDF_FAIL("The provided operator is not a binary operator.");
}
CUDF_EXPECTS(cudf::ast::detail::ast_operator_arity(op) == 2,
"The provided operator is not a binary operator.",
std::invalid_argument);
}

cudf::size_type literal::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type operation::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
}

cudf::size_type column_name_reference::accept(detail::expression_parser& visitor) const
{
return visitor.visit(*this);
Expand All @@ -60,16 +65,19 @@ auto literal::accept(detail::expression_transformer& visitor) const
{
return visitor.visit(*this);
}

auto column_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto operation::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
return visitor.visit(*this);
}

auto column_name_reference::accept(detail::expression_transformer& visitor) const
-> decltype(visitor.visit(*this))
{
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/span.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

Expand Down Expand Up @@ -373,7 +374,7 @@ class stats_expression_converter : public ast::detail::expression_transformer {

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -553,7 +554,7 @@ std::reference_wrapper<ast::expression const> named_to_reference_converter::visi

std::vector<std::reference_wrapper<ast::expression const>>
named_to_reference_converter::visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands)
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
std::vector<std::reference_wrapper<ast::expression const>> transformed_operands;
for (auto const& operand : operands) {
Expand Down Expand Up @@ -623,7 +624,7 @@ class names_from_expression : public ast::detail::expression_transformer {
}

private:
void visit_operands(std::vector<std::reference_wrapper<ast::expression const>> operands)
void visit_operands(cudf::host_span<std::reference_wrapper<ast::expression const> const> operands)
{
for (auto const& operand : operands) {
operand.get().accept(*this);
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ class named_to_reference_converter : public ast::detail::expression_transformer

private:
std::vector<std::reference_wrapper<ast::expression const>> visit_operands(
std::vector<std::reference_wrapper<ast::expression const>> operands);
cudf::host_span<std::reference_wrapper<ast::expression const> const> operands);

std::unordered_map<std::string, size_type> column_name_to_index;
std::optional<std::reference_wrapper<ast::expression const>> _stats_expr;
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ ConfigureTest(ENCODE_TEST encode/encode_tests.cpp)

# ##################################################################################################
# * ast tests -------------------------------------------------------------------------------------
ConfigureTest(AST_TEST ast/transform_tests.cpp)
ConfigureTest(AST_TEST ast/transform_tests.cpp ast/ast_tree_tests.cpp)

# ##################################################################################################
# * lists tests ----------------------------------------------------------------------------------
Expand Down
79 changes: 79 additions & 0 deletions cpp/tests/ast/ast_tree_tests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf_test/column_utilities.hpp>
#include <cudf_test/column_wrapper.hpp>
#include <cudf_test/testing_main.hpp>

#include <cudf/ast/expressions.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/transform.hpp>
#include <cudf/types.hpp>

template <typename T>
using column_wrapper = cudf::test::fixed_width_column_wrapper<T>;

TEST(AstTreeTest, ExpressionTree)
{
namespace ast = cudf::ast;
using op = ast::ast_operator;
using operation = ast::operation;

// computes (y = mx + c)... and linearly interpolates them using interpolator t
auto m0_col = column_wrapper<float>{10, 20, 50, 100};
auto x0_col = column_wrapper<float>{10, 5, 2, 1};
auto c0_col = column_wrapper<float>{100, 100, 100, 100};

auto m1_col = column_wrapper<float>{10, 20, 50, 100};
auto x1_col = column_wrapper<float>{20, 10, 4, 2};
auto c1_col = column_wrapper<float>{200, 200, 200, 200};

auto one_scalar = cudf::numeric_scalar<float>{1};
auto t_scalar = cudf::numeric_scalar<float>{0.5F};

auto table = cudf::table_view{{m0_col, x0_col, c0_col, m1_col, x1_col, c1_col}};

ast::tree tree{};

auto const& one = tree.push(ast::literal{one_scalar});
auto const& t = tree.push(ast::literal{t_scalar});
auto const& m0 = tree.push(ast::column_reference(0));
auto const& x0 = tree.push(ast::column_reference(1));
auto const& c0 = tree.push(ast::column_reference(2));
auto const& m1 = tree.push(ast::column_reference(3));
auto const& x1 = tree.push(ast::column_reference(4));
auto const& c1 = tree.push(ast::column_reference(5));

// compute: y0 = m0 x0 + c0
auto const& y0 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m0, x0}), c0});

// compute: y1 = m1 x1 + c1
auto const& y1 = tree.push(operation{op::ADD, tree.push(operation{op::MUL, m1, x1}), c1});

// compute weighted: (1 - t) * y0
auto const& y0_w = tree.push(operation{op::MUL, tree.push(operation{op::SUB, one, t}), y0});

// compute weighted: y = t * y1
auto const& y1_w = tree.push(operation{op::MUL, t, y1});

// add weighted: result = lerp(y0, y1, t) = (1 - t) * y0 + t * y1
auto result = cudf::compute_column(table, tree.push(operation{op::ADD, y0_w, y1_w}));

auto expected = column_wrapper<float>{300, 300, 300, 300};

CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view());
}
5 changes: 3 additions & 2 deletions cpp/tests/ast/transform_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,9 +530,10 @@ TEST_F(TransformTest, UnaryTrigonometry)
TEST_F(TransformTest, ArityCheckFailure)
{
auto col_ref_0 = cudf::ast::column_reference(0);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0), cudf::logic_error);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0),
std::invalid_argument);
EXPECT_THROW(cudf::ast::operation(cudf::ast::ast_operator::ABS, col_ref_0, col_ref_0),
cudf::logic_error);
std::invalid_argument);
}

TEST_F(TransformTest, StringComparison)
Expand Down
Loading