diff --git a/common/lib/aws_client.ts b/common/lib/aws_client.ts index 6fb9fe72..3966da67 100644 --- a/common/lib/aws_client.ts +++ b/common/lib/aws_client.ts @@ -35,6 +35,7 @@ import { DriverConfigurationProfiles } from "./profile/driver_configuration_prof import { ConfigurationProfile } from "./profile/configuration_profile"; import { AwsWrapperError } from "./utils/errors"; import { Messages } from "./utils/messages"; +import { TransactionIsolationLevel } from "./utils/transaction_isolation_level"; export abstract class AwsClient extends EventEmitter { private _defaultPort: number = -1; @@ -42,8 +43,6 @@ export abstract class AwsClient extends EventEmitter { protected pluginManager: PluginManager; protected pluginService: PluginService; protected isConnected: boolean = false; - protected _isReadOnly: boolean = false; - protected _isolationLevel: number = 0; protected _connectionUrlParser: ConnectionUrlParser; protected _configurationProfile: ConfigurationProfile | null = null; readonly properties: Map; @@ -151,8 +150,6 @@ export abstract class AwsClient extends EventEmitter { return this._connectionUrlParser; } - abstract updateSessionStateReadOnly(readOnly: boolean): Promise; - abstract setReadOnly(readOnly: boolean): Promise; abstract isReadOnly(): boolean; @@ -161,9 +158,9 @@ export abstract class AwsClient extends EventEmitter { abstract getAutoCommit(): boolean; - abstract setTransactionIsolation(transactionIsolation: number): Promise; + abstract setTransactionIsolation(level: TransactionIsolationLevel): Promise; - abstract getTransactionIsolation(): number; + abstract getTransactionIsolation(): TransactionIsolationLevel; abstract setSchema(schema: any): Promise; @@ -179,8 +176,6 @@ export abstract class AwsClient extends EventEmitter { abstract rollback(): Promise; - abstract resetState(): void; - async isValid(): Promise { if (!this.targetClient) { return Promise.resolve(false); diff --git a/common/lib/database_dialect/database_dialect.ts b/common/lib/database_dialect/database_dialect.ts index a631f490..7eae8676 100644 --- a/common/lib/database_dialect/database_dialect.ts +++ b/common/lib/database_dialect/database_dialect.ts @@ -19,6 +19,8 @@ import { HostListProviderService } from "../host_list_provider_service"; import { ClientWrapper } from "../client_wrapper"; import { FailoverRestriction } from "../plugins/failover/failover_restriction"; import { ErrorHandler } from "../error_handler"; +import { SessionState } from "../session_state"; +import { TransactionIsolationLevel } from "../utils/transaction_isolation_level"; export enum DatabaseType { MYSQL, @@ -30,6 +32,11 @@ export interface DatabaseDialect { getHostAliasQuery(): string; getHostAliasAndParseResults(targetClient: ClientWrapper): Promise; getServerVersionQuery(): string; + getSetReadOnlyQuery(readOnly: boolean): string; + getSetAutoCommitQuery(autoCommit: boolean): string; + getSetTransactionIsolationQuery(level: TransactionIsolationLevel): string; + getSetCatalogQuery(catalog: string): string; + getSetSchemaQuery(schema: string): string; getDialectUpdateCandidates(): string[]; getErrorHandler(): ErrorHandler; isDialect(targetClient: ClientWrapper): Promise; @@ -39,7 +46,7 @@ export interface DatabaseDialect { getDialectName(): string; getFailoverRestrictions(): FailoverRestriction[]; doesStatementSetReadOnly(statement: string): boolean | undefined; - doesStatementSetTransactionIsolation(statement: string): number | undefined; + doesStatementSetTransactionIsolation(statement: string): TransactionIsolationLevel | undefined; doesStatementSetAutoCommit(statement: string): boolean | undefined; doesStatementSetSchema(statement: string): string | undefined; doesStatementSetCatalog(statement: string): string | undefined; diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index 8284c1f1..f4ea810a 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -44,6 +44,7 @@ import { getWriter } from "./utils/utils"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { DriverDialect } from "./driver_dialect/driver_dialect"; import { ConfigurationProfile } from "./profile/configuration_profile"; +import { SessionState } from "./session_state"; export class PluginService implements ErrorHandler, HostListProviderService { private readonly _currentClient: AwsClient; @@ -75,10 +76,10 @@ export class PluginService implements ErrorHandler, HostListProviderService { this.dbDialectProvider = new DatabaseDialectManager(knownDialectsByCode, dbType, this.props); this.driverDialect = driverDialect; this.initialHost = props.get(WrapperProperties.HOST.name); - this.sessionStateService = new SessionStateServiceImpl(this, this.props); container.pluginService = this; this.dialect = WrapperProperties.CUSTOM_DATABASE_DIALECT.get(this.props) ?? this.dbDialectProvider.getDialect(this.props); + this.sessionStateService = new SessionStateServiceImpl(this, this.props); } isInTransaction(): boolean { @@ -333,7 +334,6 @@ export class PluginService implements ErrorHandler, HostListProviderService { this.sessionStateService.begin(); try { - this.getCurrentClient().resetState(); this.getCurrentClient().targetClient = newClient; this._currentHostInfo = hostInfo; await this.sessionStateService.applyCurrentSessionState(this.getCurrentClient()); @@ -432,35 +432,35 @@ export class PluginService implements ErrorHandler, HostListProviderService { private async updateReadOnly(statements: string[]) { const updateReadOnly = SqlMethodUtils.doesSetReadOnly(statements, this.getDialect()); if (updateReadOnly !== undefined) { - await this.getCurrentClient().setReadOnly(updateReadOnly); + this.getSessionStateService().setReadOnly(updateReadOnly); } } private async updateAutoCommit(statements: string[]) { const updateAutoCommit = SqlMethodUtils.doesSetAutoCommit(statements, this.getDialect()); if (updateAutoCommit !== undefined) { - await this.getCurrentClient().setAutoCommit(updateAutoCommit); + this.getSessionStateService().setAutoCommit(updateAutoCommit); } } private async updateCatalog(statements: string[]) { const updateCatalog = SqlMethodUtils.doesSetCatalog(statements, this.getDialect()); if (updateCatalog !== undefined) { - await this.getCurrentClient().setCatalog(updateCatalog); + this.getSessionStateService().setCatalog(updateCatalog); } } private async updateSchema(statements: string[]) { const updateSchema = SqlMethodUtils.doesSetSchema(statements, this.getDialect()); if (updateSchema !== undefined) { - await this.getCurrentClient().setSchema(updateSchema); + this.getSessionStateService().setSchema(updateSchema); } } private async updateTransactionIsolation(statements: string[]) { const updateTransactionIsolation = SqlMethodUtils.doesSetTransactionIsolation(statements, this.getDialect()); if (updateTransactionIsolation !== undefined) { - await this.getCurrentClient().setTransactionIsolation(updateTransactionIsolation); + this.getSessionStateService().setTransactionIsolation(updateTransactionIsolation); } } diff --git a/common/lib/plugins/failover/failover_plugin.ts b/common/lib/plugins/failover/failover_plugin.ts index b6914c66..ac4209a9 100644 --- a/common/lib/plugins/failover/failover_plugin.ts +++ b/common/lib/plugins/failover/failover_plugin.ts @@ -388,9 +388,9 @@ export class FailoverPlugin extends AbstractConnectionPlugin { throw new FailoverFailedError(Messages.get("Failover.unableToConnectToReader")); } - this.pluginService.getCurrentHostInfo()?.removeAlias(Array.from(oldAliases)); await this.pluginService.abortCurrentClient(); await this.pluginService.setCurrentClient(result.client, result.newHost); + this.pluginService.getCurrentHostInfo()?.removeAlias(Array.from(oldAliases)); await this.updateTopology(true); this.failoverReaderSuccessCounter.inc(); } catch (error: any) { diff --git a/common/lib/pool_client_wrapper.ts b/common/lib/pool_client_wrapper.ts index 0cdacb67..12dad3c1 100644 --- a/common/lib/pool_client_wrapper.ts +++ b/common/lib/pool_client_wrapper.ts @@ -18,12 +18,14 @@ import { ClientWrapper } from "./client_wrapper"; import { HostInfo } from "./host_info"; import { uniqueId } from "../logutils"; import { ClientUtils } from "./utils/client_utils"; +import { SessionState } from "./session_state"; export class PoolClientWrapper implements ClientWrapper { readonly client: any; readonly hostInfo: HostInfo; readonly properties: Map; readonly id: string; + readonly sessionState = new SessionState(); constructor(targetClient: any, hostInfo: HostInfo, properties: Map) { this.client = targetClient; diff --git a/common/lib/session_state.ts b/common/lib/session_state.ts index a14af7d7..b284693a 100644 --- a/common/lib/session_state.ts +++ b/common/lib/session_state.ts @@ -14,10 +14,29 @@ limitations under the License. */ -class SessionStateField { +import { DatabaseDialect } from "./database_dialect/database_dialect"; +import { AwsClient } from "./aws_client"; +import { TransactionIsolationLevel } from "./utils/transaction_isolation_level"; + +export abstract class SessionStateField { value?: Type; pristineValue?: Type; + constructor(copy?: SessionStateField) { + if (copy) { + this.value = copy.value; + this.pristineValue = copy.pristineValue; + } + } + + abstract setValue(state: SessionState): void; + + abstract setPristineValue(state: SessionState): void; + + abstract getQuery(dialect: DatabaseDialect, isPristine: boolean): string; + + abstract getClientValue(client: AwsClient): Type; + resetValue(): void { this.value = undefined; } @@ -60,38 +79,123 @@ class SessionStateField { return true; } - copy(): SessionStateField { - const newField: SessionStateField = new SessionStateField(); - if (this.value !== undefined) { - newField.value = this.value; - } + toString() { + return `${this.pristineValue ?? "(blank)"} => ${this.value ?? "(blank)"}`; + } +} - if (this.pristineValue !== undefined) { - newField.pristineValue = this.pristineValue; - } +class AutoCommitState extends SessionStateField { + setValue(state: SessionState) { + this.value = state.autoCommit.value; + } - return newField; + setPristineValue(state: SessionState) { + this.value = state.autoCommit.pristineValue; } - toString() { - return `${this.pristineValue ?? "(blank)"} => ${this.value ?? "(blank)"}`; + getQuery(dialect: DatabaseDialect, isPristine: boolean = false) { + return dialect.getSetAutoCommitQuery(isPristine ? this.pristineValue : this.value); + } + + getClientValue(client: AwsClient): boolean { + return client.getAutoCommit(); + } +} + +class ReadOnlyState extends SessionStateField { + setValue(state: SessionState) { + this.value = state.readOnly.value; + } + + setPristineValue(state: SessionState) { + this.value = state.readOnly.pristineValue; + } + + getQuery(dialect: DatabaseDialect, isPristine: boolean = false) { + return dialect.getSetReadOnlyQuery(this.value); + } + + getClientValue(client: AwsClient): boolean { + return client.isReadOnly(); + } +} + +class CatalogState extends SessionStateField { + setValue(state: SessionState) { + this.value = state.catalog.value; + } + + setPristineValue(state: SessionState) { + this.value = state.catalog.pristineValue; + } + + getQuery(dialect: DatabaseDialect, isPristine: boolean = false) { + return dialect.getSetCatalogQuery(isPristine ? this.pristineValue : this.value); + } + + getClientValue(client: AwsClient): string { + return client.getCatalog(); + } +} + +class SchemaState extends SessionStateField { + setValue(state: SessionState) { + this.value = state.schema.value; + } + + setPristineValue(state: SessionState) { + this.value = state.schema.pristineValue; + } + + getQuery(dialect: DatabaseDialect, isPristine: boolean = false) { + return dialect.getSetSchemaQuery(isPristine ? this.pristineValue : this.value); + } + + getClientValue(client: AwsClient): string { + return client.getSchema(); + } +} + +class TransactionIsolationState extends SessionStateField { + setValue(state: SessionState) { + this.value = state.transactionIsolation.value; + } + + setPristineValue(state: SessionState) { + this.value = state.transactionIsolation.pristineValue; + } + + getQuery(dialect: DatabaseDialect, isPristine: boolean = false) { + return dialect.getSetTransactionIsolationQuery(isPristine ? this.pristineValue : this.value); + } + + getClientValue(client: AwsClient): number { + return client.getTransactionIsolation(); } } export class SessionState { - autoCommit: SessionStateField = new SessionStateField(); - readOnly: SessionStateField = new SessionStateField(); - catalog: SessionStateField = new SessionStateField(); - schema: SessionStateField = new SessionStateField(); - transactionIsolation: SessionStateField = new SessionStateField(); + autoCommit: AutoCommitState = new AutoCommitState(); + readOnly: ReadOnlyState = new ReadOnlyState(); + catalog: CatalogState = new CatalogState(); + schema: SchemaState = new SchemaState(); + transactionIsolation: TransactionIsolationState = new TransactionIsolationState(); + + static setState(target: SessionStateField, source: SessionState): void { + target.setValue(source); + } + + static setPristineState(target: SessionStateField, source: SessionState): void { + target.setPristineValue(source); + } copy(): SessionState { const newSessionState = new SessionState(); - newSessionState.autoCommit = this.autoCommit.copy(); - newSessionState.readOnly = this.readOnly.copy(); - newSessionState.catalog = this.catalog.copy(); - newSessionState.schema = this.schema.copy(); - newSessionState.transactionIsolation = this.transactionIsolation.copy(); + newSessionState.autoCommit = new AutoCommitState(this.autoCommit); + newSessionState.readOnly = new ReadOnlyState(this.readOnly); + newSessionState.catalog = new CatalogState(this.catalog); + newSessionState.schema = new SchemaState(this.schema); + newSessionState.transactionIsolation = new TransactionIsolationState(this.transactionIsolation); return newSessionState; } diff --git a/common/lib/session_state_service.ts b/common/lib/session_state_service.ts index c1098103..ed0a4b5e 100644 --- a/common/lib/session_state_service.ts +++ b/common/lib/session_state_service.ts @@ -15,6 +15,7 @@ */ import { AwsClient } from "./aws_client"; +import { TransactionIsolationLevel } from "./utils/transaction_isolation_level"; export interface SessionStateService { // auto commit @@ -26,26 +27,27 @@ export interface SessionStateService { // read only getReadOnly(): boolean | undefined; setReadOnly(readOnly: boolean): void; - setupPristineReadOnly(): boolean | undefined; - setupPristineReadOnly(readOnly: boolean): boolean | undefined; + setupPristineReadOnly(): void; + setupPristineReadOnly(readOnly: boolean): void; + updateReadOnly(readOnly: boolean): void; // catalog getCatalog(): string | undefined; setCatalog(catalog: string): void; - setupPristineCatalog(): string | undefined; - setupPristineCatalog(catalog: string): string | undefined; + setupPristineCatalog(): void; + setupPristineCatalog(catalog: string): void; // schema getSchema(): string | undefined; setSchema(schema: string): void; - setupPristineSchema(): string | undefined; - setupPristineSchema(schema: string): string | undefined; + setupPristineSchema(): void; + setupPristineSchema(schema: string): void; // transaction isolation - getTransactionIsolation(): number | undefined; - setTransactionIsolation(transactionIsolation: number): void; - setupPristineTransactionIsolation(): number | undefined; - setupPristineTransactionIsolation(transactionIsolation: number): number | undefined; + getTransactionIsolation(): TransactionIsolationLevel | undefined; + setTransactionIsolation(transactionIsolation: TransactionIsolationLevel): void; + setupPristineTransactionIsolation(): void; + setupPristineTransactionIsolation(transactionIsolation: TransactionIsolationLevel): void; reset(): void; diff --git a/common/lib/session_state_service_impl.ts b/common/lib/session_state_service_impl.ts index 8bd3b172..4e25a133 100644 --- a/common/lib/session_state_service_impl.ts +++ b/common/lib/session_state_service_impl.ts @@ -17,11 +17,13 @@ import { WrapperProperties } from "./wrapper_property"; import { SessionStateService } from "./session_state_service"; import { AwsClient } from "./aws_client"; -import { SessionState } from "./session_state"; +import { SessionState, SessionStateField } from "./session_state"; import { PluginService } from "./plugin_service"; import { AwsWrapperError, UnsupportedMethodError } from "./utils/errors"; import { logger } from "../logutils"; import { SessionStateTransferHandler } from "./session_state_transfer_handler"; +import { ClientWrapper } from "./client_wrapper"; +import { TransactionIsolationLevel } from "./utils/transaction_isolation_level"; export class SessionStateServiceImpl implements SessionStateService { protected sessionState: SessionState; @@ -57,68 +59,15 @@ export class SessionStateServiceImpl implements SessionStateService { } } - if (this.sessionState.autoCommit.value !== undefined) { - this.sessionState.autoCommit.resetPristineValue(); - this.setupPristineAutoCommit(); - try { - await newClient.setAutoCommit(this.sessionState.autoCommit.value); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } - - if (this.sessionState.readOnly.value !== undefined) { - this.sessionState.readOnly.resetPristineValue(); - this.setupPristineReadOnly(); - try { - await newClient.updateSessionStateReadOnly(this.sessionState.readOnly.value); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } - - if (this.sessionState.catalog.value !== undefined) { - this.sessionState.catalog.resetPristineValue(); - this.setupPristineCatalog(); - try { - await newClient.setCatalog(this.sessionState.catalog.value); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } - - if (this.sessionState.schema.value !== undefined) { - this.sessionState.schema.resetPristineValue(); - this.setupPristineSchema(); - try { - await newClient.setSchema(this.sessionState.schema.value); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } + const targetClient: ClientWrapper = newClient.targetClient; - if (this.sessionState.transactionIsolation.value !== undefined) { - this.sessionState.transactionIsolation.resetPristineValue(); - this.setupPristineTransactionIsolation(); - try { - await newClient.setTransactionIsolation(this.sessionState.transactionIsolation.value); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; + // Apply current state for all 5 states: autoCommit, readOnly, catalog, schema, transactionIsolation + for (const key of Object.keys(this.sessionState)) { + const state = this.sessionState[key]; + if (state instanceof SessionStateField) { + await this.applyCurrentState(targetClient, state); + } else { + throw new AwsWrapperError(`Unexpected session state key: ${key}`); } } } @@ -137,58 +86,20 @@ export class SessionStateServiceImpl implements SessionStateService { } } - if (this.copySessionState?.autoCommit.canRestorePristine() && this.copySessionState?.autoCommit.pristineValue !== undefined) { - try { - await client.setAutoCommit(this.copySessionState?.autoCommit.pristineValue); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } - - if (this.copySessionState?.readOnly.canRestorePristine() && this.copySessionState?.readOnly.pristineValue !== undefined) { - try { - await client.updateSessionStateReadOnly(this.copySessionState?.readOnly.pristineValue); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } - - if (this.copySessionState?.catalog.canRestorePristine() && this.copySessionState?.catalog.pristineValue !== undefined) { - try { - await client.setCatalog(this.copySessionState?.catalog.pristineValue); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } + if (this.copySessionState === undefined) { + return; } - if (this.copySessionState?.schema.canRestorePristine() && this.copySessionState?.schema.pristineValue !== undefined) { - try { - await client.setSchema(this.copySessionState?.schema.pristineValue); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; - } - } + const targetClient: ClientWrapper = client.targetClient; - if (this.copySessionState?.transactionIsolation.canRestorePristine() && this.copySessionState?.transactionIsolation.pristineValue !== undefined) { - try { - await client.setTransactionIsolation(this.copySessionState?.transactionIsolation.pristineValue); - } catch (error: any) { - if (error instanceof UnsupportedMethodError) { - // ignore - } - throw error; + // Set pristine states on all target client. + // The states that will be set are: autoCommit, readonly, schema, catalog, transactionIsolation. + for (const key of Object.keys(this.copySessionState)) { + const state = this.copySessionState[key]; + if (state instanceof SessionStateField) { + await this.setPristineStateOnTarget(targetClient, state, key); + } else { + throw new AwsWrapperError(`Unexpected session state key: ${key}`); } } } @@ -198,30 +109,13 @@ export class SessionStateServiceImpl implements SessionStateService { } setAutoCommit(autoCommit: boolean): void { - if (!this.transferStateEnabledSetting()) { - return; - } - - this.sessionState.autoCommit.value = autoCommit; - this.logCurrentState(); + return this.setState("autoCommit", autoCommit); } setupPristineAutoCommit(): void; setupPristineAutoCommit(autoCommit: boolean): void; setupPristineAutoCommit(autoCommit?: boolean): void { - if (!this.resetStateEnabledSetting()) { - return; - } - - if (this.sessionState.autoCommit.pristineValue !== undefined) { - return; - } - - if (autoCommit !== undefined) { - this.sessionState.autoCommit.pristineValue = autoCommit; - } else { - this.sessionState.autoCommit.pristineValue = this.pluginService.getCurrentClient().getAutoCommit(); - } + return this.setupPristineState(this.sessionState.autoCommit, autoCommit); } getCatalog(): string | undefined { @@ -229,30 +123,13 @@ export class SessionStateServiceImpl implements SessionStateService { } setCatalog(catalog: string): void { - if (!this.transferStateEnabledSetting()) { - return; - } - - this.sessionState.catalog.value = catalog; - this.logCurrentState(); + return this.setState("catalog", catalog); } - setupPristineCatalog(): string | undefined; - setupPristineCatalog(catalog: string): string | undefined; - setupPristineCatalog(catalog?: string): string | undefined { - if (!this.resetStateEnabledSetting()) { - return; - } - - if (this.sessionState.catalog.pristineValue !== undefined) { - return; - } - - if (catalog !== undefined) { - this.sessionState.catalog.pristineValue = catalog; - } else { - this.sessionState.catalog.pristineValue = this.pluginService.getCurrentClient().getCatalog(); - } + setupPristineCatalog(): void; + setupPristineCatalog(catalog: string): void; + setupPristineCatalog(catalog?: string): void { + this.setupPristineState(this.sessionState.catalog, catalog); } getReadOnly(): boolean | undefined { @@ -260,30 +137,18 @@ export class SessionStateServiceImpl implements SessionStateService { } setReadOnly(readOnly: boolean): void { - if (!this.transferStateEnabledSetting()) { - return; - } - - this.sessionState.readOnly.value = readOnly; - this.logCurrentState(); + return this.setState("readOnly", readOnly); } - setupPristineReadOnly(): boolean | undefined; - setupPristineReadOnly(readOnly: boolean): boolean | undefined; - setupPristineReadOnly(readOnly?: boolean): boolean | undefined { - if (!this.resetStateEnabledSetting()) { - return; - } - - if (this.sessionState.readOnly.pristineValue !== undefined) { - return; - } + setupPristineReadOnly(): void; + setupPristineReadOnly(readOnly: boolean): void; + setupPristineReadOnly(readOnly?: boolean): void { + this.setupPristineState(this.sessionState.readOnly, readOnly); + } - if (readOnly !== undefined) { - this.sessionState.readOnly.pristineValue = readOnly; - } else { - this.sessionState.readOnly.pristineValue = this.pluginService.getCurrentClient().isReadOnly(); - } + updateReadOnly(readOnly: boolean): void { + this.pluginService.getSessionStateService().setupPristineReadOnly(readOnly); + this.pluginService.getSessionStateService().setReadOnly(readOnly); } getSchema(): string | undefined { @@ -291,61 +156,95 @@ export class SessionStateServiceImpl implements SessionStateService { } setSchema(schema: string): void { - if (!this.transferStateEnabledSetting()) { - return; - } - - this.sessionState.schema.value = schema; - this.logCurrentState(); + return this.setState("schema", schema); } - setupPristineSchema(): string | undefined; - setupPristineSchema(schema: string): string | undefined; - setupPristineSchema(schema?: string): string | undefined { - if (!this.resetStateEnabledSetting()) { - return; - } + setupPristineSchema(): void; + setupPristineSchema(schema: string): void; + setupPristineSchema(schema?: string): void { + this.setupPristineState(this.sessionState.schema, schema); + } - if (this.sessionState.schema.pristineValue !== undefined) { - return; - } + getTransactionIsolation(): TransactionIsolationLevel | undefined { + return this.sessionState.transactionIsolation.value; + } - if (schema !== undefined) { - this.sessionState.schema.pristineValue = schema; - } else { - this.sessionState.schema.pristineValue = this.pluginService.getCurrentClient().getSchema(); - } + setTransactionIsolation(transactionIsolation: TransactionIsolationLevel): void { + return this.setState("transactionIsolation", transactionIsolation); } - getTransactionIsolation(): number | undefined { - return this.sessionState.transactionIsolation.value; + setupPristineTransactionIsolation(): void; + setupPristineTransactionIsolation(transactionIsolation: TransactionIsolationLevel): void; + setupPristineTransactionIsolation(transactionIsolation?: TransactionIsolationLevel): void { + this.setupPristineState(this.sessionState.transactionIsolation, transactionIsolation); } - setTransactionIsolation(transactionIsolation: number): void { + private setState(state: any, val: Type): void { if (!this.transferStateEnabledSetting()) { return; } - this.sessionState.transactionIsolation.value = transactionIsolation; + this.sessionState[state].value = val; this.logCurrentState(); } - setupPristineTransactionIsolation(): number | undefined; - setupPristineTransactionIsolation(transactionIsolation: number): number | undefined; - setupPristineTransactionIsolation(transactionIsolation?: number): number | undefined { + private async applyCurrentState(targetClient: ClientWrapper, sessionState: SessionStateField): Promise { + if (sessionState.value !== undefined) { + sessionState.resetPristineValue(); + this.setupPristineState(sessionState); + await this.setStateOnTarget(targetClient, sessionState); + } + } + + private async setStateOnTarget(targetClient: ClientWrapper, sessionStateField: SessionStateField): Promise { + try { + await targetClient.query(sessionStateField.getQuery(this.pluginService.getDialect(), false)); + SessionState.setState(sessionStateField, this.sessionState); + } catch (error: any) { + if (error instanceof UnsupportedMethodError) { + // UnsupportedMethodError is thrown if the database does not support setting this state. + // For instance, PostgreSQL does not support setting the catalog and instead supports setting the schema. + // In this case, ignore the error. + return; + } + throw error; + } + } + + private async setPristineStateOnTarget( + targetClient: ClientWrapper, + sessionStateField: SessionStateField, + sessionStateName: string + ): Promise { + if (sessionStateField.canRestorePristine() && sessionStateField.pristineValue !== undefined) { + try { + await targetClient.query(sessionStateField.getQuery(this.pluginService.getDialect(), true)); + this.setState(sessionStateName, sessionStateField.pristineValue); + SessionState.setPristineState(sessionStateField, this.copySessionState); + } catch (error: any) { + if (error instanceof UnsupportedMethodError) { + // UnsupportedMethodError is thrown if the database does not support setting this state. + // For instance, PostgreSQL does not support setting the catalog and instead supports setting the schema. + // In this case, ignore the error. + return; + } + throw error; + } + } + } + + private setupPristineState(state: SessionStateField): void; + private setupPristineState(state: SessionStateField, val: Type): void; + private setupPristineState(state: SessionStateField, val?: Type): void { if (!this.resetStateEnabledSetting()) { return; } - if (this.sessionState.transactionIsolation.pristineValue !== undefined) { + if (state.pristineValue !== undefined) { return; } - if (transactionIsolation !== undefined) { - this.sessionState.transactionIsolation.pristineValue = transactionIsolation; - } else { - this.sessionState.transactionIsolation.pristineValue = this.pluginService.getCurrentClient().getTransactionIsolation(); - } + state.pristineValue = val ?? state.getClientValue(this.pluginService.getCurrentClient()); } begin(): void { diff --git a/common/lib/utils/sql_method_utils.ts b/common/lib/utils/sql_method_utils.ts index 4a5eded7..02f603b5 100644 --- a/common/lib/utils/sql_method_utils.ts +++ b/common/lib/utils/sql_method_utils.ts @@ -20,11 +20,17 @@ import { TransactionIsolationLevel } from "./transaction_isolation_level"; export class SqlMethodUtils { static doesOpenTransaction(sql: string) { const firstStatement = SqlMethodUtils.getFirstSqlStatement(sql); + if (!firstStatement) { + return false; + } return firstStatement.toLowerCase().startsWith("start transaction") || firstStatement.toLowerCase().startsWith("begin"); } static doesCloseTransaction(sql: string) { const firstStatement = SqlMethodUtils.getFirstSqlStatement(sql); + if (!firstStatement) { + return false; + } return ( firstStatement.toLowerCase().startsWith("commit") || firstStatement.toLowerCase().startsWith("rollback") || @@ -47,7 +53,7 @@ export class SqlMethodUtils { } static doesSetAutoCommit(statements: string[], dialect: DatabaseDialect): boolean | undefined { - let autoCommit; + let autoCommit = undefined; for (const statement of statements) { const cleanStatement = statement .toLowerCase() @@ -60,7 +66,7 @@ export class SqlMethodUtils { } static doesSetCatalog(statements: string[], dialect: DatabaseDialect): string | undefined { - let catalog; + let catalog = undefined; for (const statement of statements) { const cleanStatement = statement .toLowerCase() @@ -73,7 +79,7 @@ export class SqlMethodUtils { } static doesSetSchema(statements: string[], dialect: DatabaseDialect): string | undefined { - let schema; + let schema = undefined; for (const statement of statements) { const cleanStatement = statement .toLowerCase() @@ -86,7 +92,7 @@ export class SqlMethodUtils { } static doesSetTransactionIsolation(statements: string[], dialect: DatabaseDialect): TransactionIsolationLevel | undefined { - let transactionIsolation; + let transactionIsolation = undefined; for (const statement of statements) { const cleanStatement = statement .toLowerCase() @@ -100,7 +106,6 @@ export class SqlMethodUtils { static getFirstSqlStatement(sql: string) { const statements = SqlMethodUtils.parseMultiStatementQueries(sql); - if (statements.length === 0) { return sql; } diff --git a/mysql/lib/client.ts b/mysql/lib/client.ts index 1c47e02b..0f1ddc9b 100644 --- a/mysql/lib/client.ts +++ b/mysql/lib/client.ts @@ -39,12 +39,9 @@ export class AwsMySQLClient extends AwsClient { [DatabaseDialectCodes.AURORA_MYSQL, new AuroraMySQLDatabaseDialect()], [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL, new RdsMultiAZMySQLDatabaseDialect()] ]); - private isAutoCommit: boolean = true; - private catalog = ""; constructor(config: any) { super(config, DatabaseType.MYSQL, AwsMySQLClient.knownDialectsByCode, new MySQLConnectionUrlParser(), new MySQL2DriverDialect()); - this.resetState(); } async connect(): Promise { @@ -82,7 +79,7 @@ export class AwsMySQLClient extends AwsClient { }); } - private async readOnlyQuery(options: QueryOptions, callback?: any): Promise { + private async queryWithoutUpdate(options: QueryOptions, callback?: any): Promise { const host = this.pluginService.getCurrentHostInfo(); return this.pluginManager.execute( @@ -96,61 +93,44 @@ export class AwsMySQLClient extends AwsClient { ); } - async updateSessionStateReadOnly(readOnly: boolean): Promise { - const result = await this.targetClient.queryWithTimeout(`SET SESSION TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}`); - - this._isReadOnly = readOnly; - this.pluginService.getSessionStateService().setupPristineReadOnly(); - this.pluginService.getSessionStateService().setReadOnly(readOnly); - return result; - } - async setReadOnly(readOnly: boolean): Promise { - const result = await this.readOnlyQuery({ sql: `SET SESSION TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}` }); - this._isReadOnly = readOnly; this.pluginService.getSessionStateService().setupPristineReadOnly(); - this.pluginService.getSessionStateService().setReadOnly(readOnly); + const result = await this.queryWithoutUpdate({ sql: `SET SESSION TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}` }); + this.pluginService.getSessionStateService().updateReadOnly(readOnly); return result; } isReadOnly(): boolean { - return this._isReadOnly; + return this.pluginService.getSessionStateService().getReadOnly(); } async setAutoCommit(autoCommit: boolean): Promise { - if (autoCommit === this.getAutoCommit()) { - return; - } - this.pluginService.getSessionStateService().setupPristineAutoCommit(); - this.pluginService.getSessionStateService().setAutoCommit(autoCommit); - this.isAutoCommit = autoCommit; let setting = "1"; if (!autoCommit) { setting = "0"; } - return await this.query({ sql: `SET AUTOCOMMIT=${setting}` }); + const result = await this.queryWithoutUpdate({ sql: `SET AUTOCOMMIT=${setting}` }); + this.pluginService.getSessionStateService().setAutoCommit(autoCommit); + return result; } getAutoCommit(): boolean { - return this.isAutoCommit; + return this.pluginService.getSessionStateService().getAutoCommit(); } async setCatalog(catalog: string): Promise { - if (catalog === this.getCatalog()) { + if (!catalog) { return; } - this.pluginService.getSessionStateService().setupPristineCatalog(); + await this.queryWithoutUpdate({ sql: `USE ${catalog}` }); this.pluginService.getSessionStateService().setCatalog(catalog); - - this.catalog = catalog; - await this.query({ sql: `USE ${catalog}` }); } getCatalog(): string { - return this.catalog; + return this.pluginService.getSessionStateService().getCatalog(); } async setSchema(schema: string): Promise { @@ -162,36 +142,34 @@ export class AwsMySQLClient extends AwsClient { } async setTransactionIsolation(level: TransactionIsolationLevel): Promise { - if (level === this.getTransactionIsolation()) { + if (level == this.getTransactionIsolation()) { return; } this.pluginService.getSessionStateService().setupPristineTransactionIsolation(); - this.pluginService.getSessionStateService().setTransactionIsolation(level); - this._isolationLevel = level; switch (level) { - case 0: - await this.query({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" }); + case TransactionIsolationLevel.TRANSACTION_READ_UNCOMMITTED: + await this.queryWithoutUpdate({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL READ UNCOMMITTED" }); break; - case 1: - await this.query({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED" }); + case TransactionIsolationLevel.TRANSACTION_READ_COMMITTED: + await this.queryWithoutUpdate({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED" }); break; - case 2: - await this.query({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" }); + case TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ: + await this.queryWithoutUpdate({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" }); break; - case 3: - await this.query({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE" }); + case TransactionIsolationLevel.TRANSACTION_SERIALIZABLE: + await this.queryWithoutUpdate({ sql: "SET SESSION TRANSACTION ISOLATION LEVEL SERIALIZABLE" }); break; default: throw new AwsWrapperError(Messages.get("Client.invalidTransactionIsolationLevel", String(level))); } - this._isolationLevel = level; + this.pluginService.getSessionStateService().setTransactionIsolation(level); } - getTransactionIsolation(): number { - return this._isolationLevel; + getTransactionIsolation(): TransactionIsolationLevel { + return this.pluginService.getSessionStateService().getTransactionIsolation(); } async end() { @@ -201,7 +179,7 @@ export class AwsMySQLClient extends AwsClient { return; } - const result = await this.pluginManager.execute( + return await this.pluginManager.execute( this.pluginService.getCurrentHostInfo(), this.properties, "end", @@ -212,7 +190,6 @@ export class AwsMySQLClient extends AwsClient { }, null ); - return result; } async rollback(): Promise { @@ -230,11 +207,4 @@ export class AwsMySQLClient extends AwsClient { null ); } - - resetState() { - this._isReadOnly = false; - this.isAutoCommit = true; - this.catalog = ""; - this._isolationLevel = TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ; - } } diff --git a/mysql/lib/dialect/mysql_database_dialect.ts b/mysql/lib/dialect/mysql_database_dialect.ts index eadd44c5..1b122297 100644 --- a/mysql/lib/dialect/mysql_database_dialect.ts +++ b/mysql/lib/dialect/mysql_database_dialect.ts @@ -18,7 +18,7 @@ import { DatabaseDialect, DatabaseType } from "../../../common/lib/database_dial import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { ConnectionStringHostListProvider } from "../../../common/lib/host_list_provider/connection_string_host_list_provider"; -import { AwsWrapperError } from "../../../common/lib/utils/errors"; +import { AwsWrapperError, UnsupportedMethodError } from "../../../common/lib/utils/errors"; import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes"; import { TransactionIsolationLevel } from "../../../common/lib/utils/transaction_isolation_level"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; @@ -26,7 +26,8 @@ import { ClientUtils } from "../../../common/lib/utils/client_utils"; import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction"; import { ErrorHandler } from "../../../common/lib/error_handler"; import { MySQLErrorHandler } from "../mysql_error_handler"; -import { error } from "winston"; +import { SessionState } from "../../../common/lib/session_state"; +import { Messages } from "../../../common/lib/utils/messages"; export class MySQLDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -59,6 +60,43 @@ export class MySQLDatabaseDialect implements DatabaseDialect { return "SHOW VARIABLES LIKE 'version_comment'"; } + getSetReadOnlyQuery(readOnly: boolean): string { + return `SET SESSION TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}`; + } + + getSetAutoCommitQuery(autoCommit: boolean): string { + return `SET AUTOCOMMIT=${autoCommit}`; + } + + getSetTransactionIsolationQuery(level: TransactionIsolationLevel): string { + let transactionIsolationLevel: string; + switch (level) { + case TransactionIsolationLevel.TRANSACTION_READ_UNCOMMITTED: + transactionIsolationLevel = "READ UNCOMMITTED"; + break; + case TransactionIsolationLevel.TRANSACTION_READ_COMMITTED: + transactionIsolationLevel = "READ COMMITTED"; + break; + case TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ: + transactionIsolationLevel = "REPEATABLE READ"; + break; + case TransactionIsolationLevel.TRANSACTION_SERIALIZABLE: + transactionIsolationLevel = "SERIALIZABLE"; + break; + default: + throw new AwsWrapperError(Messages.get("Client.invalidTransactionIsolationLevel", String(level))); + } + return `SET SESSION TRANSACTION ISOLATION LEVEL ${transactionIsolationLevel}`; + } + + getSetCatalogQuery(catalog: string): string { + return `USE ${catalog}`; + } + + getSetSchemaQuery(schema: string): string { + throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "setSchema")); + } + async isDialect(targetClient: ClientWrapper): Promise { return await targetClient .query(this.getServerVersionQuery()) diff --git a/pg/lib/client.ts b/pg/lib/client.ts index 0138e7ad..4ba7e53a 100644 --- a/pg/lib/client.ts +++ b/pg/lib/client.ts @@ -24,12 +24,12 @@ import { PgDatabaseDialect } from "./dialect/pg_database_dialect"; import { AuroraPgDatabaseDialect } from "./dialect/aurora_pg_database_dialect"; import { AwsWrapperError, UnsupportedMethodError } from "../../common/lib/utils/errors"; import { Messages } from "../../common/lib/utils/messages"; -import { TransactionIsolationLevel } from "../../common/lib/utils/transaction_isolation_level"; import { ClientWrapper } from "../../common/lib/client_wrapper"; import { RdsMultiAZPgDatabaseDialect } from "./dialect/rds_multi_az_pg_database_dialect"; import { HostInfo } from "../../common/lib/host_info"; import { TelemetryTraceLevel } from "../../common/lib/utils/telemetry/telemetry_trace_level"; import { NodePostgresDriverDialect } from "./dialect/node_postgres_driver_dialect"; +import { TransactionIsolationLevel } from "../../common/lib/utils/transaction_isolation_level"; export class AwsPGClient extends AwsClient { private static readonly knownDialectsByCode: Map = new Map([ @@ -38,11 +38,9 @@ export class AwsPGClient extends AwsClient { [DatabaseDialectCodes.AURORA_PG, new AuroraPgDatabaseDialect()], [DatabaseDialectCodes.RDS_MULTI_AZ_PG, new RdsMultiAZPgDatabaseDialect()] ]); - private schema: string = ""; constructor(config: any) { super(config, DatabaseType.POSTGRES, AwsPGClient.knownDialectsByCode, new PgConnectionUrlParser(), new NodePostgresDriverDialect()); - this.resetState(); } async connect(): Promise { @@ -75,11 +73,7 @@ export class AwsPGClient extends AwsClient { }); } - async updateSessionStateReadOnly(readOnly: boolean): Promise { - return await this.targetClient.query(`SET SESSION CHARACTERISTICS AS TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}`); - } - - private async readOnlyQuery(text: string): Promise { + private async queryWithoutUpdate(text: string): Promise { return this.pluginManager.execute( this.pluginService.getCurrentHostInfo(), this.properties, @@ -92,20 +86,14 @@ export class AwsPGClient extends AwsClient { } async setReadOnly(readOnly: boolean): Promise { - let result; - if (readOnly) { - result = await this.readOnlyQuery("SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY"); - } else { - result = await this.readOnlyQuery("SET SESSION CHARACTERISTICS AS TRANSACTION READ WRITE"); - } - this._isReadOnly = readOnly; this.pluginService.getSessionStateService().setupPristineReadOnly(); - this.pluginService.getSessionStateService().setReadOnly(readOnly); + const result = await this.queryWithoutUpdate(`SET SESSION CHARACTERISTICS AS TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}`); + this.pluginService.getSessionStateService().updateReadOnly(readOnly); return result; } isReadOnly(): boolean { - return this._isReadOnly; + return this.pluginService.getSessionStateService().getReadOnly(); } async setAutoCommit(autoCommit: boolean): Promise { @@ -116,35 +104,34 @@ export class AwsPGClient extends AwsClient { throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "getAutoCommit")); } - async setTransactionIsolation(level: number): Promise { - if (level === this.getTransactionIsolation()) { + async setTransactionIsolation(level: TransactionIsolationLevel): Promise { + if (level == this.getTransactionIsolation()) { return; } this.pluginService.getSessionStateService().setupPristineTransactionIsolation(); - this.pluginService.getSessionStateService().setTransactionIsolation(level); - this._isolationLevel = level; switch (level) { - case 0: - await this.query("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"); + case TransactionIsolationLevel.TRANSACTION_READ_UNCOMMITTED: + await this.queryWithoutUpdate("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ UNCOMMITTED"); break; - case 1: - await this.query("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"); + case TransactionIsolationLevel.TRANSACTION_READ_COMMITTED: + await this.queryWithoutUpdate("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"); break; - case 2: - await this.query("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL REPEATABLE READ"); + case TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ: + await this.queryWithoutUpdate("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL REPEATABLE READ"); break; - case 3: - await this.query("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL SERIALIZABLE"); + case TransactionIsolationLevel.TRANSACTION_SERIALIZABLE: + await this.queryWithoutUpdate("SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL SERIALIZABLE"); break; default: throw new AwsWrapperError(Messages.get("Client.invalidTransactionIsolationLevel", String(level))); } + this.pluginService.getSessionStateService().setTransactionIsolation(level); } - getTransactionIsolation(): number { - return this._isolationLevel; + getTransactionIsolation(): TransactionIsolationLevel { + return this.pluginService.getSessionStateService().getTransactionIsolation(); } async setCatalog(catalog: string): Promise { @@ -156,19 +143,22 @@ export class AwsPGClient extends AwsClient { } async setSchema(schema: string): Promise { + if (!schema) { + return; + } + if (schema === this.getSchema()) { return; } this.pluginService.getSessionStateService().setupPristineSchema(); + const result = await this.queryWithoutUpdate(`SET search_path TO ${schema};`); this.pluginService.getSessionStateService().setSchema(schema); - - this.schema = schema; - return await this.query(`SET search_path TO ${schema};`); + return result; } getSchema(): string { - return this.schema; + return this.pluginService.getSessionStateService().getSchema(); } async end() { @@ -178,7 +168,7 @@ export class AwsPGClient extends AwsClient { return; } const hostInfo: HostInfo | null = this.pluginService.getCurrentHostInfo(); - const result = await this.pluginManager.execute( + return await this.pluginManager.execute( hostInfo, this.properties, "end", @@ -189,7 +179,6 @@ export class AwsPGClient extends AwsClient { }, null ); - return result; } async rollback(): Promise { @@ -207,10 +196,4 @@ export class AwsPGClient extends AwsClient { null ); } - - resetState() { - this._isReadOnly = false; - this.schema = ""; - this._isolationLevel = TransactionIsolationLevel.TRANSACTION_READ_COMMITTED; - } } diff --git a/pg/lib/dialect/pg_database_dialect.ts b/pg/lib/dialect/pg_database_dialect.ts index 395cf608..9d1d2d10 100644 --- a/pg/lib/dialect/pg_database_dialect.ts +++ b/pg/lib/dialect/pg_database_dialect.ts @@ -18,13 +18,15 @@ import { DatabaseDialect, DatabaseType } from "../../../common/lib/database_dial import { HostListProviderService } from "../../../common/lib/host_list_provider_service"; import { HostListProvider } from "../../../common/lib/host_list_provider/host_list_provider"; import { ConnectionStringHostListProvider } from "../../../common/lib/host_list_provider/connection_string_host_list_provider"; -import { AwsWrapperError } from "../../../common/lib/utils/errors"; +import { AwsWrapperError, UnsupportedMethodError } from "../../../common/lib/utils/errors"; import { DatabaseDialectCodes } from "../../../common/lib/database_dialect/database_dialect_codes"; import { TransactionIsolationLevel } from "../../../common/lib/utils/transaction_isolation_level"; import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction"; import { ErrorHandler } from "../../../common/lib/error_handler"; import { PgErrorHandler } from "../pg_error_handler"; +import { SessionState } from "../../../common/lib/session_state"; +import { Messages } from "../../../common/lib/utils/messages"; export class PgDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -57,6 +59,43 @@ export class PgDatabaseDialect implements DatabaseDialect { return "SELECT 'version', VERSION()"; } + getSetReadOnlyQuery(readOnly: boolean): string { + return `SET SESSION CHARACTERISTICS AS TRANSACTION READ ${readOnly ? "ONLY" : "WRITE"}`; + } + + getSetAutoCommitQuery(autoCommit: boolean): string { + throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "setAutoCommit")); + } + + getSetTransactionIsolationQuery(level: TransactionIsolationLevel): string { + let transactionIsolationLevel: string; + switch (level) { + case TransactionIsolationLevel.TRANSACTION_READ_UNCOMMITTED: + transactionIsolationLevel = "READ UNCOMMITTED"; + break; + case TransactionIsolationLevel.TRANSACTION_READ_COMMITTED: + transactionIsolationLevel = "READ COMMITTED"; + break; + case TransactionIsolationLevel.TRANSACTION_REPEATABLE_READ: + transactionIsolationLevel = "REPEATABLE READ"; + break; + case TransactionIsolationLevel.TRANSACTION_SERIALIZABLE: + transactionIsolationLevel = "SERIALIZABLE"; + break; + default: + throw new AwsWrapperError(Messages.get("Client.invalidTransactionIsolationLevel", String(level))); + } + return `SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL ${transactionIsolationLevel}`; + } + + getSetCatalogQuery(catalog: string): string { + throw new UnsupportedMethodError(Messages.get("Client.methodNotSupported", "setCatalog")); + } + + getSetSchemaQuery(schema: string): string { + return `SET search_path TO ${schema}`; + } + async isDialect(targetClient: ClientWrapper): Promise { return await targetClient .query("SELECT 1 FROM pg_proc LIMIT 1") @@ -131,7 +170,7 @@ export class PgDatabaseDialect implements DatabaseDialect { return undefined; } - doesStatementSetTransactionIsolation(statement: string): number | undefined { + doesStatementSetTransactionIsolation(statement: string): TransactionIsolationLevel | undefined { if (statement.toLowerCase().includes("set session characteristics as transaction isolation level read uncommitted")) { return TransactionIsolationLevel.TRANSACTION_READ_COMMITTED; } diff --git a/tests/integration/container/tests/aurora_failover.test.ts b/tests/integration/container/tests/aurora_failover.test.ts index d425c378..ac7c8f45 100644 --- a/tests/integration/container/tests/aurora_failover.test.ts +++ b/tests/integration/container/tests/aurora_failover.test.ts @@ -26,6 +26,7 @@ import { logger } from "../../../../common/logutils"; import { features, instanceCount } from "./config"; import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { PluginManager } from "../../../../common/lib"; +import { TransactionIsolationLevel } from "../../../../common/lib/utils/transaction_isolation_level"; const itIf = features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && @@ -48,7 +49,7 @@ async function initDefaultConfig(host: string, port: number, connectToProxy: boo let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "failover", @@ -182,7 +183,14 @@ describe("aurora failover", () => { expect(await auroraTestUtility.isDbInstanceWriter(initialWriterId)).toBe(true); await client.setReadOnly(true); - const writerId = await auroraTestUtility.queryInstanceId(client); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + + if (driver === DatabaseEngine.PG) { + await client.setSchema(env.databaseInfo.defaultDbName); + } else if (driver === DatabaseEngine.MYSQL) { + await client.setAutoCommit(false); + await client.setCatalog(env.databaseInfo.defaultDbName); + } // Failover cluster and nominate a new writer await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(); @@ -195,6 +203,14 @@ describe("aurora failover", () => { const currentConnectionId = await auroraTestUtility.queryInstanceId(client); expect(await auroraTestUtility.isDbInstanceWriter(currentConnectionId)).toBe(true); expect(currentConnectionId).not.toBe(initialWriterId); + expect(client.isReadOnly()).toBe(true); + expect(client.getTransactionIsolation()).toBe(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + if (driver === DatabaseEngine.PG) { + expect(client.getSchema()).toBe(env.databaseInfo.defaultDbName); + } else if (driver === DatabaseEngine.MYSQL) { + expect(client.getAutoCommit()).toBe(false); + expect(client.getCatalog()).toBe(env.databaseInfo.defaultDbName); + } }, 1320000 ); diff --git a/tests/integration/container/tests/autoscaling.test.ts b/tests/integration/container/tests/autoscaling.test.ts index 25435a32..46818447 100644 --- a/tests/integration/container/tests/autoscaling.test.ts +++ b/tests/integration/container/tests/autoscaling.test.ts @@ -22,7 +22,6 @@ import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features, instanceCount } from "./config"; import { InternalPooledConnectionProvider } from "../../../../common/lib/internal_pooled_connection_provider"; import { AwsPoolConfig } from "../../../../common/lib/aws_pool_config"; -import { ConnectionProviderManager } from "../../../../common/lib/connection_provider_manager"; import { TestInstanceInfo } from "./utils/test_instance_info"; import { sleep } from "../../../../common/lib/utils/utils"; import { FailoverSuccessError } from "../../../../common/lib/utils/errors"; @@ -50,7 +49,7 @@ async function initDefaultConfig(host: string, port: number, provider: InternalP let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "readWriteSplitting", @@ -70,7 +69,7 @@ async function initConfigWithFailover(host: string, port: number, provider: Inte let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "readWriteSplitting,failover", diff --git a/tests/integration/container/tests/basic_connectivity.test.ts b/tests/integration/container/tests/basic_connectivity.test.ts index 979027f7..45aec100 100644 --- a/tests/integration/container/tests/basic_connectivity.test.ts +++ b/tests/integration/container/tests/basic_connectivity.test.ts @@ -70,7 +70,7 @@ describe("basic_connectivity", () => { let props = { user: env.databaseInfo.username, host: env.databaseInfo.clusterReadOnlyEndpoint, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.databaseInfo.clusterEndpointPort, plugins: "failover,efm", @@ -96,7 +96,7 @@ describe("basic_connectivity", () => { let props = { user: env.databaseInfo.username, host: env.databaseInfo.clusterEndpoint, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.databaseInfo.clusterEndpointPort, plugins: "failover,efm", @@ -122,7 +122,7 @@ describe("basic_connectivity", () => { let props = { user: env.databaseInfo.username, host: env.databaseInfo.instances[0].host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.databaseInfo.clusterEndpointPort, plugins: "failover,efm", @@ -148,7 +148,7 @@ describe("basic_connectivity", () => { let props = { user: env.databaseInfo.username, host: env.databaseInfo.instances[0].host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.databaseInfo.instanceEndpointPort, plugins: "", @@ -178,7 +178,7 @@ describe("basic_connectivity", () => { let props = { user: env.databaseInfo.username, host: env.proxyDatabaseInfo.instances[0].host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.proxyDatabaseInfo.instanceEndpointPort, plugins: "", diff --git a/tests/integration/container/tests/iam_authentication.test.ts b/tests/integration/container/tests/iam_authentication.test.ts index a4eebc54..c9c00d7b 100644 --- a/tests/integration/container/tests/iam_authentication.test.ts +++ b/tests/integration/container/tests/iam_authentication.test.ts @@ -54,7 +54,7 @@ async function initDefaultConfig(host: string): Promise { let props = { user: "jane_doe", host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: env.databaseInfo.instanceEndpointPort, plugins: "iam", diff --git a/tests/integration/container/tests/performance.test.ts b/tests/integration/container/tests/performance.test.ts index 5d7c47f7..6144111e 100644 --- a/tests/integration/container/tests/performance.test.ts +++ b/tests/integration/container/tests/performance.test.ts @@ -75,7 +75,7 @@ function initDefaultConfig(host: string, port: number): any { let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, failoverTimeoutMs: 250000 diff --git a/tests/integration/container/tests/read_write_splitting.test.ts b/tests/integration/container/tests/read_write_splitting.test.ts index 2d3f2d55..d673dbb3 100644 --- a/tests/integration/container/tests/read_write_splitting.test.ts +++ b/tests/integration/container/tests/read_write_splitting.test.ts @@ -26,7 +26,6 @@ import { TestEnvironmentFeatures } from "./utils/test_environment_features"; import { features, instanceCount } from "./config"; import { InternalPooledConnectionProvider } from "../../../../common/lib/internal_pooled_connection_provider"; import { AwsPoolConfig } from "../../../../common/lib/aws_pool_config"; -import { ConnectionProviderManager } from "../../../../common/lib/connection_provider_manager"; import { InternalPoolMapping } from "../../../../common/lib/utils/internal_pool_mapping"; import { HostInfo } from "../../../../common/lib/host_info"; import { PluginManager } from "../../../../common/lib"; @@ -53,7 +52,7 @@ async function initDefaultConfig(host: string, port: number, connectToProxy: boo let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "readWriteSplitting", @@ -73,7 +72,7 @@ async function initConfigWithFailover(host: string, port: number, connectToProxy let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "readWriteSplitting,failover", diff --git a/tests/integration/container/tests/read_write_splitting_performance.test.ts b/tests/integration/container/tests/read_write_splitting_performance.test.ts index 5961e3ab..7fc222fa 100644 --- a/tests/integration/container/tests/read_write_splitting_performance.test.ts +++ b/tests/integration/container/tests/read_write_splitting_performance.test.ts @@ -24,7 +24,6 @@ import { PerfTestUtility } from "./utils/perf_util"; import { ConnectTimePlugin } from "../../../../common/lib/plugins/connect_time_plugin"; import { ExecuteTimePlugin } from "../../../../common/lib/plugins/execute_time_plugin"; import { TestDriver } from "./utils/test_driver"; -import { ConnectionProviderManager } from "../../../../common/lib/connection_provider_manager"; import { InternalPooledConnectionProvider } from "../../../../common/lib/internal_pooled_connection_provider"; import { PluginManager } from "../../../../common/lib"; diff --git a/tests/integration/container/tests/session_state.test.ts b/tests/integration/container/tests/session_state.test.ts new file mode 100644 index 00000000..3ce23be8 --- /dev/null +++ b/tests/integration/container/tests/session_state.test.ts @@ -0,0 +1,184 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { TestEnvironment } from "./utils/test_environment"; +import { ProxyHelper } from "./utils/proxy_helper"; +import { DriverHelper } from "./utils/driver_helper"; +import { logger } from "../../../../common/logutils"; +import { DatabaseEngine } from "./utils/database_engine"; +import { TestEnvironmentFeatures } from "./utils/test_environment_features"; +import { features } from "./config"; +import { DatabaseEngineDeployment } from "./utils/database_engine_deployment"; +import { PluginManager } from "../../../../common/lib"; +import { AwsPGClient } from "../../../../pg/lib"; +import { PluginService } from "../../../../common/lib/plugin_service"; +import { TestDriver } from "./utils/test_driver"; +import { AwsMySQLClient } from "../../../../mysql/lib"; +import { TransactionIsolationLevel } from "../../../../common/lib/utils/transaction_isolation_level"; + +const itIf = + !features.includes(TestEnvironmentFeatures.PERFORMANCE) && !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) ? it : it.skip; + +let client: any; + +async function executeInstanceQuery(client: any, engine: DatabaseEngine, deployment: DatabaseEngineDeployment, props: any): Promise { + await client.connect(); + + const res = await DriverHelper.executeInstanceQuery(engine, deployment, client); + + expect(res).not.toBeNull(); +} + +beforeEach(async () => { + logger.info(`Test started: ${expect.getState().currentTestName}`); + await ProxyHelper.enableAllConnectivity(); + await TestEnvironment.verifyClusterStatus(); + client = null; +}, 1320000); + +afterEach(async () => { + if (client !== null) { + try { + await client.end(); + } catch (error) { + // pass + } + } + await PluginManager.releaseResources(); + logger.info(`Test finished: ${expect.getState().currentTestName}`); +}, 1320000); + +class TestAwsMySQLClient extends AwsMySQLClient { + getPluginService(): PluginService { + return this.pluginService; + } +} + +class TestAwsPGClient extends AwsPGClient { + getPluginService(): PluginService { + return this.pluginService; + } +} + +describe("session state", () => { + it.only("test update state", async () => { + const env = await TestEnvironment.getCurrent(); + const driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + let initClientFunc; + switch (driver) { + case TestDriver.MYSQL: + initClientFunc = (options: any) => new TestAwsMySQLClient(options); + break; + case TestDriver.PG: + initClientFunc = (options: any) => new TestAwsPGClient(options); + break; + default: + throw new Error("invalid driver"); + } + + let props = { + user: env.databaseInfo.username, + host: env.databaseInfo.clusterEndpoint, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: env.databaseInfo.clusterEndpointPort + }; + props = DriverHelper.addDriverSpecificConfiguration(props, env.engine); + client = initClientFunc(props); + + const newClient = initClientFunc(props); + + try { + await client.connect(); + await newClient.connect(); + const targetClient = client.targetClient; + const newTargetClient = newClient.targetClient; + + expect(targetClient).not.toEqual(newTargetClient); + if (driver === TestDriver.MYSQL) { + await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE IF NOT EXISTS testSessionState"); + await client.setReadOnly(true); + await client.setCatalog("testSessionState"); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + await client.setAutoCommit(false); + + // Assert new client's session states are using server default values. + let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); + let catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); + let autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); + let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); + expect(readOnly[0][0].readonly).toEqual(0); + expect(catalog[0][0].catalog).toEqual(env.databaseInfo.defaultDbName); + expect(autoCommit[0][0].autocommit).toEqual(1); + expect(transactionIsolation[0][0].level).toEqual("REPEATABLE-READ"); + + await client.getPluginService().setCurrentClient(newClient.targetClient); + + expect(client.targetClient).not.toEqual(targetClient); + expect(client.targetClient).toEqual(newTargetClient); + + // Assert new client's session states are set. + readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); + catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); + autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); + transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); + expect(readOnly[0][0].readonly).toEqual(1); + expect(catalog[0][0].catalog).toEqual("testSessionState"); + expect(autoCommit[0][0].autocommit).toEqual(0); + expect(transactionIsolation[0][0].level).toEqual("SERIALIZABLE"); + + await client.setReadOnly(false); + await client.setAutoCommit(true); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + } else if (driver === TestDriver.PG) { + // End any current transaction before we can create a new test database. + await DriverHelper.executeQuery(env.engine, client, "END TRANSACTION"); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE testSessionState"); + await client.setReadOnly(true); + await client.setSchema("testSessionState"); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + + // Assert new client's session states are using server default values. + let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); + let schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); + let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); + expect(readOnly.rows[0]["transaction_read_only"]).toEqual("off"); + expect(schema.rows[0]["search_path"]).not.toEqual("testSessionState"); + expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("read committed"); + + await client.getPluginService().setCurrentClient(newClient.targetClient); + expect(client.targetClient).not.toEqual(targetClient); + expect(client.targetClient).toEqual(newTargetClient); + + // Assert new client's session states are set. + readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); + schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); + transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); + expect(readOnly.rows[0]["transaction_read_only"]).toEqual("on"); + expect(schema.rows[0]["search_path"]).toEqual("testsessionstate"); + expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("serializable"); + + await client.setReadOnly(false); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + } + } catch (e) { + await client.end(); + await newClient.end(); + throw e; + } + }, 1320000); +}); diff --git a/tests/integration/container/tests/utils/perf_util.ts b/tests/integration/container/tests/utils/perf_util.ts index 34ef7370..94ffd372 100644 --- a/tests/integration/container/tests/utils/perf_util.ts +++ b/tests/integration/container/tests/utils/perf_util.ts @@ -25,7 +25,7 @@ export class PerfTestUtility { let config: any = { user: env.databaseInfo.username, host: host, - database: env.databaseInfo.default_db_name, + database: env.databaseInfo.defaultDbName, password: env.databaseInfo.password, port: port, plugins: "connectTime,executeTime" diff --git a/tests/integration/container/tests/utils/test_database_info.ts b/tests/integration/container/tests/utils/test_database_info.ts index 19bc25ee..abfa1b5b 100644 --- a/tests/integration/container/tests/utils/test_database_info.ts +++ b/tests/integration/container/tests/utils/test_database_info.ts @@ -20,7 +20,7 @@ import { DBInstance } from "@aws-sdk/client-rds/dist-types/models/models_0"; export class TestDatabaseInfo { private readonly _username: string; private readonly _password: string; - private readonly _default_db_name: string; + private readonly _defaultDbName: string; private readonly _clusterEndpoint: string; private readonly _clusterEndpointPort: number; private readonly _clusterReadOnlyEndpoint: string; @@ -32,7 +32,7 @@ export class TestDatabaseInfo { constructor(databaseInfo: { [s: string]: any }) { this._username = String(databaseInfo["username"]); this._password = String(databaseInfo["password"]); - this._default_db_name = String(databaseInfo["defaultDbName"]); + this._defaultDbName = String(databaseInfo["defaultDbName"]); this._clusterEndpoint = String(databaseInfo["clusterEndpoint"]); this._clusterEndpointPort = Number(databaseInfo["clusterEndpointPort"]); this._clusterReadOnlyEndpoint = String(databaseInfo["clusterReadOnlyEndpoint"]); @@ -53,8 +53,8 @@ export class TestDatabaseInfo { return this._password; } - get default_db_name(): string { - return this._default_db_name; + get defaultDbName(): string { + return this._defaultDbName; } get writerInstanceEndpoint() { diff --git a/tests/integration/container/tests/utils/test_environment.ts b/tests/integration/container/tests/utils/test_environment.ts index ba0b2bb3..a0c4243f 100644 --- a/tests/integration/container/tests/utils/test_environment.ts +++ b/tests/integration/container/tests/utils/test_environment.ts @@ -111,7 +111,7 @@ export class TestEnvironment { port: info?.databaseInfo.instanceEndpointPort ?? 5432, user: info?.databaseInfo.username, password: info?.databaseInfo.password, - database: info?.databaseInfo.default_db_name, + database: info?.databaseInfo.defaultDbName, query_timeout: 3000, connectionTimeoutMillis: 3000 }); @@ -135,7 +135,7 @@ export class TestEnvironment { port: info?.databaseInfo.instanceEndpointPort ?? 3306, user: info?.databaseInfo.username, password: info?.databaseInfo.password, - database: info?.databaseInfo.default_db_name, + database: info?.databaseInfo.defaultDbName, connectTimeout: 3000 } as ConnectionOptions); diff --git a/tests/plugin_manager_benchmarks.ts b/tests/plugin_manager_benchmarks.ts index cb318e4c..6fd09a29 100644 --- a/tests/plugin_manager_benchmarks.ts +++ b/tests/plugin_manager_benchmarks.ts @@ -283,7 +283,7 @@ suite( }), add("initHostProviderWith10Plugins", async () => { - const pluginManagerWithPlugins = await initPluginManagerWithPlugins(10, instance(mockPluginService), propsWithPlugins);; + const pluginManagerWithPlugins = await initPluginManagerWithPlugins(10, instance(mockPluginService), propsWithPlugins); return async () => await pluginManagerWithPlugins.initHostProvider( new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }).withHost("host").build(), @@ -325,7 +325,7 @@ suite( }), add("notifyConnectionChangedWith10Plugins", async () => { - const pluginManagerWithPlugins = await initPluginManagerWithPlugins(10, instance(mockPluginService), propsWithPlugins);; + const pluginManagerWithPlugins = await initPluginManagerWithPlugins(10, instance(mockPluginService), propsWithPlugins); return async () => await pluginManagerWithPlugins.notifyConnectionChanged(new Set([HostChangeOptions.INITIAL_CONNECTION]), null); }), diff --git a/tests/unit/session_state_service_impl.test.ts b/tests/unit/session_state_service_impl.test.ts index 4c25a04b..107dc803 100644 --- a/tests/unit/session_state_service_impl.test.ts +++ b/tests/unit/session_state_service_impl.test.ts @@ -14,20 +14,32 @@ limitations under the License. */ -import { anything, instance, mock, reset, spy, verify, when } from "ts-mockito"; +import { anything, instance, mock, reset, spy, when } from "ts-mockito"; import { SessionStateServiceImpl } from "../../common/lib/session_state_service_impl"; import { PluginService } from "../../common/lib/plugin_service"; import { AwsPGClient } from "../../pg/lib"; import { SessionStateService } from "../../common/lib/session_state_service"; import { AwsClient } from "../../common/lib/aws_client"; import { AwsMySQLClient } from "../../mysql/lib"; - +import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; +import { HostInfoBuilder } from "../../common/lib/host_info_builder"; +import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; +import { MySQLDatabaseDialect } from "../../mysql/lib/dialect/mysql_database_dialect"; +import { PgDatabaseDialect } from "../../pg/lib/dialect/pg_database_dialect"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; + +const hostInfoBuilder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); const mockPluginService = mock(PluginService); let awsPGClient: AwsClient; let mockAwsPGClient: AwsClient; let awsMySQLClient: AwsClient; let mockAwsMySQLClient: AwsClient; let sessionStateService: SessionStateService; +const mockPgClientWrapper: PgClientWrapper = mock(PgClientWrapper); +const mockMySQLClientWrapper: MySQLClientWrapper = mock(MySQLClientWrapper); +const hostInfo = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }).withHost("host").build(); +const mockMySQLDriverDialect = mock(MySQL2DriverDialect); describe("testSessionStateServiceImpl", () => { beforeEach(() => { @@ -36,6 +48,13 @@ describe("testSessionStateServiceImpl", () => { awsMySQLClient = new AwsMySQLClient({}); mockAwsMySQLClient = spy(awsMySQLClient); sessionStateService = new SessionStateServiceImpl(instance(mockPluginService), new Map()); + awsPGClient.targetClient = new PgClientWrapper(undefined, hostInfoBuilder.withHost("host").build(), new Map()); + mockAwsPGClient.targetClient = new PgClientWrapper(undefined, hostInfoBuilder.withHost("host").build(), new Map()); + awsMySQLClient.targetClient = new MySQLClientWrapper(undefined, hostInfoBuilder.withHost("host").build(), new Map(), mockMySQLDriverDialect); + mockAwsMySQLClient.targetClient = new MySQLClientWrapper(undefined, hostInfoBuilder.withHost("host").build(), new Map(), mockMySQLDriverDialect); + when(mockMySQLClientWrapper.query(anything())).thenResolve(); + when(mockPgClientWrapper.query(anything())).thenResolve(); + when(mockPluginService.getSessionStateService()).thenReturn(sessionStateService); }); afterEach(() => { @@ -55,10 +74,16 @@ describe("testSessionStateServiceImpl", () => { ])("test reset client readOnly", async (pristineValue: boolean, value: boolean, shouldReset: boolean, driver: number) => { const mockAwsClient = driver === 0 ? mockAwsPGClient : mockAwsMySQLClient; const awsClient = driver === 0 ? awsPGClient : awsMySQLClient; + if (driver === 0) { + awsClient.targetClient = new PgClientWrapper(undefined, hostInfo, new Map()); + when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); + } else { + awsClient.targetClient = new MySQLClientWrapper(undefined, hostInfo, new Map(), mockMySQLDriverDialect); + when(mockPluginService.getDialect()).thenReturn(new MySQLDatabaseDialect()); + } when(mockPluginService.getCurrentClient()).thenReturn(awsClient); when(mockAwsClient.isReadOnly()).thenReturn(pristineValue); expect(sessionStateService.getReadOnly()).toBe(undefined); - when(mockAwsClient.updateSessionStateReadOnly(anything())).thenResolve(); sessionStateService.setupPristineReadOnly(); sessionStateService.setReadOnly(value); expect(sessionStateService.getReadOnly()).toBe(value); @@ -66,11 +91,12 @@ describe("testSessionStateServiceImpl", () => { sessionStateService.begin(); await sessionStateService.applyPristineSessionState(awsClient); sessionStateService.complete(); - if (shouldReset) { - verify(mockAwsClient.updateSessionStateReadOnly(pristineValue)).once(); + // Should reset to pristine value + expect(sessionStateService.getReadOnly()).toBe(pristineValue); } else { - verify(mockAwsClient.updateSessionStateReadOnly(anything())).never(); + // No-op, value should stay unchanged + expect(sessionStateService.getReadOnly()).toBe(value); } }); @@ -83,6 +109,7 @@ describe("testSessionStateServiceImpl", () => { const mockAwsClient = driver === 0 ? mockAwsPGClient : mockAwsMySQLClient; const awsClient = driver === 0 ? awsPGClient : awsMySQLClient; + when(mockPluginService.getDialect()).thenReturn(new MySQLDatabaseDialect()); when(mockPluginService.getCurrentClient()).thenReturn(awsClient); when(mockAwsClient.getAutoCommit()).thenReturn(pristineValue); expect(sessionStateService.getAutoCommit()).toBe(undefined); @@ -95,9 +122,11 @@ describe("testSessionStateServiceImpl", () => { sessionStateService.complete(); if (shouldReset) { - verify(mockAwsClient.setAutoCommit(pristineValue)).once(); + // Should reset to pristine value + expect(sessionStateService.getAutoCommit()).toBe(pristineValue); } else { - verify(mockAwsClient.setAutoCommit(anything())).never(); + // No-op, value should stay unchanged + expect(sessionStateService.getAutoCommit()).toBe(value); } }); @@ -110,6 +139,7 @@ describe("testSessionStateServiceImpl", () => { const mockAwsClient = driver === 0 ? mockAwsPGClient : mockAwsMySQLClient; const awsClient = driver === 0 ? awsPGClient : awsMySQLClient; + when(mockPluginService.getDialect()).thenReturn(new MySQLDatabaseDialect()); when(mockPluginService.getCurrentClient()).thenReturn(awsClient); when(mockAwsClient.getCatalog()).thenReturn(pristineValue); expect(sessionStateService.getCatalog()).toBe(undefined); @@ -122,9 +152,11 @@ describe("testSessionStateServiceImpl", () => { sessionStateService.complete(); if (shouldReset) { - verify(mockAwsClient.setCatalog(pristineValue)).once(); + // Should reset to pristine value + expect(sessionStateService.getCatalog()).toBe(pristineValue); } else { - verify(mockAwsClient.setCatalog(anything())).never(); + // No-op, value should stay unchanged + expect(sessionStateService.getCatalog()).toBe(value); } }); @@ -137,6 +169,7 @@ describe("testSessionStateServiceImpl", () => { const mockAwsClient = driver === 0 ? mockAwsPGClient : mockAwsMySQLClient; const awsClient = driver === 0 ? awsPGClient : awsMySQLClient; + when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); when(mockPluginService.getCurrentClient()).thenReturn(awsClient); when(mockAwsClient.getSchema()).thenReturn(pristineValue); expect(sessionStateService.getSchema()).toBe(undefined); @@ -149,9 +182,11 @@ describe("testSessionStateServiceImpl", () => { sessionStateService.complete(); if (shouldReset) { - verify(mockAwsClient.setSchema(pristineValue)).once(); + // Should reset to pristine value + expect(sessionStateService.getSchema()).toBe(pristineValue); } else { - verify(mockAwsClient.setSchema(anything())).never(); + // No-op, value should stay unchanged + expect(sessionStateService.getSchema()).toBe(value); } }); @@ -168,6 +203,12 @@ describe("testSessionStateServiceImpl", () => { const mockAwsClient = driver === 0 ? mockAwsPGClient : mockAwsMySQLClient; const awsClient = driver === 0 ? awsPGClient : awsMySQLClient; + if (driver === 0) { + when(mockPluginService.getDialect()).thenReturn(new PgDatabaseDialect()); + } else { + when(mockPluginService.getDialect()).thenReturn(new MySQLDatabaseDialect()); + } + when(mockPluginService.getCurrentClient()).thenReturn(awsClient); when(mockAwsClient.getTransactionIsolation()).thenReturn(pristineValue); expect(sessionStateService.getTransactionIsolation()).toBe(undefined); @@ -180,9 +221,11 @@ describe("testSessionStateServiceImpl", () => { sessionStateService.complete(); if (shouldReset) { - verify(mockAwsClient.setTransactionIsolation(pristineValue)).once(); + // Should reset to pristine value + expect(sessionStateService.getTransactionIsolation()).toBe(pristineValue); } else { - verify(mockAwsClient.setTransactionIsolation(anything())).never(); + // No-op, value should stay unchanged + expect(sessionStateService.getTransactionIsolation()).toBe(value); } }); }); diff --git a/tests/unit/sql_method_utils.test.ts b/tests/unit/sql_method_utils.test.ts index ed6b9051..13b57e20 100644 --- a/tests/unit/sql_method_utils.test.ts +++ b/tests/unit/sql_method_utils.test.ts @@ -33,7 +33,8 @@ describe("test sql method utils", () => { ["commit", false], [" select 1", false], [" INSERT INTO test_table VALUES (1) ; ", false], - [" set autocommit = 1 ", false] + [" set autocommit = 1 ", false], + ["CREATE DATABASE fooBar", false] ])("test open transaction", (sql: string, expectedResult: boolean) => { expect(SqlMethodUtils.doesOpenTransaction(sql)).toBe(expectedResult); });