Skip to content

Commit

Permalink
ADH-5332
Browse files Browse the repository at this point in the history
- fixed comments
  • Loading branch information
VitekArkhipov committed Dec 5, 2024
1 parent dbdf0bd commit c14c03f
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,7 @@ public AdbSqlClient(ConnectionFactory connectionFactory,

private ConnectorExpressionRewriter<ParameterizedExpression> createExpressionRewriter(DataTypeMapper dataTypeMapper)
{
Predicate<ConnectorSession> pushdownWithCollateEnabled =
AdbSessionProperties::isEnableStringPushdownWithCollate;
Predicate<ConnectorSession> datetimeComparisonEnabled =
AdbPushdownSessionProperties::isPushdownDatetimeComparison;
Predicate<ConnectorSession> decimalArithmeticsEnabled =
AdbPushdownSessionProperties::isPushdownDecimalArithmetics;
Predicate<ConnectorSession> doubleArithmeticsEnabled =
AdbPushdownSessionProperties::isPushdownDoubleArithmetics;
Predicate<ConnectorSession> functionDatePartEnabled = AdbPushdownSessionProperties::isPushdownFunctionDatePart;
Predicate<ConnectorSession> functionLikeEnabled = AdbPushdownSessionProperties::isPushdownFunctionLike;
Predicate<ConnectorSession> functionSubstringEnabled =
AdbPushdownSessionProperties::isPushdownFunctionSubstring;
Predicate<ConnectorSession> functionUpperEnabled = AdbPushdownSessionProperties::isPushdownFunctionUpper;
Predicate<ConnectorSession> functionLowerEnabled = AdbPushdownSessionProperties::isPushdownFunctionLower;
return JdbcConnectorExpressionRewriterBuilder.newBuilder()
JdbcConnectorExpressionRewriterBuilder rewriterBuilder = JdbcConnectorExpressionRewriterBuilder.newBuilder()
.addStandardRules(this::quoted)
.add(new AdbRewriteInexactNumericConstant())
.add(new AdbRewriteBooleanConstant())
Expand All @@ -216,9 +202,17 @@ private ConnectorExpressionRewriter<ParameterizedExpression> createExpressionRew
.withTypeClass("string_type", ImmutableSet.of("char", "varchar"))
.withTypeClass("datetime_type",
ImmutableSet.of("date", "time", "timestamp", "timestamp with time zone"))
.add(new RewriteIn());
addPushdownCollateRules(AdbSessionProperties::isEnableStringPushdownWithCollate, rewriterBuilder);
addDatetimeComparisonRules(AdbPushdownSessionProperties::isPushdownDatetimeComparison, rewriterBuilder);
addDecimalArithmeticRules(AdbPushdownSessionProperties::isPushdownDecimalArithmetics, rewriterBuilder);
addDoubleArithmeticRules(AdbPushdownSessionProperties::isPushdownDoubleArithmetics, rewriterBuilder);
addFunctionDatePartRules(AdbPushdownSessionProperties::isPushdownFunctionDatePart, rewriterBuilder);
addFunctionLikeRules(AdbPushdownSessionProperties::isPushdownFunctionLike, rewriterBuilder);
addFunctionSubstringRules(AdbPushdownSessionProperties::isPushdownFunctionSubstring, rewriterBuilder);
return rewriterBuilder
.map("$not($is_null(value))").to("value IS NOT NULL")
.map("$not(value: boolean)").to("NOT value")
.add(new RewriteIn())
.map("$is_null(value)").to("value IS NULL")
.map("$nullif(first, second)").to("NULLIF(first, second)")
.map("$equal(left, right)").to("left = right")
Expand All @@ -228,69 +222,102 @@ private ConnectorExpressionRewriter<ParameterizedExpression> createExpressionRew
.map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right")
.map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right")
.map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right")
.when(pushdownWithCollateEnabled).map("$less_than(left: string_type, right: string_type)")
.to("left < right COLLATE \"C\"")
.when(pushdownWithCollateEnabled).map("$less_than_or_equal(left: string_type, right: string_type)")
.to("left <= right COLLATE \"C\"")
.when(pushdownWithCollateEnabled).map("$greater_than(left: string_type, right: string_type)")
.to("left > right COLLATE \"C\"")
.when(pushdownWithCollateEnabled).map("$greater_than_or_equal(left: string_type, right: string_type)")
.to("left >= right COLLATE \"C\"")
.when(datetimeComparisonEnabled).map("$less_than(left: datetime_type, right: datetime_type)")
.to("left < right")
.when(datetimeComparisonEnabled).map("$less_than_or_equal(left: datetime_type, right: datetime_type)")
.to("left <= right")
.when(datetimeComparisonEnabled).map("$greater_than(left: datetime_type, right: datetime_type)")
.to("left > right")
.when(datetimeComparisonEnabled)
.map("$greater_than_or_equal(left: datetime_type, right: datetime_type)").to("left >= right")
.map("$add(left: integer_type, right: integer_type)").to("left + right")
.map("$subtract(left: integer_type, right: integer_type)").to("left - right")
.map("$multiply(left: integer_type, right: integer_type)").to("left * right")
.map("$divide(left: integer_type, right: integer_type)").to("left / right")
.map("$modulus(left: integer_type, right: integer_type)").to("left % right")
.map("$negate(value: integer_type)").to("-value")
.when(decimalArithmeticsEnabled).map("$add(left: decimal_type, right: decimal_type)").to("left + right")
.when(decimalArithmeticsEnabled).map("$subtract(left: decimal_type, right: decimal_type)")
.add(new AdbRewriteDatetimeArithmetics())
.add(new AdbRewriteCast(dataTypeMapper))
.when(AdbPushdownSessionProperties::isPushdownFunctionUpper).map("upper(arg)").to("UPPER(arg)")
.when(AdbPushdownSessionProperties::isPushdownFunctionLower).map("lower(arg)").to("LOWER(arg)")
.build();
}

private void addPushdownCollateRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("$less_than(left: string_type, right: string_type)")
.to("left < right COLLATE \"C\"")
.when(predicate).map("$less_than_or_equal(left: string_type, right: string_type)")
.to("left <= right COLLATE \"C\"").when(predicate)
.map("$greater_than(left: string_type, right: string_type)")
.to("left > right COLLATE \"C\"").when(predicate)
.map("$greater_than_or_equal(left: string_type, right: string_type)")
.to("left >= right COLLATE \"C\"");
}

