Skip to content

Commit

Permalink
feat: Configuration profiles (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiyvamz authored Dec 12, 2024
1 parent f91e999 commit f9ce248
Show file tree
Hide file tree
Showing 25 changed files with 1,078 additions and 124 deletions.
31 changes: 18 additions & 13 deletions common/lib/authentication/aws_credentials_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,34 @@
import { HostInfo } from "../host_info";
import { fromNodeProviderChain } from "@aws-sdk/credential-providers";
import { AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity";
import { WrapperProperties } from "../wrapper_property";
import { AwsWrapperError } from "../utils/errors";
import { Messages } from "../utils/messages";

interface AwsCredentialsProviderHandler {
getAwsCredentialsProvider(hostInfo: HostInfo, properties: Map<string, any>): AwsCredentialIdentityProvider;
}

export class AwsCredentialsManager {
private static handler?: AwsCredentialsProviderHandler;

static setCustomHandler(customHandler: AwsCredentialsProviderHandler) {
AwsCredentialsManager.handler = customHandler;
}

static getProvider(hostInfo: HostInfo, props: Map<string, any>): AwsCredentialIdentityProvider {
return AwsCredentialsManager.handler === undefined
? AwsCredentialsManager.getDefaultProvider()
: AwsCredentialsManager.handler.getAwsCredentialsProvider(hostInfo, props);
}
const awsCredentialProviderHandler = WrapperProperties.CUSTOM_AWS_CREDENTIAL_PROVIDER_HANDLER.get(props);
if (awsCredentialProviderHandler && !AwsCredentialsManager.isAwsCredentialsProviderHandler(awsCredentialProviderHandler)) {
throw new AwsWrapperError(Messages.get("AwsCredentialsManager.wrongHandler"));
}

static resetCustomHandler() {
AwsCredentialsManager.handler = undefined;
return !awsCredentialProviderHandler
? AwsCredentialsManager.getDefaultProvider(WrapperProperties.AWS_PROFILE.get(props))
: awsCredentialProviderHandler.getAwsCredentialsProvider(hostInfo, props);
}

private static getDefaultProvider() {
private static getDefaultProvider(profileName: string | null) {
if (profileName) {
return fromNodeProviderChain({ profile: profileName });
}
return fromNodeProviderChain();
}

private static isAwsCredentialsProviderHandler(arg: any): arg is AwsCredentialsProviderHandler {
return arg.getAwsCredentialsProvider !== undefined;
}
}
50 changes: 48 additions & 2 deletions common/lib/aws_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ import { DefaultTelemetryFactory } from "./utils/telemetry/default_telemetry_fac
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { DriverDialect } from "./driver_dialect/driver_dialect";
import { WrapperProperties } from "./wrapper_property";
import { DriverConfigurationProfiles } from "./profile/driver_configuration_profiles";
import { ConfigurationProfile } from "./profile/configuration_profile";
import { AwsWrapperError } from "./utils/errors";
import { Messages } from "./utils/messages";

export abstract class AwsClient extends EventEmitter {
private _defaultPort: number = -1;
Expand All @@ -41,6 +45,7 @@ export abstract class AwsClient extends EventEmitter {
protected _isReadOnly: boolean = false;
protected _isolationLevel: number = 0;
protected _connectionUrlParser: ConnectionUrlParser;
protected _configurationProfile: ConfigurationProfile | null = null;
readonly properties: Map<string, any>;
config: any;
targetClient?: ClientWrapper;
Expand All @@ -58,9 +63,50 @@ export abstract class AwsClient extends EventEmitter {

this.properties = new Map<string, any>(Object.entries(config));

const profileName = WrapperProperties.PROFILE_NAME.get(this.properties);
if (profileName && profileName.length > 0) {
this._configurationProfile = DriverConfigurationProfiles.getProfileConfiguration(profileName);
if (this._configurationProfile) {
const profileProperties = this._configurationProfile.getProperties();
if (profileProperties) {
for (const key of profileProperties.keys()) {
if (this.properties.has(key)) {
// Setting defined by a user has priority over property in configuration profile.
continue;
}
this.properties.set(key, profileProperties.get(key));
}

const connectionProvider = WrapperProperties.CONNECTION_PROVIDER.get(this.properties);
if (!connectionProvider) {
WrapperProperties.CONNECTION_PROVIDER.set(this.properties, this._configurationProfile.getAwsCredentialProvider());
}

const customAwsCredentialProvider = WrapperProperties.CUSTOM_AWS_CREDENTIAL_PROVIDER_HANDLER.get(this.properties);
if (!customAwsCredentialProvider) {
WrapperProperties.CUSTOM_AWS_CREDENTIAL_PROVIDER_HANDLER.set(this.properties, this._configurationProfile.getAwsCredentialProvider());
}

const customDatabaseDialect = WrapperProperties.CUSTOM_DATABASE_DIALECT.get(this.properties);
if (!customDatabaseDialect) {
WrapperProperties.CUSTOM_DATABASE_DIALECT.set(this.properties, this._configurationProfile.getDatabaseDialect());
}
}
} else {
throw new AwsWrapperError(Messages.get("AwsClient.configurationProfileNotFound", profileName));
}
}

this.telemetryFactory = new DefaultTelemetryFactory(this.properties);
const container = new PluginServiceManagerContainer();
this.pluginService = new PluginService(container, this, dbType, knownDialectsByCode, this.properties, driverDialect);
this.pluginService = new PluginService(
container,
this,
dbType,
knownDialectsByCode,
this.properties,
this._configurationProfile?.getDriverDialect() ?? driverDialect
);
this.pluginManager = new PluginManager(
container,
this.properties,
Expand All @@ -71,7 +117,7 @@ export abstract class AwsClient extends EventEmitter {

private async setup() {
await this.telemetryFactory.init();
await this.pluginManager.init();
await this.pluginManager.init(this._configurationProfile);
}

protected async internalConnect() {
Expand Down
107 changes: 70 additions & 37 deletions common/lib/connection_plugin_chain_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
import { ConnectionPluginFactory } from "./plugin_factory";
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
import { ConfigurationProfile } from "./profile/configuration_profile";

/*
Type alias used for plugin factory sorting. It holds a reference to a plugin
Expand Down Expand Up @@ -69,58 +70,90 @@ export class ConnectionPluginChainBuilder {
["executeTime", { factory: ExecuteTimePluginFactory, weight: ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN }]
]);

static readonly PLUGIN_WEIGHTS = new Map<typeof ConnectionPluginFactory, number>([
[AuroraInitialConnectionStrategyFactory, 390],
[AuroraConnectionTrackerPluginFactory, 400],
[StaleDnsPluginFactory, 500],
[ReadWriteSplittingPluginFactory, 600],
[FailoverPluginFactory, 700],
[HostMonitoringPluginFactory, 800],
[LimitlessConnectionPluginFactory, 950],
[IamAuthenticationPluginFactory, 1000],
[AwsSecretsManagerPluginFactory, 1100],
[FederatedAuthPluginFactory, 1200],
[OktaAuthPluginFactory, 1300],
[DeveloperConnectionPluginFactory, 1400],
[ConnectTimePluginFactory, ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN],
[ExecuteTimePluginFactory, ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN]
]);

static async getPlugins(
pluginService: PluginService,
props: Map<string, any>,
connectionProviderManager: ConnectionProviderManager
connectionProviderManager: ConnectionProviderManager,
configurationProfile: ConfigurationProfile | null
): Promise<ConnectionPlugin[]> {
let pluginFactoryInfoList: PluginFactoryInfo[] = [];
const plugins: ConnectionPlugin[] = [];
let pluginCodes: string = props.get(WrapperProperties.PLUGINS.name);
if (pluginCodes == null) {
pluginCodes = WrapperProperties.DEFAULT_PLUGINS;
}

const usingDefault = pluginCodes === WrapperProperties.DEFAULT_PLUGINS;
let usingDefault: boolean = false;

pluginCodes = pluginCodes.trim();

if (pluginCodes !== "") {
const pluginCodeList = pluginCodes.split(",").map((pluginCode) => pluginCode.trim());
let pluginFactoryInfoList: PluginFactoryInfo[] = [];
let lastWeight = 0;
pluginCodeList.forEach((p) => {
if (!ConnectionPluginChainBuilder.PLUGIN_FACTORIES.has(p)) {
throw new AwsWrapperError(Messages.get("PluginManager.unknownPluginCode", p));
}

const factoryInfo = ConnectionPluginChainBuilder.PLUGIN_FACTORIES.get(p);
if (factoryInfo) {
if (factoryInfo.weight === ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN) {
lastWeight++;
} else {
lastWeight = factoryInfo.weight;
if (configurationProfile) {
const profilePluginFactories = configurationProfile.getPluginFactories();
if (profilePluginFactories) {
for (const factory of profilePluginFactories) {
const weight = ConnectionPluginChainBuilder.PLUGIN_WEIGHTS.get(factory);
if (!weight) {
throw new AwsWrapperError(Messages.get("PluginManager.unknownPluginWeight", factory.prototype.constructor.name));
}
pluginFactoryInfoList.push({ factory: factoryInfo.factory, weight: lastWeight });
pluginFactoryInfoList.push({ factory: factory, weight: weight });
}
});
usingDefault = true; // We assume that plugin factories in configuration profile is presorted.
}
} else {
let pluginCodes: string = props.get(WrapperProperties.PLUGINS.name);
if (pluginCodes == null) {
pluginCodes = WrapperProperties.DEFAULT_PLUGINS;
}
usingDefault = pluginCodes === WrapperProperties.DEFAULT_PLUGINS;

if (!usingDefault && pluginFactoryInfoList.length > 1 && WrapperProperties.AUTO_SORT_PLUGIN_ORDER.get(props)) {
pluginFactoryInfoList = pluginFactoryInfoList.sort((a, b) => a.weight - b.weight);
pluginCodes = pluginCodes.trim();
if (pluginCodes !== "") {
const pluginCodeList = pluginCodes.split(",").map((pluginCode) => pluginCode.trim());
let lastWeight = 0;
pluginCodeList.forEach((p) => {
if (!ConnectionPluginChainBuilder.PLUGIN_FACTORIES.has(p)) {
throw new AwsWrapperError(Messages.get("PluginManager.unknownPluginCode", p));
}

if (!usingDefault) {
logger.info(
"Plugins order has been rearranged. The following order is in effect: " +
pluginFactoryInfoList.map((pluginFactoryInfo) => pluginFactoryInfo.factory.name.split("Factory")[0]).join(", ")
);
}
const factoryInfo = ConnectionPluginChainBuilder.PLUGIN_FACTORIES.get(p);
if (factoryInfo) {
if (factoryInfo.weight === ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN) {
lastWeight++;
} else {
lastWeight = factoryInfo.weight;
}
pluginFactoryInfoList.push({ factory: factoryInfo.factory, weight: lastWeight });
}
});
}
}

if (!usingDefault && pluginFactoryInfoList.length > 1 && WrapperProperties.AUTO_SORT_PLUGIN_ORDER.get(props)) {
pluginFactoryInfoList = pluginFactoryInfoList.sort((a, b) => a.weight - b.weight);

for (const pluginFactoryInfo of pluginFactoryInfoList) {
const factoryObj = new pluginFactoryInfo.factory();
plugins.push(await factoryObj.getInstance(pluginService, props));
if (!usingDefault) {
logger.info(
"Plugins order has been rearranged. The following order is in effect: " +
pluginFactoryInfoList.map((pluginFactoryInfo) => pluginFactoryInfo.factory.name.split("Factory")[0]).join(", ")
);
}
}

for (const pluginFactoryInfo of pluginFactoryInfoList) {
const factoryObj = new pluginFactoryInfo.factory();
plugins.push(await factoryObj.getInstance(pluginService, props));
}

plugins.push(new DefaultPlugin(pluginService, connectionProviderManager));

return plugins;
Expand Down
41 changes: 22 additions & 19 deletions common/lib/database_dialect/database_dialect_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,46 @@ export class DatabaseDialectManager implements DatabaseDialectProvider {
*/
private static readonly ENDPOINT_CACHE_EXPIRATION_MS = 86_400_000_000_000; // 24 hours
protected static readonly knownEndpointDialects: CacheMap<string, string> = new CacheMap();
protected readonly knownDialectsByCode: Map<string, DatabaseDialect>;

private static customDialect: DatabaseDialect | null = null;
private readonly rdsHelper: RdsUtils = new RdsUtils();
private readonly dbType;
private canUpdate: boolean = false;
private dialect: DatabaseDialect;
private dialectCode: string = "";
protected readonly knownDialectsByCode: Map<string, DatabaseDialect>;
protected readonly customDialect: DatabaseDialect | null;
protected readonly rdsHelper: RdsUtils = new RdsUtils();
protected readonly dbType: DatabaseType;
protected canUpdate: boolean = false;
protected dialect: DatabaseDialect;
protected dialectCode: string = "";

constructor(knownDialectsByCode: any, dbType: DatabaseType, props: Map<string, any>) {
this.knownDialectsByCode = knownDialectsByCode;
this.dbType = dbType;
this.dialect = this.getDialect(props);
}

static setCustomDialect(dialect: DatabaseDialect) {
DatabaseDialectManager.customDialect = dialect;
}
const dialectSetting = WrapperProperties.CUSTOM_DATABASE_DIALECT.get(props);
if (dialectSetting && !this.isDatabaseDialect(dialectSetting)) {
throw new AwsWrapperError(Messages.get("DatabaseDialectManager.wrongCustomDialect"));
}
this.customDialect = dialectSetting;

static resetCustomDialect() {
DatabaseDialectManager.customDialect = null;
this.dialect = this.getDialect(props);
}

static resetEndpointCache() {
DatabaseDialectManager.knownEndpointDialects.clear();
}

protected isDatabaseDialect(arg: any): arg is DatabaseDialect {
return arg.getDialectName !== undefined;
}

getDialect(props: Map<string, any>): DatabaseDialect {
if (this.dialect) {
return this.dialect;
}

this.canUpdate = false;

if (DatabaseDialectManager.customDialect) {
if (this.customDialect) {
this.dialectCode = DatabaseDialectCodes.CUSTOM;
this.dialect = DatabaseDialectManager.customDialect;
this.dialect = this.customDialect;
this.logCurrentDialect();
return this.dialect;
}
Expand All @@ -87,7 +90,7 @@ export class DatabaseDialectManager implements DatabaseDialectProvider {
this.logCurrentDialect();
return userDialect;
}
throw new AwsWrapperError(Messages.get("DialectManager.unknownDialectCode", dialectCode));
throw new AwsWrapperError(Messages.get("DatabaseDialectManager.unknownDialectCode", dialectCode));
}

if (this.dbType === DatabaseType.MYSQL) {
Expand Down Expand Up @@ -148,7 +151,7 @@ export class DatabaseDialectManager implements DatabaseDialectProvider {
return this.dialect;
}

throw new AwsWrapperError(Messages.get("DialectManager.getDialectError"));
throw new AwsWrapperError(Messages.get("DatabaseDialectManager.getDialectError"));
}

async getDialectForUpdate(targetClient: ClientWrapper, originalHost: string, newHost: string): Promise<DatabaseDialect> {
Expand All @@ -161,7 +164,7 @@ export class DatabaseDialectManager implements DatabaseDialectProvider {
for (const dialectCandidateCode of dialectCandidates) {
const dialectCandidate = this.knownDialectsByCode.get(dialectCandidateCode);
if (!dialectCandidate) {
throw new AwsWrapperError(Messages.get("DialectManager.unknownDialectCode", dialectCandidateCode));
throw new AwsWrapperError(Messages.get("DatabaseDialectManager.unknownDialectCode", dialectCandidateCode));
}

const isDialect = await dialectCandidate.isDialect(targetClient);
Expand Down
10 changes: 6 additions & 4 deletions common/lib/plugin_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level";
import { ConnectionProvider } from "./connection_provider";
import { ConnectionPluginFactory } from "./plugin_factory";
import { ConfigurationProfile } from "./profile/configuration_profile";

type PluginFunc<T> = (plugin: ConnectionPlugin, targetFunc: () => Promise<T>) => Promise<T>;

Expand Down Expand Up @@ -91,17 +92,18 @@ export class PluginManager {
this.telemetryFactory = telemetryFactory;
}

async init(): Promise<void>;
async init(plugins: ConnectionPlugin[]): Promise<void>;
async init(plugins?: ConnectionPlugin[]) {
async init(configurationProfile?: ConfigurationProfile | null): Promise<void>;
async init(configurationProfile: ConfigurationProfile | null, plugins: ConnectionPlugin[]): Promise<void>;
async init(configurationProfile: ConfigurationProfile | null, plugins?: ConnectionPlugin[]) {
if (this.pluginServiceManagerContainer.pluginService != null) {
if (plugins) {
this._plugins = plugins;
} else {
this._plugins = await ConnectionPluginChainBuilder.getPlugins(
this.pluginServiceManagerContainer.pluginService,
this.props,
this.connectionProviderManager
this.connectionProviderManager,
configurationProfile
);
}
}
Expand Down
Loading

0 comments on commit f9ce248

Please sign in to comment.