diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java index bd4f11ee2e07..20507ff7e22a 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/AdbSqlClient.java @@ -188,21 +188,7 @@ public AdbSqlClient(ConnectionFactory connectionFactory, private ConnectorExpressionRewriter createExpressionRewriter(DataTypeMapper dataTypeMapper) { - Predicate pushdownWithCollateEnabled = - AdbSessionProperties::isEnableStringPushdownWithCollate; - Predicate datetimeComparisonEnabled = - AdbPushdownSessionProperties::isPushdownDatetimeComparison; - Predicate decimalArithmeticsEnabled = - AdbPushdownSessionProperties::isPushdownDecimalArithmetics; - Predicate doubleArithmeticsEnabled = - AdbPushdownSessionProperties::isPushdownDoubleArithmetics; - Predicate functionDatePartEnabled = AdbPushdownSessionProperties::isPushdownFunctionDatePart; - Predicate functionLikeEnabled = AdbPushdownSessionProperties::isPushdownFunctionLike; - Predicate functionSubstringEnabled = - AdbPushdownSessionProperties::isPushdownFunctionSubstring; - Predicate functionUpperEnabled = AdbPushdownSessionProperties::isPushdownFunctionUpper; - Predicate functionLowerEnabled = AdbPushdownSessionProperties::isPushdownFunctionLower; - return JdbcConnectorExpressionRewriterBuilder.newBuilder() + JdbcConnectorExpressionRewriterBuilder rewriterBuilder = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .add(new AdbRewriteInexactNumericConstant()) .add(new AdbRewriteBooleanConstant()) @@ -216,9 +202,17 @@ private ConnectorExpressionRewriter createExpressionRew .withTypeClass("string_type", ImmutableSet.of("char", "varchar")) .withTypeClass("datetime_type", ImmutableSet.of("date", "time", "timestamp", "timestamp with time zone")) + .add(new RewriteIn()); + addPushdownCollateRules(AdbSessionProperties::isEnableStringPushdownWithCollate, rewriterBuilder); + addDatetimeComparisonRules(AdbPushdownSessionProperties::isPushdownDatetimeComparison, rewriterBuilder); + addDecimalArithmeticRules(AdbPushdownSessionProperties::isPushdownDecimalArithmetics, rewriterBuilder); + addDoubleArithmeticRules(AdbPushdownSessionProperties::isPushdownDoubleArithmetics, rewriterBuilder); + addFunctionDatePartRules(AdbPushdownSessionProperties::isPushdownFunctionDatePart, rewriterBuilder); + addFunctionLikeRules(AdbPushdownSessionProperties::isPushdownFunctionLike, rewriterBuilder); + addFunctionSubstringRules(AdbPushdownSessionProperties::isPushdownFunctionSubstring, rewriterBuilder); + return rewriterBuilder .map("$not($is_null(value))").to("value IS NOT NULL") .map("$not(value: boolean)").to("NOT value") - .add(new RewriteIn()) .map("$is_null(value)").to("value IS NULL") .map("$nullif(first, second)").to("NULLIF(first, second)") .map("$equal(left, right)").to("left = right") @@ -228,69 +222,102 @@ private ConnectorExpressionRewriter createExpressionRew .map("$less_than_or_equal(left: numeric_type, right: numeric_type)").to("left <= right") .map("$greater_than(left: numeric_type, right: numeric_type)").to("left > right") .map("$greater_than_or_equal(left: numeric_type, right: numeric_type)").to("left >= right") - .when(pushdownWithCollateEnabled).map("$less_than(left: string_type, right: string_type)") - .to("left < right COLLATE \"C\"") - .when(pushdownWithCollateEnabled).map("$less_than_or_equal(left: string_type, right: string_type)") - .to("left <= right COLLATE \"C\"") - .when(pushdownWithCollateEnabled).map("$greater_than(left: string_type, right: string_type)") - .to("left > right COLLATE \"C\"") - .when(pushdownWithCollateEnabled).map("$greater_than_or_equal(left: string_type, right: string_type)") - .to("left >= right COLLATE \"C\"") - .when(datetimeComparisonEnabled).map("$less_than(left: datetime_type, right: datetime_type)") - .to("left < right") - .when(datetimeComparisonEnabled).map("$less_than_or_equal(left: datetime_type, right: datetime_type)") - .to("left <= right") - .when(datetimeComparisonEnabled).map("$greater_than(left: datetime_type, right: datetime_type)") - .to("left > right") - .when(datetimeComparisonEnabled) - .map("$greater_than_or_equal(left: datetime_type, right: datetime_type)").to("left >= right") .map("$add(left: integer_type, right: integer_type)").to("left + right") .map("$subtract(left: integer_type, right: integer_type)").to("left - right") .map("$multiply(left: integer_type, right: integer_type)").to("left * right") .map("$divide(left: integer_type, right: integer_type)").to("left / right") .map("$modulus(left: integer_type, right: integer_type)").to("left % right") .map("$negate(value: integer_type)").to("-value") - .when(decimalArithmeticsEnabled).map("$add(left: decimal_type, right: decimal_type)").to("left + right") - .when(decimalArithmeticsEnabled).map("$subtract(left: decimal_type, right: decimal_type)") + .add(new AdbRewriteDatetimeArithmetics()) + .add(new AdbRewriteCast(dataTypeMapper)) + .when(AdbPushdownSessionProperties::isPushdownFunctionUpper).map("upper(arg)").to("UPPER(arg)") + .when(AdbPushdownSessionProperties::isPushdownFunctionLower).map("lower(arg)").to("LOWER(arg)") + .build(); + } + + private void addPushdownCollateRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("$less_than(left: string_type, right: string_type)") + .to("left < right COLLATE \"C\"") + .when(predicate).map("$less_than_or_equal(left: string_type, right: string_type)") + .to("left <= right COLLATE \"C\"").when(predicate) + .map("$greater_than(left: string_type, right: string_type)") + .to("left > right COLLATE \"C\"").when(predicate) + .map("$greater_than_or_equal(left: string_type, right: string_type)") + .to("left >= right COLLATE \"C\""); + } + + private void addDatetimeComparisonRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("$less_than(left: datetime_type, right: datetime_type)") + .to("left < right") + .when(predicate).map("$less_than_or_equal(left: datetime_type, right: datetime_type)") + .to("left <= right").when(predicate).map("$greater_than(left: datetime_type, right: datetime_type)") + .to("left > right").when(predicate) + .map("$greater_than_or_equal(left: datetime_type, right: datetime_type)").to("left >= right"); + } + + private void addDecimalArithmeticRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("$add(left: decimal_type, right: decimal_type)").to("left + right") + .when(predicate).map("$subtract(left: decimal_type, right: decimal_type)") .to("left - right") - .when(decimalArithmeticsEnabled).map("$multiply(left: decimal_type, right: decimal_type)") + .when(predicate).map("$multiply(left: decimal_type, right: decimal_type)") .to("left * right") - .when(decimalArithmeticsEnabled).map("$divide(left: decimal_type, right: decimal_type)") + .when(predicate).map("$divide(left: decimal_type, right: decimal_type)") .to("left / right") - .when(decimalArithmeticsEnabled).map("$modulus(left: decimal_type, right: decimal_type)") + .when(predicate).map("$modulus(left: decimal_type, right: decimal_type)") .to("left % right") - .when(decimalArithmeticsEnabled).map("$negate(value: decimal_type)").to("-value") - .when(doubleArithmeticsEnabled).map("$add(left: double_type, right: double_type)").to("left + right") - .when(doubleArithmeticsEnabled).map("$subtract(left: double_type, right: double_type)") + .when(predicate).map("$negate(value: decimal_type)").to("-value"); + } + + private void addDoubleArithmeticRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("$add(left: double_type, right: double_type)").to("left + right") + .when(predicate).map("$subtract(left: double_type, right: double_type)") .to("left - right") - .when(doubleArithmeticsEnabled).map("$multiply(left: double_type, right: double_type)") + .when(predicate).map("$multiply(left: double_type, right: double_type)") .to("left * right") - .when(doubleArithmeticsEnabled).map("$divide(left: double_type, right: double_type)").to("left / right") - .when(doubleArithmeticsEnabled).map("$modulus(left: double_type, right: double_type)") + .when(predicate).map("$divide(left: double_type, right: double_type)").to("left / right") + .when(predicate).map("$modulus(left: double_type, right: double_type)") .to("left % right") - .when(doubleArithmeticsEnabled).map("$negate(value: double_type)").to("-value") - .add(new AdbRewriteDatetimeArithmetics()) - .add(new AdbRewriteCast(dataTypeMapper)) - .when(functionDatePartEnabled).map("year(arg: timestamp)").to("DATE_PART('isoyear', arg)") - .when(functionDatePartEnabled).map("quarter(arg: timestamp)").to("DATE_PART('quarter', arg)") - .when(functionDatePartEnabled).map("month(arg: timestamp)").to("DATE_PART('month', arg)") - .when(functionDatePartEnabled).map("week(arg: timestamp)").to("DATE_PART('week', arg)") - .when(functionDatePartEnabled).map("day(arg: timestamp)").to("DATE_PART('day', arg)") - .when(functionDatePartEnabled).map("day_of_week(arg: timestamp)").to("DATE_PART('isodow', arg)") - .when(functionDatePartEnabled).map("day_of_year(arg: timestamp)").to("DATE_PART('doy', arg)") - .when(functionDatePartEnabled).map("hour(arg: timestamp)").to("DATE_PART('hour', arg)") - .when(functionDatePartEnabled).map("minute(arg: timestamp)").to("DATE_PART('minute', arg)") - .when(functionLikeEnabled).map("$like(value: string_type, pattern): boolean").to("value LIKE pattern") - .when(functionLikeEnabled).map("$like(value: string_type, pattern, escape): boolean") - .to("value LIKE pattern ESCAPE escape") - .when(functionSubstringEnabled).map("substring(arg1: string_type, arg2: integer_type)") + .when(predicate).map("$negate(value: double_type)").to("-value"); + } + + private void addFunctionDatePartRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("year(arg: timestamp)").to("DATE_PART('isoyear', arg)") + .when(predicate).map("quarter(arg: timestamp)").to("DATE_PART('quarter', arg)") + .when(predicate).map("month(arg: timestamp)").to("DATE_PART('month', arg)") + .when(predicate).map("week(arg: timestamp)").to("DATE_PART('week', arg)") + .when(predicate).map("day(arg: timestamp)").to("DATE_PART('day', arg)") + .when(predicate).map("day_of_week(arg: timestamp)").to("DATE_PART('isodow', arg)") + .when(predicate).map("day_of_year(arg: timestamp)").to("DATE_PART('doy', arg)") + .when(predicate).map("hour(arg: timestamp)").to("DATE_PART('hour', arg)") + .when(predicate).map("minute(arg: timestamp)").to("DATE_PART('minute', arg)"); + } + + private void addFunctionLikeRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("$like(value: string_type, pattern): boolean").to("value LIKE pattern") + .when(predicate).map("$like(value: string_type, pattern, escape): boolean") + .to("value LIKE pattern ESCAPE escape"); + } + + private void addFunctionSubstringRules(Predicate predicate, + JdbcConnectorExpressionRewriterBuilder builder) + { + builder.when(predicate).map("substring(arg1: string_type, arg2: integer_type)") .to("SUBSTRING(arg1 FROM CAST(arg2 AS INT))") - .when(functionSubstringEnabled) + .when(predicate) .map("substring(arg1: string_type, arg2: integer_type, arg3: integer_type)") - .to("SUBSTRING(arg1 FROM CAST(arg2 AS INT) FOR CAST(arg3 AS INT))") - .when(functionUpperEnabled).map("upper(arg)").to("UPPER(arg)") - .when(functionLowerEnabled).map("lower(arg)").to("LOWER(arg)") - .build(); + .to("SUBSTRING(arg1 FROM CAST(arg2 AS INT) FOR CAST(arg3 AS INT))"); } private AggregateFunctionRewriter createAggregationFunctionRewriter( diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java index 394de7d3877a..6e4d3583271f 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/aggregation/AdbImplementMinMax.java @@ -55,15 +55,14 @@ public Optional rewrite(AggregateFunction aggregateFunction, Variable argument = captures.get(ARGUMENT); JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName()); verify(columnHandle.getColumnType().equals(aggregateFunction.getOutputType())); - Optional suffix = Optional.empty(); + String suffix = ""; if (columnHandle.getColumnType() instanceof CharType || columnHandle.getColumnType() instanceof VarcharType) { - suffix = Optional.of(" COLLATE \"C\""); + suffix = " COLLATE \"C\""; } ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); return Optional.of( new JdbcExpression( - format("%s(%s%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression(), - suffix.orElse("")), + format("%s(%s%s)", aggregateFunction.getFunctionName(), rewrittenArgument.expression(), suffix), rewrittenArgument.parameters(), columnHandle.getJdbcTypeHandle())); } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java index bb8438664ab4..ba053db476a3 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/encode/csv/CsvFormatConfig.java @@ -22,7 +22,7 @@ public class CsvFormatConfig implements DataFormatConfig { private char delimiter = '|'; - private Optional nullValue = Optional.empty(); + private String nullValue; private String encoding = "UTF-8"; public static CsvFormatConfig create() @@ -38,7 +38,7 @@ public CsvFormatConfig delimiter(char delimiter) public CsvFormatConfig nullValue(String nullValue) { - this.nullValue = Optional.ofNullable(nullValue); + this.nullValue = nullValue; return this; } @@ -49,7 +49,7 @@ public char getDelimiter() public Optional getNullValue() { - return nullValue; + return Optional.ofNullable(nullValue); } @Override diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java index 086e5850a6ba..b7646ed9a5f3 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/expression/AdbRewriteDatetimeArithmetics.java @@ -38,15 +38,15 @@ public class AdbRewriteDatetimeArithmetics { private static final String TYPE_NAME_DAY_SECOND = "interval day to second"; private static final String TYPE_NAME_YEAR_MONTH = "interval year to month"; - private static final Capture ARG0 = Capture.newCapture(); - private static final Capture ARG1 = Capture.newCapture(); + private static final Capture CONNECTOR_EXPR_CAPTURE = Capture.newCapture(); + private static final Capture CONSTANT_CAPTURE = Capture.newCapture(); private static final Pattern PATTERN = ConnectorExpressionPatterns.call() .with(ConnectorExpressionPatterns.functionName() .matching(n -> StandardFunctions.ADD_FUNCTION_NAME.equals(n) || StandardFunctions.SUBTRACT_FUNCTION_NAME.equals(n))) .with(ConnectorExpressionPatterns.argument(0) .matching(ConnectorExpressionPatterns.expression() - .capturedAs(ARG0) + .capturedAs(CONNECTOR_EXPR_CAPTURE) .with(ConnectorExpressionPatterns.type() .matching(type -> type == DateType.DATE || type instanceof TimeType @@ -54,7 +54,7 @@ public class AdbRewriteDatetimeArithmetics || type instanceof TimestampWithTimeZoneType)))) .with(ConnectorExpressionPatterns.argument(1) .matching(ConnectorExpressionPatterns.constant() - .matching(c -> isInterval(c.getType())).capturedAs(ARG1))); + .matching(c -> isInterval(c.getType())).capturedAs(CONSTANT_CAPTURE))); @Override public Pattern getPattern() @@ -72,23 +72,21 @@ public boolean isEnabled(ConnectorSession session) public Optional rewrite(Call expression, Captures captures, RewriteContext context) { - Constant arg1 = captures.get(ARG1); + Constant arg1 = captures.get(CONSTANT_CAPTURE); if (arg1.getValue() == null) { return Optional.empty(); } + Optional arg0 = context.defaultRewrite(captures.get(CONNECTOR_EXPR_CAPTURE)); + if (arg0.isEmpty()) { + return Optional.empty(); + } else { - Optional arg0 = context.defaultRewrite(captures.get(ARG0)); - if (arg0.isEmpty()) { - return Optional.empty(); - } - else { - String arg1Caption = String.format("interval '%d %s'", (Long) arg1.getValue(), - isDaySecondInterval(arg1.getType()) ? "milliseconds" : "months"); - String operator = expression.getFunctionName() == StandardFunctions.ADD_FUNCTION_NAME ? "+" : "-"; - return Optional.of(new ParameterizedExpression( - String.format("%s %s %s", arg0.get().expression(), operator, arg1Caption), - arg0.get().parameters())); - } + String arg1Caption = String.format("interval '%d %s'", (Long) arg1.getValue(), + isDaySecondInterval(arg1.getType()) ? "milliseconds" : "months"); + String operator = expression.getFunctionName() == StandardFunctions.ADD_FUNCTION_NAME ? "+" : "-"; + return Optional.of(new ParameterizedExpression( + String.format("%s %s %s", arg0.get().expression(), operator, arg1Caption), + arg0.get().parameters())); } } diff --git a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java index 69da9958a90d..ed4810298148 100644 --- a/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java +++ b/plugin/trino-adb/src/main/java/io/trino/plugin/adb/connector/protocol/gpfdist/GpfdistModule.java @@ -77,7 +77,9 @@ public void setup(Binder binder) Multibinder createExtTableQueryFactories = Multibinder.newSetBinder(binder, CreateExternalTableQueryFactory.class); - Multibinder insertDataQueryFactories = Multibinder.newSetBinder(binder, InsertDataQueryFactory.class);createExtTableQueryFactories.addBinding().to(CreateReadableExternalTableQueryFactory.class) + Multibinder insertDataQueryFactories = + Multibinder.newSetBinder(binder, InsertDataQueryFactory.class); + createExtTableQueryFactories.addBinding().to(CreateReadableExternalTableQueryFactory.class) .in(Scopes.SINGLETON); createExtTableQueryFactories.addBinding().to(CreateWritableExternalTableQueryFactory.class) .in(Scopes.SINGLETON); @@ -85,7 +87,8 @@ public void setup(Binder binder) insertDataQueryFactories.addBinding().to(InsertDataFromExternalTableQueryFactory.class).in(Scopes.SINGLETON); insertDataQueryFactories.addBinding().to(InsertDataToExternalTableQueryFactory.class).in(Scopes.SINGLETON); - binder.bind(ExternalTableFormatConfigFactory.class).to(ExternalTableFormatConfigFactoryImpl.class).in(Scopes.SINGLETON); + binder.bind(ExternalTableFormatConfigFactory.class).to(ExternalTableFormatConfigFactoryImpl.class) + .in(Scopes.SINGLETON); OptionalBinder.newOptionalBinder(binder, ConnectorPageSinkProvider.class).setBinding() .to(GpfdistPageSinkProvider.class).in(Scopes.SINGLETON); @@ -95,9 +98,12 @@ public void setup(Binder binder) binder.bind(GpfdistUnloadMetadataFactory.class).to(GpfdistUnloadMetadataFactoryImpl.class).in(Scopes.SINGLETON); binder.bind(GpfdistLocationFactory.class).to(GpfdistLocationFactoryImpl.class).in(Scopes.SINGLETON); binder.bind(GpfdistUnloadMetadataFactory.class).to(GpfdistUnloadMetadataFactoryImpl.class).in(Scopes.SINGLETON); - binder.bind(new TypeLiteral>() {}).to(ReadContextManager.class).in(Scopes.SINGLETON); - binder.bind(new TypeLiteral>() {}).to(WriteContextManager.class).in(Scopes.SINGLETON); - binder.bind(new TypeLiteral>() {}).to(GpfdistInputDataProcessorFactory.class).in(Scopes.SINGLETON); + binder.bind(new TypeLiteral>() {}).to(ReadContextManager.class) + .in(Scopes.SINGLETON); + binder.bind(new TypeLiteral>() {}).to(WriteContextManager.class) + .in(Scopes.SINGLETON); + binder.bind(new TypeLiteral>() {}).to(GpfdistInputDataProcessorFactory.class) + .in(Scopes.SINGLETON); binder.bind(NodeInfo.class).in(Scopes.SINGLETON); binder.bind(HttpServerInfo.class).in(Scopes.SINGLETON); binder.bind(RequestStats.class).in(Scopes.SINGLETON);