private void addDatetimeComparisonRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("$less_than(left: datetime_type, right: datetime_type)")
.to("left < right")
.when(predicate).map("$less_than_or_equal(left: datetime_type, right: datetime_type)")
.to("left <= right").when(predicate).map("$greater_than(left: datetime_type, right: datetime_type)")
.to("left > right").when(predicate)
.map("$greater_than_or_equal(left: datetime_type, right: datetime_type)").to("left >= right");
}

private void addDecimalArithmeticRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("$add(left: decimal_type, right: decimal_type)").to("left + right")
.when(predicate).map("$subtract(left: decimal_type, right: decimal_type)")
.to("left - right")
.when(decimalArithmeticsEnabled).map("$multiply(left: decimal_type, right: decimal_type)")
.when(predicate).map("$multiply(left: decimal_type, right: decimal_type)")
.to("left * right")
.when(decimalArithmeticsEnabled).map("$divide(left: decimal_type, right: decimal_type)")
.when(predicate).map("$divide(left: decimal_type, right: decimal_type)")
.to("left / right")
.when(decimalArithmeticsEnabled).map("$modulus(left: decimal_type, right: decimal_type)")
.when(predicate).map("$modulus(left: decimal_type, right: decimal_type)")
.to("left % right")
.when(decimalArithmeticsEnabled).map("$negate(value: decimal_type)").to("-value")
.when(doubleArithmeticsEnabled).map("$add(left: double_type, right: double_type)").to("left + right")
.when(doubleArithmeticsEnabled).map("$subtract(left: double_type, right: double_type)")
.when(predicate).map("$negate(value: decimal_type)").to("-value");
}

