From cd141ef2d08e37ce7ade65c14ddc339fc39a2a48 Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Wed, 24 Apr 2024 15:03:59 -0700 Subject: [PATCH] Fix execution of lambda expressions Summary: Lambda expressions can be executed multiple times on the same input batch, such as with the `reduce` function which applies a lambda function to each element of an input array. It is important to note that each invocation receives a new set of inputs, and any state relevant to one set of inputs should be reset before the next invocation to avoid unintended consequences. An example of such failure that we observed is when shared expressions inside `reduce` inadvertently reused results between invocations because the shared expressions held onto shared results that were indexed based on input vector's address; due to sheer chance, some inputs ended up having the same memory address. Therefore, this change fixes this bug by ensuring this input specific state, currently only limited to shared expressions is reset before every invocation of the lambda. Differential Revision: D56502765 --- velox/expression/LambdaExpr.cpp | 53 +++++++++++++++++++++++++++++++-- velox/expression/LambdaExpr.h | 17 +++++------ 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/velox/expression/LambdaExpr.cpp b/velox/expression/LambdaExpr.cpp index 739e46cbaa221..77962f704f343 100644 --- a/velox/expression/LambdaExpr.cpp +++ b/velox/expression/LambdaExpr.cpp @@ -32,10 +32,12 @@ class ExprCallable : public Callable { ExprCallable( RowTypePtr signature, RowVectorPtr capture, - std::shared_ptr body) + std::shared_ptr body, + std::vector> sharedExprsToReset) : signature_(std::move(signature)), capture_(std::move(capture)), - body_(std::move(body)) {} + body_(std::move(body)), + sharedExprsToReset_(std::move(sharedExprsToReset)) {} bool hasCapture() const override { return capture_->childrenSize() > signature_->size(); @@ -53,6 +55,7 @@ class ExprCallable : public Callable { EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult); ScopedVarSetter throwOnError( lambdaCtx.mutableThrowOnError(), context->throwOnError()); + resetSharedExprs(); body_->eval(rows, lambdaCtx, *result); transformErrorVector(lambdaCtx, context, rows, elementToTopLevelRows); } @@ -68,11 +71,18 @@ class ExprCallable : public Callable { auto row = createRowVector(context, wrapCapture, args, rows.end()); EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult); ScopedVarSetter throwOnError(lambdaCtx.mutableThrowOnError(), false); + resetSharedExprs(); body_->eval(rows, lambdaCtx, *result); lambdaCtx.swapErrors(elementErrors); } private: + void resetSharedExprs() { + for (auto& expr : sharedExprsToReset_) { + expr->reset(); + } + } + EvalCtx createLambdaCtx( EvalCtx* context, std::shared_ptr& row, @@ -129,10 +139,46 @@ class ExprCallable : public Callable { RowTypePtr signature_; RowVectorPtr capture_; std::shared_ptr body_; + // List of Shared Exprs that are decendants of 'body_' for which reset() needs + // to be called before calling `body_->eval()`. + std::vector> sharedExprsToReset_; }; +void extractSharedExpressions( + const ExprPtr& expr, + std::unordered_set& shared) { + for (const auto& input : expr->inputs()) { + if (input->isMultiplyReferenced()) { + shared.insert(input); + } + extractSharedExpressions(input, shared); + } +} + } // namespace +LambdaExpr::LambdaExpr( + TypePtr type, + RowTypePtr&& signature, + std::vector>&& capture, + std::shared_ptr&& body, + bool trackCpuUsage) + : SpecialForm( + std::move(type), + std::vector>(), + "lambda", + false /* supportsFlatNoNullsFastPath */, + trackCpuUsage), + signature_(std::move(signature)), + body_(std::move(body)), + capture_(std::move(capture)) { + std::unordered_set shared; + extractSharedExpressions(body_, shared); + for (auto& expr : shared) { + sharedExprsToReset_.push_back(std::move(expr)); + } +} + void LambdaExpr::computeDistinctFields() { SpecialForm::computeDistinctFields(); std::vector capturedFields; @@ -205,7 +251,8 @@ void LambdaExpr::evalSpecialForm( rows.end(), values, 0); - auto callable = std::make_shared(signature_, capture, body_); + auto callable = std::make_shared( + signature_, capture, body_, sharedExprsToReset_); std::shared_ptr functions; if (!result) { functions = std::make_shared(context.pool(), type_); diff --git a/velox/expression/LambdaExpr.h b/velox/expression/LambdaExpr.h index 81327b81a8805..bba0dc0e85f50 100644 --- a/velox/expression/LambdaExpr.h +++ b/velox/expression/LambdaExpr.h @@ -33,16 +33,7 @@ class LambdaExpr : public SpecialForm { RowTypePtr&& signature, std::vector>&& capture, std::shared_ptr&& body, - bool trackCpuUsage) - : SpecialForm( - std::move(type), - std::vector>(), - "lambda", - false /* supportsFlatNoNullsFastPath */, - trackCpuUsage), - signature_(std::move(signature)), - body_(std::move(body)), - capture_(std::move(capture)) {} + bool trackCpuUsage); bool isConstant() const override { return false; @@ -80,6 +71,12 @@ class LambdaExpr : public SpecialForm { /// array/map. ExprPtr body_; + // List of Shared Exprs that are decendants of 'body_' for which reset() needs + // to be called before calling `body_->eval()`.This is because every + // invocation of `body_->eval()` should treat its inputs like a fresh batch + // similar to how we operate in `ExprSet::eval()`. + std::vector sharedExprsToReset_; + /// List of field references to columns in the input row vector. std::vector> capture_;