diff --git a/velox/expression/LambdaExpr.cpp b/velox/expression/LambdaExpr.cpp index 739e46cbaa22..67160a7d7fcc 100644 --- a/velox/expression/LambdaExpr.cpp +++ b/velox/expression/LambdaExpr.cpp @@ -32,10 +32,12 @@ class ExprCallable : public Callable { ExprCallable( RowTypePtr signature, RowVectorPtr capture, - std::shared_ptr body) + std::shared_ptr body, + std::vector> sharedExprsToReset) : signature_(std::move(signature)), capture_(std::move(capture)), - body_(std::move(body)) {} + body_(std::move(body)), + sharedExprsToReset_(std::move(sharedExprsToReset)) {} bool hasCapture() const override { return capture_->childrenSize() > signature_->size(); @@ -53,6 +55,7 @@ class ExprCallable : public Callable { EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult); ScopedVarSetter throwOnError( lambdaCtx.mutableThrowOnError(), context->throwOnError()); + resetSharedExprs(); body_->eval(rows, lambdaCtx, *result); transformErrorVector(lambdaCtx, context, rows, elementToTopLevelRows); } @@ -68,11 +71,18 @@ class ExprCallable : public Callable { auto row = createRowVector(context, wrapCapture, args, rows.end()); EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult); ScopedVarSetter throwOnError(lambdaCtx.mutableThrowOnError(), false); + resetSharedExprs(); body_->eval(rows, lambdaCtx, *result); lambdaCtx.swapErrors(elementErrors); } private: + void resetSharedExprs() { + for (auto& expr : sharedExprsToReset_) { + expr->reset(); + } + } + EvalCtx createLambdaCtx( EvalCtx* context, std::shared_ptr& row, @@ -129,10 +139,47 @@ class ExprCallable : public Callable { RowTypePtr signature_; RowVectorPtr capture_; std::shared_ptr body_; + // List of Shared Exprs that are decendants of 'body_' for which reset() needs + // to be called before calling `body_->eval()`. + std::vector> sharedExprsToReset_; }; +void extractSharedExpressions( + const ExprPtr& expr, + std::unordered_set& shared) { + for (const auto& input : expr->inputs()) { + if (input->isMultiplyReferenced()) { + shared.insert(input); + continue; + } + extractSharedExpressions(input, shared); + } +} + } // namespace +LambdaExpr::LambdaExpr( + TypePtr type, + RowTypePtr&& signature, + std::vector>&& capture, + std::shared_ptr&& body, + bool trackCpuUsage) + : SpecialForm( + std::move(type), + std::vector>(), + "lambda", + false /* supportsFlatNoNullsFastPath */, + trackCpuUsage), + signature_(std::move(signature)), + body_(std::move(body)), + capture_(std::move(capture)) { + std::unordered_set shared; + extractSharedExpressions(body_, shared); + for (auto& expr : shared) { + sharedExprsToReset_.push_back(expr); + } +} + void LambdaExpr::computeDistinctFields() { SpecialForm::computeDistinctFields(); std::vector capturedFields; @@ -205,7 +252,8 @@ void LambdaExpr::evalSpecialForm( rows.end(), values, 0); - auto callable = std::make_shared(signature_, capture, body_); + auto callable = std::make_shared( + signature_, capture, body_, sharedExprsToReset_); std::shared_ptr functions; if (!result) { functions = std::make_shared(context.pool(), type_); diff --git a/velox/expression/LambdaExpr.h b/velox/expression/LambdaExpr.h index 81327b81a880..bba0dc0e85f5 100644 --- a/velox/expression/LambdaExpr.h +++ b/velox/expression/LambdaExpr.h @@ -33,16 +33,7 @@ class LambdaExpr : public SpecialForm { RowTypePtr&& signature, std::vector>&& capture, std::shared_ptr&& body, - bool trackCpuUsage) - : SpecialForm( - std::move(type), - std::vector>(), - "lambda", - false /* supportsFlatNoNullsFastPath */, - trackCpuUsage), - signature_(std::move(signature)), - body_(std::move(body)), - capture_(std::move(capture)) {} + bool trackCpuUsage); bool isConstant() const override { return false; @@ -80,6 +71,12 @@ class LambdaExpr : public SpecialForm { /// array/map. ExprPtr body_; + // List of Shared Exprs that are decendants of 'body_' for which reset() needs + // to be called before calling `body_->eval()`.This is because every + // invocation of `body_->eval()` should treat its inputs like a fresh batch + // similar to how we operate in `ExprSet::eval()`. + std::vector sharedExprsToReset_; + /// List of field references to columns in the input row vector. std::vector> capture_;