private void addDoubleArithmeticRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("$add(left: double_type, right: double_type)").to("left + right")
.when(predicate).map("$subtract(left: double_type, right: double_type)")
.to("left - right")
.when(doubleArithmeticsEnabled).map("$multiply(left: double_type, right: double_type)")
.when(predicate).map("$multiply(left: double_type, right: double_type)")
.to("left * right")
.when(doubleArithmeticsEnabled).map("$divide(left: double_type, right: double_type)").to("left / right")
.when(doubleArithmeticsEnabled).map("$modulus(left: double_type, right: double_type)")
.when(predicate).map("$divide(left: double_type, right: double_type)").to("left / right")
.when(predicate).map("$modulus(left: double_type, right: double_type)")
.to("left % right")
.when(doubleArithmeticsEnabled).map("$negate(value: double_type)").to("-value")
.add(new AdbRewriteDatetimeArithmetics())
.add(new AdbRewriteCast(dataTypeMapper))
.when(functionDatePartEnabled).map("year(arg: timestamp)").to("DATE_PART('isoyear', arg)")
.when(functionDatePartEnabled).map("quarter(arg: timestamp)").to("DATE_PART('quarter', arg)")
.when(functionDatePartEnabled).map("month(arg: timestamp)").to("DATE_PART('month', arg)")
.when(functionDatePartEnabled).map("week(arg: timestamp)").to("DATE_PART('week', arg)")
.when(functionDatePartEnabled).map("day(arg: timestamp)").to("DATE_PART('day', arg)")
.when(functionDatePartEnabled).map("day_of_week(arg: timestamp)").to("DATE_PART('isodow', arg)")
.when(functionDatePartEnabled).map("day_of_year(arg: timestamp)").to("DATE_PART('doy', arg)")
.when(functionDatePartEnabled).map("hour(arg: timestamp)").to("DATE_PART('hour', arg)")
.when(functionDatePartEnabled).map("minute(arg: timestamp)").to("DATE_PART('minute', arg)")
.when(functionLikeEnabled).map("$like(value: string_type, pattern): boolean").to("value LIKE pattern")
.when(functionLikeEnabled).map("$like(value: string_type, pattern, escape): boolean")
.to("value LIKE pattern ESCAPE escape")
.when(functionSubstringEnabled).map("substring(arg1: string_type, arg2: integer_type)")
.when(predicate).map("$negate(value: double_type)").to("-value");
}

private void addFunctionDatePartRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("year(arg: timestamp)").to("DATE_PART('isoyear', arg)")
.when(predicate).map("quarter(arg: timestamp)").to("DATE_PART('quarter', arg)")
.when(predicate).map("month(arg: timestamp)").to("DATE_PART('month', arg)")
.when(predicate).map("week(arg: timestamp)").to("DATE_PART('week', arg)")
.when(predicate).map("day(arg: timestamp)").to("DATE_PART('day', arg)")
.when(predicate).map("day_of_week(arg: timestamp)").to("DATE_PART('isodow', arg)")
.when(predicate).map("day_of_year(arg: timestamp)").to("DATE_PART('doy', arg)")
.when(predicate).map("hour(arg: timestamp)").to("DATE_PART('hour', arg)")
.when(predicate).map("minute(arg: timestamp)").to("DATE_PART('minute', arg)");
}

private void addFunctionLikeRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("$like(value: string_type, pattern): boolean").to("value LIKE pattern")
.when(predicate).map("$like(value: string_type, pattern, escape): boolean")
.to("value LIKE pattern ESCAPE escape");
}

private void addFunctionSubstringRules(Predicate<ConnectorSession> predicate,
JdbcConnectorExpressionRewriterBuilder builder)
{
builder.when(predicate).map("substring(arg1: string_type, arg2: integer_type)")
.to("SUBSTRING(arg1 FROM CAST(arg2 AS INT))")
.when(functionSubstringEnabled)
.when(predicate)
.map("substring(arg1: string_type, arg2: integer_type, arg3: integer_type)")
.to("SUBSTRING(arg1 FROM CAST(arg2 AS INT) FOR CAST(arg3 AS INT))")
.when(functionUpperEnabled).map("upper(arg)").to("UPPER(arg)")
.when(functionLowerEnabled).map("lower(arg)").to("LOWER(arg)")
.build();
.to("SUBSTRING(arg1 FROM CAST(arg2 AS INT) FOR CAST(arg3 AS INT))");
}

