Skip to content

Commit

Permalink
Split flat expression evaluator implementation from implementation of…
Browse files Browse the repository at this point in the history
… the legacy CelExpression interface.

PiperOrigin-RevId: 537376126
  • Loading branch information
jnthntatum authored and copybara-github committed Jun 2, 2023
1 parent 30bfb80 commit 58397d7
Show file tree
Hide file tree
Showing 28 changed files with 508 additions and 278 deletions.
9 changes: 4 additions & 5 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,10 @@ cc_library(
"//eval/internal:interop",
"//eval/public:ast_traverse_native",
"//eval/public:ast_visitor_native",
"//eval/public:cel_expression",
"//eval/public:source_position_native",
"//internal:status_macros",
"//runtime:function_registry",
"//runtime:runtime_options",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/log:check",
Expand Down Expand Up @@ -187,6 +185,7 @@ cc_library(
deps = [
":flat_expr_builder",
"//base:ast",
"//eval/eval:cel_expression_flat_impl",
"//eval/eval:evaluator_core",
"//eval/public:cel_expression",
"//extensions/protobuf:ast_converters",
Expand Down Expand Up @@ -232,6 +231,7 @@ cc_library(
":flat_expr_builder_extensions",
":resolver",
"//base:ast_internal",
"//base:builtins",
"//base:data",
"//base:function",
"//base:handle",
Expand All @@ -241,13 +241,11 @@ cc_library(
"//eval/eval:evaluator_core",
"//eval/internal:errors",
"//eval/internal:interop",
"//eval/public:activation",
"//eval/public:cel_builtins",
"//eval/public:cel_expression",
"//eval/public:cel_value",
"//eval/public/containers:container_backed_list_impl",
"//extensions/protobuf:memory_manager",
"//internal:status_macros",
"//runtime:activation",
"//runtime:function_overload_reference",
"//runtime:function_registry",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down Expand Up @@ -438,6 +436,7 @@ cc_test(
"//base:data",
"//base:memory",
"//base/internal:ast_impl",
"//eval/eval:cel_expression_flat_impl",
"//eval/eval:evaluator_core",
"//eval/public:builtin_func_registrar",
"//eval/public:cel_options",
Expand Down
14 changes: 10 additions & 4 deletions eval/compiler/cel_expression_builder_flat_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "base/ast.h"
#include "eval/eval/cel_expression_flat_impl.h"
#include "eval/eval/evaluator_core.h"
#include "extensions/protobuf/ast_converters.h"
#include "internal/status_macros.h"
Expand All @@ -46,8 +47,10 @@ CelExpressionBuilderFlatImpl::CreateExpression(
CEL_ASSIGN_OR_RETURN(
std::unique_ptr<Ast> converted_ast,
cel::extensions::CreateAstFromParsedExpr(*expr, source_info));
return flat_expr_builder_.CreateExpressionImpl(std::move(converted_ast),
warnings);
CEL_ASSIGN_OR_RETURN(FlatExpression impl,
flat_expr_builder_.CreateExpressionImpl(
std::move(converted_ast), warnings));
return std::make_unique<CelExpressionFlatImpl>(std::move(impl));
}

absl::StatusOr<std::unique_ptr<CelExpression>>
Expand All @@ -65,8 +68,11 @@ CelExpressionBuilderFlatImpl::CreateExpression(
CEL_ASSIGN_OR_RETURN(
std::unique_ptr<Ast> converted_ast,
cel::extensions::CreateAstFromCheckedExpr(*checked_expr));
return flat_expr_builder_.CreateExpressionImpl(std::move(converted_ast),
warnings);

CEL_ASSIGN_OR_RETURN(FlatExpression impl,
flat_expr_builder_.CreateExpressionImpl(
std::move(converted_ast), warnings));
return std::make_unique<CelExpressionFlatImpl>(std::move(impl));
}

absl::StatusOr<std::unique_ptr<CelExpression>>
Expand Down
44 changes: 23 additions & 21 deletions eval/compiler/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "base/ast_internal.h"
#include "base/builtins.h"
#include "base/function.h"
#include "base/handle.h"
#include "base/internal/ast_impl.h"
#include "base/kind.h"
#include "base/type_provider.h"
#include "base/value.h"
#include "base/values/bytes_value.h"
#include "base/values/error_value.h"
Expand All @@ -25,38 +27,36 @@
#include "eval/eval/evaluator_core.h"
#include "eval/internal/errors.h"
#include "eval/internal/interop.h"
#include "eval/public/activation.h"
#include "eval/public/cel_builtins.h"
#include "eval/public/cel_expression.h"
#include "eval/public/cel_value.h"
#include "eval/public/containers/container_backed_list_impl.h"
#include "extensions/protobuf/memory_manager.h"
#include "internal/status_macros.h"
#include "runtime/activation.h"
#include "runtime/function_overload_reference.h"
#include "runtime/function_registry.h"

namespace cel::ast::internal {

namespace {

using ::cel::builtin::kAnd;
using ::cel::builtin::kOr;
using ::cel::builtin::kTernary;
using ::cel::extensions::ProtoMemoryManager;
using ::cel::interop_internal::CreateErrorValueFromView;
using ::cel::interop_internal::CreateLegacyListValue;
using ::cel::interop_internal::CreateNoMatchingOverloadError;
using ::cel::interop_internal::ModernValueToLegacyValueOrDie;
using ::google::api::expr::runtime::Activation;
using ::google::api::expr::runtime::CelEvaluationListener;
using ::google::api::expr::runtime::CelExpressionFlatEvaluationState;
using ::google::api::expr::runtime::CelValue;
using ::google::api::expr::runtime::ContainerBackedListImpl;
using ::google::api::expr::runtime::EvaluationListener;
using ::google::api::expr::runtime::ExecutionFrame;
using ::google::api::expr::runtime::ExecutionPath;
using ::google::api::expr::runtime::ExecutionPathView;
using ::google::api::expr::runtime::FlatExpressionEvaluatorState;
using ::google::api::expr::runtime::PlannerContext;
using ::google::api::expr::runtime::ProgramOptimizer;
using ::google::api::expr::runtime::Resolver;
using ::google::api::expr::runtime::builtin::kAnd;
using ::google::api::expr::runtime::builtin::kOr;
using ::google::api::expr::runtime::builtin::kTernary;

using ::google::protobuf::Arena;

Expand Down Expand Up @@ -201,11 +201,8 @@ class ConstantFoldingTransform {
}
// short-circuiting affects evaluation of logic combinators, so we do
// not fold them here
if (!all_constant ||
call_expr.function() == google::api::expr::runtime::builtin::kAnd ||
call_expr.function() == google::api::expr::runtime::builtin::kOr ||
call_expr.function() ==
google::api::expr::runtime::builtin::kTernary) {
if (!all_constant || call_expr.function() == cel::builtin::kAnd ||
call_expr.function() == kOr || call_expr.function() == kTernary) {
return false;
}

Expand Down Expand Up @@ -392,8 +389,11 @@ bool ConstantFoldingTransform::Transform(const Expr& expr, Expr& out_) {

class ConstantFoldingExtension : public ProgramOptimizer {
public:
explicit ConstantFoldingExtension(google::protobuf::Arena* arena)
: arena_(arena), state_(kDefaultStackLimit, arena) {}
explicit ConstantFoldingExtension(google::protobuf::Arena* arena,
const TypeProvider& type_provider)
: arena_(arena),
memory_manager_(arena),
state_(kDefaultStackLimit, type_provider, memory_manager_) {}

absl::Status OnPreVisit(google::api::expr::runtime::PlannerContext& context,
const Expr& node) override;
Expand All @@ -410,9 +410,10 @@ class ConstantFoldingExtension : public ProgramOptimizer {
static constexpr size_t kDefaultStackLimit = 4;

google::protobuf::Arena* arena_;
ProtoMemoryManager memory_manager_;
Activation empty_;
CelEvaluationListener null_listener_;
CelExpressionFlatEvaluationState state_;
EvaluationListener null_listener_;
FlatExpressionEvaluatorState state_;

std::vector<IsConst> is_const_;
};
Expand Down Expand Up @@ -498,7 +499,7 @@ absl::Status ConstantFoldingExtension::OnPostVisit(PlannerContext& context,
node.const_expr().constant_kind());
} else {
ExecutionPathView subplan = context.GetSubplan(node);
ExecutionFrame frame(subplan, empty_, context.options(), &state_);
ExecutionFrame frame(subplan, empty_, context.options(), state_);
state_.Reset();
// Update stack size to accommodate sub expression.
// This only results in a vector resize if the new maxsize is greater than
Expand Down Expand Up @@ -531,8 +532,9 @@ void FoldConstants(

google::api::expr::runtime::ProgramOptimizerFactory
CreateConstantFoldingExtension(google::protobuf::Arena* arena) {
return [=](PlannerContext&, const AstImpl&) {
return std::make_unique<ConstantFoldingExtension>(arena);
return [=](PlannerContext& ctx, const AstImpl&) {
return std::make_unique<ConstantFoldingExtension>(
arena, ctx.type_registry().GetTypeProvider());
};
}

Expand Down
12 changes: 4 additions & 8 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <utility>
#include <vector>

#include "absl/base/macros.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "absl/log/check.h"
Expand Down Expand Up @@ -1121,8 +1120,7 @@ void ComprehensionVisitor::PostVisit(const cel::ast::internal::Expr* expr) {

// TODO(uncreated-issue/31): move ast conversion to client responsibility and
// update pre-processing steps to work without mutating the input AST.
absl::StatusOr<std::unique_ptr<CelExpression>>
FlatExprBuilder::CreateExpressionImpl(
absl::StatusOr<FlatExpression> FlatExprBuilder::CreateExpressionImpl(
std::unique_ptr<Ast> ast, std::vector<absl::Status>* warnings) const {
ExecutionPath execution_path;

Expand Down Expand Up @@ -1179,14 +1177,12 @@ FlatExprBuilder::CreateExpressionImpl(
return visitor.progress_status();
}

std::unique_ptr<CelExpression> expression_impl =
std::make_unique<CelExpressionFlatImpl>(std::move(execution_path),
options_);

if (warnings != nullptr) {
*warnings = std::move(warnings_builder).warnings();
}
return expression_impl;

return FlatExpression(std::move(execution_path),
type_registry_.GetTypeProvider(), options_);
}

} // namespace google::api::expr::runtime
3 changes: 1 addition & 2 deletions eval/compiler/flat_expr_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "absl/status/statusor.h"
#include "base/ast.h"
#include "eval/compiler/flat_expr_builder_extensions.h"
#include "eval/public/cel_expression.h"
#include "runtime/function_registry.h"
#include "runtime/runtime_options.h"
#include "google/protobuf/arena.h"
Expand Down Expand Up @@ -85,7 +84,7 @@ class FlatExprBuilder {

// TODO(uncreated-issue/45): Add overload for cref AST. At the moment, all the users
// can pass ownership of a freshly converted AST.
absl::StatusOr<std::unique_ptr<CelExpression>> CreateExpressionImpl(
absl::StatusOr<FlatExpression> CreateExpressionImpl(
std::unique_ptr<cel::ast::Ast> ast,
std::vector<absl::Status>* warnings) const;

Expand Down
5 changes: 3 additions & 2 deletions eval/compiler/regex_precompilation_optimization_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "eval/compiler/cel_expression_builder_flat_impl.h"
#include "eval/compiler/flat_expr_builder.h"
#include "eval/compiler/flat_expr_builder_extensions.h"
#include "eval/eval/cel_expression_flat_impl.h"
#include "eval/eval/evaluator_core.h"
#include "eval/public/builtin_func_registrar.h"
#include "eval/public/cel_options.h"
Expand Down Expand Up @@ -99,8 +100,8 @@ MATCHER_P(ExpressionPlanSizeIs, size, "") {
dynamic_cast<CelExpressionFlatImpl*>(plan.get());

if (impl == nullptr) return false;
*result_listener << "got size " << impl->path().size();
return impl->path().size() == size;
*result_listener << "got size " << impl->flat_expression().path().size();
return impl->flat_expression().path().size() == size;
}

TEST_F(RegexPrecompilationExtensionTest, OptimizeableExpression) {
Expand Down
Loading

0 comments on commit 58397d7

Please sign in to comment.