Skip to content

Commit

Permalink
Fix execution of lambda expressions
Browse files Browse the repository at this point in the history
Summary:
Lambda expressions can be executed multiple times on the same input
batch, such as with the `reduce` function which applies a lambda
function to each element of an input array. It is important to note
that each invocation receives a new set of inputs, and any state
relevant to one set of inputs should be reset before the next
invocation to avoid unintended consequences. An example of such
failure that we observed is when shared expressions inside `reduce`
inadvertently reused results between invocations because the shared
expressions held onto shared results that were indexed based on input
vector's address; due to sheer chance, some inputs ended up having
the same memory address.
Therefore, this change fixes this bug by ensuring this input specific
state, currently only limited to shared expressions is reset before
every invocation of the lambda.

Reviewed By: mbasmanova

Differential Revision: D56502765
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Apr 25, 2024
1 parent 6c0bcb4 commit 3fa0df6
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
54 changes: 51 additions & 3 deletions velox/expression/LambdaExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ class ExprCallable : public Callable {
ExprCallable(
RowTypePtr signature,
RowVectorPtr capture,
std::shared_ptr<Expr> body)
std::shared_ptr<Expr> body,
std::vector<std::shared_ptr<Expr>> sharedExprsToReset)
: signature_(std::move(signature)),
capture_(std::move(capture)),
body_(std::move(body)) {}
body_(std::move(body)),
sharedExprsToReset_(std::move(sharedExprsToReset)) {}

bool hasCapture() const override {
return capture_->childrenSize() > signature_->size();
Expand All @@ -53,6 +55,7 @@ class ExprCallable : public Callable {
EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult);
ScopedVarSetter throwOnError(
lambdaCtx.mutableThrowOnError(), context->throwOnError());
resetSharedExprs();
body_->eval(rows, lambdaCtx, *result);
transformErrorVector(lambdaCtx, context, rows, elementToTopLevelRows);
}
Expand All @@ -68,11 +71,18 @@ class ExprCallable : public Callable {
auto row = createRowVector(context, wrapCapture, args, rows.end());
EvalCtx lambdaCtx = createLambdaCtx(context, row, validRowsInReusedResult);
ScopedVarSetter throwOnError(lambdaCtx.mutableThrowOnError(), false);
resetSharedExprs();
body_->eval(rows, lambdaCtx, *result);
lambdaCtx.swapErrors(elementErrors);
}

private:
void resetSharedExprs() {
for (auto& expr : sharedExprsToReset_) {
expr->reset();
}
}

EvalCtx createLambdaCtx(
EvalCtx* context,
std::shared_ptr<RowVector>& row,
Expand Down Expand Up @@ -129,10 +139,47 @@ class ExprCallable : public Callable {
RowTypePtr signature_;
RowVectorPtr capture_;
std::shared_ptr<Expr> body_;
// List of Shared Exprs that are decendants of 'body_' for which reset() needs
// to be called before calling `body_->eval()`.
std::vector<std::shared_ptr<Expr>> sharedExprsToReset_;
};

void extractSharedExpressions(
const ExprPtr& expr,
std::unordered_set<ExprPtr>& shared) {
for (const auto& input : expr->inputs()) {
if (input->isMultiplyReferenced()) {
shared.insert(input);
continue;
}
extractSharedExpressions(input, shared);
}
}

} // namespace

LambdaExpr::LambdaExpr(
TypePtr type,
RowTypePtr&& signature,
std::vector<std::shared_ptr<FieldReference>>&& capture,
std::shared_ptr<Expr>&& body,
bool trackCpuUsage)
: SpecialForm(
std::move(type),
std::vector<std::shared_ptr<Expr>>(),
"lambda",
false /* supportsFlatNoNullsFastPath */,
trackCpuUsage),
signature_(std::move(signature)),
body_(std::move(body)),
capture_(std::move(capture)) {
std::unordered_set<ExprPtr> shared;
extractSharedExpressions(body_, shared);
for (auto& expr : shared) {
sharedExprsToReset_.push_back(expr);
}
}

void LambdaExpr::computeDistinctFields() {
SpecialForm::computeDistinctFields();
std::vector<FieldReference*> capturedFields;
Expand Down Expand Up @@ -205,7 +252,8 @@ void LambdaExpr::evalSpecialForm(
rows.end(),
values,
0);
auto callable = std::make_shared<ExprCallable>(signature_, capture, body_);
auto callable = std::make_shared<ExprCallable>(
signature_, capture, body_, sharedExprsToReset_);
std::shared_ptr<FunctionVector> functions;
if (!result) {
functions = std::make_shared<FunctionVector>(context.pool(), type_);
Expand Down
17 changes: 7 additions & 10 deletions velox/expression/LambdaExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,7 @@ class LambdaExpr : public SpecialForm {
RowTypePtr&& signature,
std::vector<std::shared_ptr<FieldReference>>&& capture,
std::shared_ptr<Expr>&& body,
bool trackCpuUsage)
: SpecialForm(
std::move(type),
std::vector<std::shared_ptr<Expr>>(),
"lambda",
false /* supportsFlatNoNullsFastPath */,
trackCpuUsage),
signature_(std::move(signature)),
body_(std::move(body)),
capture_(std::move(capture)) {}
bool trackCpuUsage);

bool isConstant() const override {
return false;
Expand Down Expand Up @@ -80,6 +71,12 @@ class LambdaExpr : public SpecialForm {
/// array/map.
ExprPtr body_;

// List of Shared Exprs that are decendants of 'body_' for which reset() needs
// to be called before calling `body_->eval()`.This is because every
// invocation of `body_->eval()` should treat its inputs like a fresh batch
// similar to how we operate in `ExprSet::eval()`.
std::vector<ExprPtr> sharedExprsToReset_;

/// List of field references to columns in the input row vector.
std::vector<std::shared_ptr<FieldReference>> capture_;

Expand Down

0 comments on commit 3fa0df6

Please sign in to comment.