private AggregateFunctionRewriter<JdbcExpression, ParameterizedExpression> createAggregationFunctionRewriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,14 @@ public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction,
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType()));
Optional<String> suffix = Optional.empty();
String suffix = "";
if (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType) {
suffix = Optional.of(" COLLATE \"C\"");
suffix = " COLLATE \"C\"";
}
ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow();
return Optional.of(
new JdbcExpression(
format("%s(%s%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression(),
suffix.orElse("")),
format("%s(%s%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression(), suffix),
rewrittenArgument.parameters(),
columnHandle.getJdbcTypeHandle()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public class CsvFormatConfig
implements DataFormatConfig
{
private char delimiter = '|';
private Optional<String> nullValue = Optional.empty();
private String nullValue;
private String encoding = "UTF-8";

public static CsvFormatConfig create()
Expand All @@ -38,7 +38,7 @@ public CsvFormatConfig delimiter(char delimiter)

public CsvFormatConfig nullValue(String nullValue)
{
this.nullValue = Optional.ofNullable(nullValue);
this.nullValue = nullValue;
return this;
}

Expand All @@ -49,7 +49,7 @@ public char getDelimiter()

public Optional<String> getNullValue()
{
return nullValue;
return Optional.ofNullable(nullValue);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,23 @@ public class AdbRewriteDatetimeArithmetics
{
private static final String TYPE_NAME_DAY_SECOND = "interval day to second";
private static final String TYPE_NAME_YEAR_MONTH = "interval year to month";
private static final Capture<ConnectorExpression> ARG0 = Capture.newCapture();
private static final Capture<Constant> ARG1 = Capture.newCapture();
private static final Capture<ConnectorExpression> CONNECTOR_EXPR_CAPTURE = Capture.newCapture();
private static final Capture<Constant> CONSTANT_CAPTURE = Capture.newCapture();
private static final Pattern<Call> PATTERN = ConnectorExpressionPatterns.call()
.with(ConnectorExpressionPatterns.functionName()
.matching(n -> StandardFunctions.ADD_FUNCTION_NAME.equals(n) ||
StandardFunctions.SUBTRACT_FUNCTION_NAME.equals(n)))
.with(ConnectorExpressionPatterns.argument(0)
.matching(ConnectorExpressionPatterns.expression()
.capturedAs(ARG0)
.capturedAs(CONNECTOR_EXPR_CAPTURE)
.with(ConnectorExpressionPatterns.type()
.matching(type -> type == DateType.DATE
|| type instanceof TimeType
|| type instanceof TimestampType
|| type instanceof TimestampWithTimeZoneType))))
.with(ConnectorExpressionPatterns.argument(1)
.matching(ConnectorExpressionPatterns.constant()
.matching(c -> isInterval(c.getType())).capturedAs(ARG1)));
.matching(c -> isInterval(c.getType())).capturedAs(CONSTANT_CAPTURE)));

@Override
public Pattern<Call> getPattern()
Expand All @@ -72,23 +72,21 @@ public boolean isEnabled(ConnectorSession session)
public Optional<ParameterizedExpression> rewrite(Call expression, Captures captures,
RewriteContext<ParameterizedExpression> context)
{
Constant arg1 = captures.get(ARG1);
Constant arg1 = captures.get(CONSTANT_CAPTURE);
if (arg1.getValue() == null) {
return Optional.empty();
}
Optional<ParameterizedExpression> arg0 = context.defaultRewrite(captures.get(CONNECTOR_EXPR_CAPTURE));
if (arg0.isEmpty()) {
return Optional.empty();
}
else {
Optional<ParameterizedExpression> arg0 = context.defaultRewrite(captures.get(ARG0));
if (arg0.isEmpty()) {
return Optional.empty();
}
else {
String arg1Caption = String.format("interval '%d %s'", (Long) arg1.getValue(),
isDaySecondInterval(arg1.getType()) ? "milliseconds" : "months");
String operator = expression.getFunctionName() == StandardFunctions.ADD_FUNCTION_NAME ? "+" : "-";
return Optional.of(new ParameterizedExpression(
String.format("%s %s %s", arg0.get().expression(), operator, arg1Caption),
arg0.get().parameters()));
}
String arg1Caption = String.format("interval '%d %s'", (Long) arg1.getValue(),
isDaySecondInterval(arg1.getType()) ? "milliseconds" : "months");
String operator = expression.getFunctionName() == StandardFunctions.ADD_FUNCTION_NAME ? "+" : "-";
return Optional.of(new ParameterizedExpression(
String.format("%s %s %s", arg0.get().expression(), operator, arg1Caption),
arg0.get().parameters()));
}
}

Expand Down
Loading

0 comments on commit c14c03f

Please sign in to comment.