Skip to content

Commit

Permalink
Introduce TypedExprs (#7907)
Browse files Browse the repository at this point in the history
Summary:
TypedExprs contains convenience methods for working with expressions:

- isFieldAccess, asFieldAccess, isConstant, asConstant.

Allows for shorter easier to read and write code.

Pull Request resolved: #7907

Reviewed By: xiaoxmeng

Differential Revision: D51919498

Pulled By: mbasmanova

fbshipit-source-id: f14678310d792e57be9a2719f836d805be9311d3
  • Loading branch information
mbasmanova authored and facebook-github-bot committed Dec 7, 2023
1 parent dffc39e commit f59e6df
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 36 deletions.
28 changes: 28 additions & 0 deletions velox/core/Expressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class ConstantTypedExpr : public ITypedExpr {
const VectorPtr valueVector_;
};

using ConstantTypedExprPtr = std::shared_ptr<const ConstantTypedExpr>;

/// Evaluates a scalar function or a special form.
///
/// Supported special forms are: and, or, cast, try_cast, coalesce, if, switch,
Expand Down Expand Up @@ -645,4 +647,30 @@ class CastTypedExpr : public ITypedExpr {
};

using CastTypedExprPtr = std::shared_ptr<const CastTypedExpr>;

/// A collection of convenince methods for working with expressions.
class TypedExprs {
public:
/// Returns true if 'expr' is a field access expression.
static bool isFieldAccess(const TypedExprPtr& expr) {
return dynamic_cast<const FieldAccessTypedExpr*>(expr.get()) != nullptr;
}

/// Returns 'expr' as FieldAccessTypedExprPtr or null if not field access
/// expression.
static FieldAccessTypedExprPtr asFieldAccess(const TypedExprPtr& expr) {
return std::dynamic_pointer_cast<const FieldAccessTypedExpr>(expr);
}

/// Returns true if 'expr' is a constant expression.
static bool isConstant(const TypedExprPtr& expr) {
return dynamic_cast<const ConstantTypedExpr*>(expr.get()) != nullptr;
}

/// Returns 'expr' as ConstantTypedExprPtr or null if not field access
/// expression.
static ConstantTypedExprPtr asConstant(const TypedExprPtr& expr) {
return std::dynamic_pointer_cast<const ConstantTypedExpr>(expr);
}
};
} // namespace facebook::velox::core
17 changes: 5 additions & 12 deletions velox/core/PlanNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ RowTypePtr getAggregationOutputType(
std::vector<TypePtr> types;

for (auto& key : groupingKeys) {
auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(key);
auto field = TypedExprs::asFieldAccess(key);
VELOX_CHECK(field, "Grouping key must be a field reference");
names.push_back(field->name());
types.push_back(field->type());
Expand Down Expand Up @@ -203,12 +202,9 @@ void addKeys(std::stringstream& stream, const std::vector<TypedExprPtr>& keys) {
for (auto i = 0; i < keys.size(); ++i) {
const auto& expr = keys[i];
appendComma(i, stream);
if (auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(expr)) {
if (auto field = TypedExprs::asFieldAccess(expr)) {
stream << field->name();
} else if (
auto constant =
std::dynamic_pointer_cast<const core::ConstantTypedExpr>(expr)) {
} else if (auto constant = TypedExprs::asConstant(expr)) {
stream << constant->toString();
} else {
stream << expr->toString();
Expand Down Expand Up @@ -493,12 +489,9 @@ ExpandNode::ExpandNode(

for (const auto& rowProjection : projections_) {
for (const auto& columnProjection : rowProjection) {
auto field = std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
columnProjection);
auto constant = std::dynamic_pointer_cast<const core::ConstantTypedExpr>(
columnProjection);
VELOX_USER_CHECK(
field || constant,
TypedExprs::isFieldAccess(columnProjection) ||
TypedExprs::isConstant(columnProjection),
"Unsupported projection expression in Expand plan node. Expected field reference or constant. Got: {} ",
columnProjection->toString());
}
Expand Down
8 changes: 2 additions & 6 deletions velox/exec/Expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,11 @@ Expand::Expand(
constantProjection;
constantProjection.reserve(numColumns);
for (const auto& columnProjection : rowProjections) {
if (auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
columnProjection)) {
if (auto field = core::TypedExprs::asFieldAccess(columnProjection)) {
rowProjection.push_back(inputType->getChildIdx(field->name()));
constantProjection.push_back(nullptr);
} else if (
auto constant =
std::dynamic_pointer_cast<const core::ConstantTypedExpr>(
columnProjection)) {
auto constant = core::TypedExprs::asConstant(columnProjection)) {
rowProjection.push_back(kConstantChannel);
constantProjection.push_back(constant);
} else {
Expand Down
3 changes: 1 addition & 2 deletions velox/exec/FilterProject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ bool checkAddIdentityProjection(
const RowTypePtr& inputType,
column_index_t outputChannel,
std::vector<IdentityProjection>& identityProjections) {
if (auto field = std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
projection)) {
if (auto field = core::TypedExprs::asFieldAccess(projection)) {
const auto& inputs = field->inputs();
if (inputs.empty() ||
(inputs.size() == 1 &&
Expand Down
7 changes: 2 additions & 5 deletions velox/exec/Window.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ Window::WindowFrame Window::createWindowFrame(
}
auto frameChannel = exprToChannel(frame.get(), inputType);
if (frameChannel == kConstantChannel) {
auto constant =
std::dynamic_pointer_cast<const core::ConstantTypedExpr>(frame)
->value();
auto constant = core::TypedExprs::asConstant(frame)->value();
VELOX_CHECK(!constant.isNull(), "Window frame offset must not be null");
auto value = VariantConverter::convert(constant, TypeKind::BIGINT)
.value<int64_t>();
Expand Down Expand Up @@ -120,8 +118,7 @@ void Window::createWindowFunctions() {
for (auto& arg : windowNodeFunction.functionCall->inputs()) {
auto channel = exprToChannel(arg.get(), inputType);
if (channel == kConstantChannel) {
auto constantArg =
std::dynamic_pointer_cast<const core::ConstantTypedExpr>(arg);
auto constantArg = core::TypedExprs::asConstant(arg);
functionArgs.push_back(
{arg->type(), constantArg->toConstantVector(pool()), std::nullopt});
} else {
Expand Down
16 changes: 5 additions & 11 deletions velox/exec/tests/AggregationFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,7 @@ class ApproxDistinctResultVerifier : public ResultVerifier {
return kDefaultError;
}

auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(args[1]);
auto field = core::TypedExprs::asFieldAccess(args[1]);
VELOX_CHECK_NOT_NULL(field);
auto errorVector =
input->childAt(field->name())->as<SimpleVector<double>>();
Expand All @@ -458,8 +457,7 @@ class ApproxDistinctResultVerifier : public ResultVerifier {
const auto& args = aggregate.call->inputs();
VELOX_CHECK_GE(args.size(), 1)

auto inputField =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(args[0]);
auto inputField = core::TypedExprs::asFieldAccess(args[0]);
VELOX_CHECK_NOT_NULL(inputField)

std::string countDistinctCall =
Expand Down Expand Up @@ -586,8 +584,7 @@ class ApproxPercentileResultVerifier : public ResultVerifier {
return kDefaultAccuracy;
}

auto field = std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
args[accuracyIndex]);
auto field = core::TypedExprs::asFieldAccess(args[accuracyIndex]);
VELOX_CHECK_NOT_NULL(field);
auto accuracyVector =
input->childAt(field->name())->as<SimpleVector<double>>();
Expand Down Expand Up @@ -693,8 +690,7 @@ class ApproxPercentileResultVerifier : public ResultVerifier {
}

static const std::string& fieldName(const core::TypedExprPtr& expression) {
auto field =
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(expression);
auto field = core::TypedExprs::asFieldAccess(expression);
VELOX_CHECK_NOT_NULL(field);
return field->name();
}
Expand All @@ -711,9 +707,7 @@ class ApproxPercentileResultVerifier : public ResultVerifier {

const auto& percentileExpr = args[percentileIndex];

if (auto constantExpr =
std::dynamic_pointer_cast<const core::ConstantTypedExpr>(
percentileExpr)) {
if (auto constantExpr = core::TypedExprs::asConstant(percentileExpr)) {
if (constantExpr->type()->isDouble()) {
return {constantExpr->value().value<double>()};
}
Expand Down

0 comments on commit f59e6df

Please sign in to comment.