Skip to content

Commit

Permalink
Support tinyInt1isBit
Browse files Browse the repository at this point in the history
Motivation:
Aligning with MySQL connector.

Modifications:
Implemented `tinyInt1isBit` flag.

Result:
Improved compatibility with MySQL connectors.
  • Loading branch information
jchrys committed Feb 1, 2025
1 parent 820bdf4 commit abaa2d1
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 34 deletions.
2 changes: 1 addition & 1 deletion r2dbc-mysql/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

<groupId>io.asyncer</groupId>
<artifactId>r2dbc-mysql</artifactId>
<version>1.3.3-SNAPSHOT</version>
<version>1.4.0-SNAPSHOT</version>

<name>Reactive Relational Database Connectivity - MySQL</name>
<url>https://github.com/asyncer-io/r2dbc-mysql</url>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public final class ConnectionContext implements CodecContext {

private final int localInfileBufferSize;

private final boolean tinyInt1isBit;

private final boolean preserveInstants;

private int connectionId = -1;
Expand Down Expand Up @@ -107,12 +109,14 @@ public final class ConnectionContext implements CodecContext {
ZeroDateOption zeroDateOption,
@Nullable Path localInfilePath,
int localInfileBufferSize,
boolean tinyInt1isBit,
boolean preserveInstants,
@Nullable ZoneId timeZone
) {
this.zeroDateOption = requireNonNull(zeroDateOption, "zeroDateOption must not be null");
this.localInfilePath = localInfilePath;
this.localInfileBufferSize = localInfileBufferSize;
this.tinyInt1isBit = tinyInt1isBit;
this.preserveInstants = preserveInstants;
this.timeZone = timeZone;
}
Expand Down Expand Up @@ -216,6 +220,11 @@ public boolean isMariaDb() {
return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb();
}

@Override
public boolean isTinyInt1isBit() {
return tinyInt1isBit;
}

