diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index 5572ac08..4571728e 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -151,13 +151,6 @@ export class PluginService implements ErrorHandler, HostListProviderService { return this.getCurrentClient().connectionUrlParser; } - getConnectionProvider(hostInfo: HostInfo | null, props: Map): ConnectionProvider { - if (!this.pluginServiceManagerContainer.pluginManager) { - throw new AwsWrapperError("Plugin manager should not be undefined"); - } - return this.pluginServiceManagerContainer.pluginManager.getConnectionProvider(hostInfo, props); - } - getDialect(): DatabaseDialect { return this.dialect; } diff --git a/common/lib/plugins/read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting_plugin.ts index 6a5fa789..b00334c4 100644 --- a/common/lib/plugins/read_write_splitting_plugin.ts +++ b/common/lib/plugins/read_write_splitting_plugin.ts @@ -30,6 +30,7 @@ import { ClientWrapper } from "../client_wrapper"; import { getWriter, logAndThrowError } from "../utils/utils"; import { CanReleaseResources } from "../can_release_resources"; import { InternalPooledConnectionProvider } from "../internal_pooled_connection_provider"; +import { PoolClientWrapper } from "../pool_client_wrapper"; export class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implements CanReleaseResources { private static readonly subscribedMethods: Set = new Set(["initHostProvider", "connect", "notifyConnectionChanged", "query"]); @@ -197,7 +198,7 @@ export class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implement props.set(WrapperProperties.HOST.name, writerHost.host); try { const targetClient = await this.pluginService.connect(writerHost, props); - this.isWriterClientFromInternalPool = this.pluginService.getConnectionProvider(writerHost, props) instanceof InternalPooledConnectionProvider; + this.isWriterClientFromInternalPool = targetClient instanceof PoolClientWrapper; this.setWriterClient(targetClient, writerHost); await this.switchCurrentTargetClientTo(this.writerTargetClient, writerHost); } catch (any) { @@ -290,7 +291,7 @@ export class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implement try { targetClient = await this.pluginService.connect(host, props); - this.isReaderClientFromInternalPool = this.pluginService.getConnectionProvider(host, props) instanceof InternalPooledConnectionProvider; + this.isReaderClientFromInternalPool = targetClient instanceof PoolClientWrapper; readerHost = host; break; } catch (any) { diff --git a/tests/unit/read_write_splitting.test.ts b/tests/unit/read_write_splitting.test.ts index 086216ee..511d7718 100644 --- a/tests/unit/read_write_splitting.test.ts +++ b/tests/unit/read_write_splitting.test.ts @@ -35,9 +35,11 @@ import { InternalPooledConnectionProvider } from "../../common/lib/internal_pool import { AwsPoolConfig } from "../../common/lib/aws_pool_config"; import { ConnectionProviderManager } from "../../common/lib/connection_provider_manager"; import { InternalPoolMapping } from "../../common/lib/utils/internal_pool_mapping"; -import { NodePostgresDriverDialect } from "../../pg/lib/dialect/node_postgres_driver_dialect"; import { DriverDialect } from "../../common/lib/driver_dialect/driver_dialect"; import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; +import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; +import { PoolClientWrapper } from "../../common/lib/pool_client_wrapper"; const properties: Map = new Map(); const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -63,14 +65,9 @@ const mockDialect: MySQLDatabaseDialect = mock(MySQLDatabaseDialect); const mockDriverDialect: DriverDialect = mock(MySQL2DriverDialect); const mockChanges: Set = mock(Set); -const clientWrapper: ClientWrapper = { - client: undefined, - hostInfo: mockHostInfo, - properties: new Map() -}; - -const mockReaderWrapper: ClientWrapper = mock(clientWrapper); -const mockWriterWrapper: ClientWrapper = mock(clientWrapper); +const mockReaderWrapper: ClientWrapper = mock(PgClientWrapper); +const mockWriterWrapper: ClientWrapper = mock(MySQLClientWrapper); +const poolClientWrapper: ClientWrapper = new PoolClientWrapper(undefined, writerHost, new Map()); const clientWrapper_undefined: any = undefined; @@ -89,7 +86,7 @@ describe("reader write splitting test", () => { when(mockPluginService.isInTransaction()).thenReturn(false); when(mockPluginService.getDialect()).thenReturn(mockDialect); when(mockPluginService.getDriverDialect()).thenReturn(mockDriverDialect); - when(mockDriverDialect.connect(anything())).thenReturn(Promise.resolve(mockReaderClient)); + when(mockDriverDialect.connect(anything(), anything())).thenReturn(Promise.resolve(mockReaderWrapper)); properties.clear(); }); @@ -411,11 +408,14 @@ describe("reader write splitting test", () => { when(mockPluginService.getCurrentClient()).thenReturn(instance(mockWriterClient)); when(await mockWriterClient.isValid()).thenReturn(true); when(mockPluginService.getCurrentHostInfo()).thenReturn(writerHost).thenReturn(writerHost).thenReturn(readerHost1); - when(mockDriverDialect.connect(anything())).thenReturn(Promise.resolve(mockReaderClient)); - when(mockPluginService.connect(anything(), anything())).thenResolve(mockReaderWrapper); - const config: AwsPoolConfig = new AwsPoolConfig({ idleTimeoutMillis: 7000, maxConnections: 10, maxIdleConnections: 10 }); + when(mockDriverDialect.connect(anything(), anything())).thenReturn(Promise.resolve(poolClientWrapper)); + when(mockPluginService.connect(anything(), anything())).thenResolve(poolClientWrapper); + const config: AwsPoolConfig = new AwsPoolConfig({ + idleTimeoutMillis: 7000, + maxConnections: 10, + maxIdleConnections: 10 + }); const provider: InternalPooledConnectionProvider = new InternalPooledConnectionProvider(config); - when(mockPluginService.getConnectionProvider(anything(), anything())).thenReturn(provider); ConnectionProviderManager.setConnectionProvider(provider); @@ -426,7 +426,7 @@ describe("reader write splitting test", () => { const spyTarget = instance(target); await spyTarget.switchClientIfRequired(true); await spyTarget.switchClientIfRequired(false); - verify(target.closeTargetClientIfIdle(mockReaderWrapper)).once(); + verify(target.closeTargetClientIfIdle(poolClientWrapper)).once(); }); it("test pooled writer connection after set read only", async () => { @@ -441,19 +441,22 @@ describe("reader write splitting test", () => { .thenReturn(readerHost1) .thenReturn(readerHost1) .thenReturn(writerHost); - when(mockDriverDialect.connect(anything())).thenReturn(Promise.resolve(mockWriterClient)); - when(mockPluginService.connect(writerHost, anything())).thenResolve(mockWriterWrapper); - when(mockPluginService.connect(readerHost1, anything())).thenResolve(mockReaderWrapper); - when(mockPluginService.connect(readerHost2, anything())).thenResolve(mockReaderWrapper); - - const config: AwsPoolConfig = new AwsPoolConfig({ idleTimeoutMillis: 7000, maxConnections: 10, maxIdleConnections: 10 }); + when(mockDriverDialect.connect(anything(), anything())).thenReturn(Promise.resolve(poolClientWrapper)); + when(mockPluginService.connect(writerHost, anything())).thenResolve(poolClientWrapper); + when(mockPluginService.connect(readerHost1, anything())).thenResolve(poolClientWrapper); + when(mockPluginService.connect(readerHost2, anything())).thenResolve(poolClientWrapper); + + const config: AwsPoolConfig = new AwsPoolConfig({ + idleTimeoutMillis: 7000, + maxConnections: 10, + maxIdleConnections: 10 + }); const myKeyFunc: InternalPoolMapping = { getKey: (hostInfo: HostInfo, props: Map) => { return hostInfo.url + "someKey"; } }; const provider: InternalPooledConnectionProvider = new InternalPooledConnectionProvider(config, myKeyFunc); - when(mockPluginService.getConnectionProvider(anything(), anything())).thenReturn(provider); ConnectionProviderManager.setConnectionProvider(provider); @@ -472,6 +475,6 @@ describe("reader write splitting test", () => { await spyTarget.switchClientIfRequired(false); await spyTarget.switchClientIfRequired(true); - verify(target.closeTargetClientIfIdle(mockWriterWrapper)).once(); + verify(target.closeTargetClientIfIdle(poolClientWrapper)).twice(); }); });