public boolean isNoBackslashEscapes() {
return (serverStatuses & ServerStatuses.NO_BACKSLASH_ESCAPES) != 0;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,26 @@ public final class MySqlConnectionConfiguration {

private final boolean metrics;

private final boolean tinyInt1isBit;

private MySqlConnectionConfiguration(
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
ZeroDateOption zeroDateOption,
boolean preserveInstants,
String connectionTimeZone,
boolean forceConnectionTimeZoneToSession,
String user, @Nullable CharSequence password, @Nullable String database,
boolean createDatabaseIfNotExist, @Nullable Predicate<String> preferPrepareStatement,
List<String> sessionVariables, @Nullable Duration lockWaitTimeout, @Nullable Duration statementTimeout,
@Nullable Path loadLocalInfilePath, int localInfileBufferSize,
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver,
boolean metrics
) {
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
ZeroDateOption zeroDateOption,
boolean preserveInstants,
String connectionTimeZone,
boolean forceConnectionTimeZoneToSession,
String user, @Nullable CharSequence password, @Nullable String database,
boolean createDatabaseIfNotExist, @Nullable Predicate<String> preferPrepareStatement,
List<String> sessionVariables, @Nullable Duration lockWaitTimeout, @Nullable Duration statementTimeout,
@Nullable Path loadLocalInfilePath, int localInfileBufferSize,
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver,
boolean metrics,
boolean tinyInt1isBit) {
this.isHost = isHost;
this.domain = domain;
this.port = port;
Expand Down Expand Up @@ -182,6 +184,7 @@ private MySqlConnectionConfiguration(
this.passwordPublisher = passwordPublisher;
this.resolver = resolver;
this.metrics = metrics;
this.tinyInt1isBit = tinyInt1isBit;
}

/**
Expand Down Expand Up @@ -321,6 +324,10 @@ boolean isMetrics() {
return metrics;
}

boolean isTinyInt1isBit() {
return tinyInt1isBit;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -359,7 +366,8 @@ public boolean equals(Object o) {
extensions.equals(that.extensions) &&
Objects.equals(passwordPublisher, that.passwordPublisher) &&
Objects.equals(resolver, that.resolver) &&
metrics == that.metrics;
metrics == that.metrics &&
tinyInt1isBit == that.tinyInt1isBit;
}

@Override
Expand All @@ -374,7 +382,7 @@ public int hashCode() {
loadLocalInfilePath, localInfileBufferSize,
queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel,
loopResources, extensions, passwordPublisher, resolver, metrics);
loopResources, extensions, passwordPublisher, resolver, metrics, tinyInt1isBit);
}

@Override
Expand Down Expand Up @@ -409,7 +417,8 @@ private String buildCommonToStringPart() {
", extensions=" + extensions +
", passwordPublisher=" + passwordPublisher +
", resolver=" + resolver +
", metrics=" + metrics;
", metrics=" + metrics +
", tinyint1isBit=" + tinyInt1isBit;
}

/**
Expand Down Expand Up @@ -511,6 +520,8 @@ public static final class Builder {

private boolean metrics;

private boolean tinyInt1isBit = true;

/**
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
*
Expand Down Expand Up @@ -545,11 +556,11 @@ public MySqlConnectionConfiguration build() {
loadLocalInfilePath,
localInfileBufferSize, queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel, loopResources,
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics);
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, metrics, tinyInt1isBit);
}

/**
* Configures the database. Default no database.
* Configures the database. Default no database.
*
* @param database the database, or {@code null} if no database want to be login.
* @return this {@link Builder}.
Expand Down Expand Up @@ -1207,6 +1218,20 @@ public Builder metrics(boolean enabled) {
return this;
}

/**
* Option to whether the driver should interpret MySQL's TINYINT(1) as a BIT type.
* When enabled, TINYINT(1) columns (both SIGNED and UNSIGNED) will be treated as
* BIT. default to {@code true}.
*
* @param tinyInt1isBit {@code true} to treat TINYINT(1) as BIT
* @return this {@link Builder}
* @since 1.4.0
*/
public Builder tinyInt1isBit(boolean tinyInt1isBit) {
this.tinyInt1isBit = tinyInt1isBit;
return this;
}

private SslMode requireSslMode() {
SslMode sslMode = this.sslMode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
configuration.getZeroDateOption(),
configuration.getLoadLocalInfilePath(),
configuration.getLocalInfileBufferSize(),
configuration.isTinyInt1isBit(),
configuration.isPreserveInstants(),
connectionTimeZone
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<Boolean> METRICS = Option.valueOf("metrics");

/**
* Option to whether the driver should interpret MySQL's TINYINT(1) as a BIT type.
* When enabled, TINYINT(1) columns (both SIGNED and UNSIGNED) will be treated as
* BIT. default to {@code true}.
*
* @since 1.4.0
*/
public static final Option<Boolean> TINY_INT_1_IS_BIT = Option.valueOf("tinyInt1isBit");

@Override
public ConnectionFactory create(ConnectionFactoryOptions options) {
requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down Expand Up @@ -424,7 +433,9 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) {
mapper.optional(STATEMENT_TIMEOUT).as(Duration.class, Duration::parse)
.to(builder::statementTimeout);
mapper.optional(METRICS).asBoolean()
.to(builder::metrics);
.to(builder::metrics);
mapper.optional(TINY_INT_1_IS_BIT).asBoolean()
.to(builder::tinyInt1isBit);

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,10 @@ public interface CodecContext {
* @return if is MariaDB.
*/
boolean isMariaDb();

/**
*
* @return true if tinyInt(1) is treated as bit.
*/
boolean isTinyInt1isBit();
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
*/
final class DefaultCodecs implements Codecs {

private static final Integer INTEGER_ONE = Integer.valueOf(1);

private static final List<Codec<?>> DEFAULT_CODECS = InternalArrays.asImmutableList(
ByteCodec.INSTANCE,
ShortCodec.INSTANCE,
Expand Down Expand Up @@ -137,6 +139,7 @@ private DefaultCodecs(List<Codec<?>> codecs) {
* Note: this method should NEVER release {@code buf} because of it come from {@code MySqlRow} which will release
* this buffer.
*/
@Nullable
@Override
public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, Class<?> type, boolean binary,
CodecContext context) {
Expand All @@ -151,7 +154,7 @@ public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, Class<?> t
return null;
}

Class<?> target = chooseClass(metadata, type);
Class<?> target = chooseClass(metadata, type, context);

if (value instanceof NormalFieldValue) {
return decodeNormal((NormalFieldValue) value, metadata, target, binary, context);
Expand All @@ -162,6 +165,7 @@ public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, Class<?> t
throw new IllegalArgumentException("Unknown value " + value.getClass().getSimpleName());
}

@Nullable
@Override
public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, ParameterizedType type,
boolean binary, CodecContext context) {
Expand Down Expand Up @@ -359,18 +363,27 @@ private <T> T decodeMassive(LargeFieldValue value, MySqlReadableMetadata metadat
* @param type the {@link Class} specified by the user.
* @return the {@link Class} to use for decoding.
*/
private static Class<?> chooseClass(final MySqlReadableMetadata metadata, Class<?> type) {
final Class<?> javaType = getDefaultJavaType(metadata);
private static Class<?> chooseClass(final MySqlReadableMetadata metadata, Class<?> type,
final CodecContext codecContext) {
final Class<?> javaType = getDefaultJavaType(metadata, codecContext);
return type.isAssignableFrom(javaType) ? javaType : type;
}

private static Class<?> getDefaultJavaType(final MySqlReadableMetadata metadata) {
private static Class<?> getDefaultJavaType(final MySqlReadableMetadata metadata, final CodecContext codecContext) {
final MySqlType type = metadata.getType();
final Integer precision = metadata.getPrecision();

if (INTEGER_ONE.equals(precision) && (type == MySqlType.TINYINT || type == MySqlType.TINYINT_UNSIGNED)
&& codecContext.isTinyInt1isBit()) {
return Boolean.class;
}

// ref: https://github.com/asyncer-io/r2dbc-mysql/issues/277
// BIT(1) should be treated as Boolean by default.
if (type == MySqlType.BIT && Integer.valueOf(1).equals(metadata.getPrecision())) {
if (INTEGER_ONE.equals(precision) && type == MySqlType.BIT) {
return Boolean.class;
}

return type.getJavaType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void getTimeZone() {
String id = i < 0 ? "UTC" + i : "UTC+" + i;
ConnectionContext context = new ConnectionContext(
ZeroDateOption.USE_NULL, null,
8192, true, ZoneId.of(id));
8192, true, true, ZoneId.of(id));

assertThat(context.getTimeZone()).isEqualTo(ZoneId.of(id));
}
Expand All @@ -48,7 +48,7 @@ void getTimeZone() {
@Test
void setTwiceTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, null);
8192, true, true, null);

context.initSession(
Caches.createPrepareCache(0),
Expand All @@ -70,7 +70,7 @@ void setTwiceTimeZone() {
@Test
void badSetTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, ZoneId.systemDefault());
8192, true, true, ZoneId.systemDefault());
assertThatIllegalStateException().isThrownBy(() -> context.initSession(
Caches.createPrepareCache(0),
IsolationLevel.REPEATABLE_READ,
Expand All @@ -91,7 +91,7 @@ public static ConnectionContext mock(boolean isMariaDB) {

public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, zoneId);
8192, true, true, zoneId);

context.initHandshake(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"),
Capability.of(~(isMariaDB ? 1 : 0)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,31 @@ void loadDataLocalInfile(String name) throws URISyntaxException, IOException {
.doOnNext(it -> assertThat(it).isEqualTo(json)));
}

@Test
public void tinyInt1isBitTrueTestValue1() {
complete(connection -> Mono.from(connection.createStatement("CREATE TEMPORARY TABLE `test` (`id` INT NOT NULL PRIMARY KEY, `value` TINYINT(1))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO `test` VALUES (1, 1)").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("SELECT `value` FROM `test`").execute())
.flatMap(result -> result.map((row, metadata) -> row.get("value", Object.class)))
.doOnNext(value -> assertThat(value).isInstanceOf(Boolean.class))
.doOnNext(value -> assertThat(value).isEqualTo(true))
);
}

@Test
public void tinyInt1isBitTrueTestValue0() {
complete(connection -> Mono.from(connection.createStatement("CREATE TEMPORARY TABLE `test` (`id` INT NOT NULL PRIMARY KEY, `value` TINYINT(1))").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO `test` VALUES (1, 0)").execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.thenMany(connection.createStatement("SELECT `value` FROM `test`").execute())
.flatMap(result -> result.map((row, metadata) -> row.get("value", Object.class)))
.doOnNext(value -> assertThat(value).isInstanceOf(Boolean.class))
.doOnNext(value -> assertThat(value).isEqualTo(false)));
}

@Test
void batchCrud() {
// TODO: spilt it to multiple test cases and move it to BatchIntegrationTest
Expand Down
Loading

0 comments on commit abaa2d1

Please sign in to comment.