Skip to content

Commit

Permalink
feat: aws secrets manager plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill committed Apr 8, 2024
1 parent 760353f commit 0ecf281
Show file tree
Hide file tree
Showing 13 changed files with 1,216 additions and 258 deletions.
188 changes: 188 additions & 0 deletions common/lib/authentication/aws_secrets_manager_plugin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import {
SecretsManagerClientConfig,
SecretsManagerClient,
SecretsManagerServiceException,
GetSecretValueCommand
} from "@aws-sdk/client-secrets-manager";
import { logger } from "../../logutils";
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
import { ConnectionPlugin } from "../connection_plugin";
import { HostInfo } from "../host_info";
import { ConnectionPluginFactory } from "../plugin_factory";
import { PluginService } from "../plugin_service";
import { AwsWrapperError } from "../utils/aws_wrapper_error";
import { Messages } from "../utils/messages";
import { WrapperProperties } from "../wrapper_property";

export class AwsSecretsManagerPlugin extends AbstractConnectionPlugin {
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(["connect", "forceConnect"]);
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
static secretsCache: Map<string, Secret> = new Map();
private secret: Secret | null = null;
secretKey: SecretCacheKey;
secretsManagerClient: SecretsManagerClient;

constructor(properties: Map<string, any>) {
super();

const secretId = WrapperProperties.SECRET_ID.get(properties);
const endpoint = WrapperProperties.SECRET_ENDPOINT.get(properties);
let region = WrapperProperties.SECRET_REGION.get(properties);
const config: SecretsManagerClientConfig = {};

if (!secretId) {
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter"));
}

if (!region) {
const groups = secretId.match(AwsSecretsManagerPlugin.SECRETS_ARN_PATTERN)?.groups;
if (groups?.region) {
region = groups.region;
}
config.region = region;
}

if (endpoint) {
config.endpoint = endpoint;
}

this.secretKey = new SecretCacheKey(secretId, region);
this.secretsManagerClient = new SecretsManagerClient(config);
}

getSubscribedMethods(): Set<string> {
return AwsSecretsManagerPlugin.SUBSCRIBED_METHODS;
}

connect<T>(hostInfo: HostInfo, props: Map<string, any>, isInitialConnection: boolean, connectFunc: () => Promise<T>): Promise<T> {
return this.connectInternal(hostInfo, props, connectFunc);
}

forceConnect<T>(hostInfo: HostInfo, props: Map<string, any>, isInitialConnection: boolean, forceConnectFunc: () => Promise<T>): Promise<T> {
return this.connectInternal(hostInfo, props, forceConnectFunc);
}

private async connectInternal<T>(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<T>): Promise<T> {
let secretWasFetched = await this.updateSecret(false);
try {
WrapperProperties.USER.set(props, this.secret?.username ?? "");
WrapperProperties.PASSWORD.set(props, this.secret?.password ?? "");
return await connectFunc();
} catch (error) {
if (error instanceof Error) {
if (error.message.includes("password authentication failed") || error.message.includes("Access denied")) {
if (!secretWasFetched) {
// Login unsuccessful with cached credentials
// Try to re-fetch credentials and try again

secretWasFetched = await this.updateSecret(true);
if (secretWasFetched) {
WrapperProperties.USER.set(props, this.secret?.username ?? "");
WrapperProperties.PASSWORD.set(props, this.secret?.password ?? "");
return await connectFunc();
}
}
}
logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledException", error.name, error.message));
}
throw error;
}
}

private async updateSecret(forceRefresh: boolean): Promise<boolean> {
let fetched = false;
this.secret = AwsSecretsManagerPlugin.secretsCache.get(JSON.stringify(this.secretKey)) ?? null;

if (!this.secret || forceRefresh) {
try {
this.secret = await this.fetchLatestCredentials();
fetched = true;
AwsSecretsManagerPlugin.secretsCache.set(JSON.stringify(this.secretKey), this.secret);
} catch (error) {
if (error instanceof SecretsManagerServiceException) {
logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials"));
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials"));
} else if (error instanceof Error && error.message.includes("AWS SDK error")) {
logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection", error.message));
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection", error.message));
} else {
logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledException", JSON.stringify(error)));
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledException", JSON.stringify(error)));
}
}
}

return fetched;
}

private async fetchLatestCredentials(): Promise<Secret> {
const commandInput = {
SecretId: this.secretKey.secretId
};
const command = new GetSecretValueCommand(commandInput);
const result = await this.secretsManagerClient.send(command);
const secret = new Secret(JSON.parse(result.SecretString ?? "").username, JSON.parse(result.SecretString ?? "").password);
if (secret && secret.username && secret.password) {
return secret;
}
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials"));
}
}

export class SecretCacheKey {
private readonly _secretId: string;
private readonly _region: string | null;

constructor(secretId: string, region: string) {
this._secretId = secretId;
this._region = region;
}

get secretId(): string {
return this._secretId;
}

get region(): string | null {
return this._region;
}
}

export class Secret {
private readonly _username: string;
private readonly _password: string;

constructor(username: string, password: string) {
this._username = username;
this._password = password;
}

get username(): string {
return this._username;
}

get password(): string | null {
return this._password;
}
}

export class AwsSecretsManagerPluginFactory implements ConnectionPluginFactory {
getInstance(pluginService: PluginService, properties: Map<string, any>): ConnectionPlugin {
return new AwsSecretsManagerPlugin(new Map(properties));
}
}
2 changes: 0 additions & 2 deletions common/lib/authentication/iam_authentication_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
const token = await this.generateAuthenticationToken(hostInfo, props, host, port, region);
logger.debug(Messages.get("IamAuthenticationPlugin.generatedNewIamToken", token));
WrapperProperties.PASSWORD.set(props, token);
this.pluginService.updateCredentials(props);
IamAuthenticationPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
}

Expand All @@ -98,7 +97,6 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
const token = await this.generateAuthenticationToken(hostInfo, props, host, port, region);
logger.debug(Messages.get("IamAuthenticationPlugin.generatedNewIamToken", token));
WrapperProperties.PASSWORD.set(props, token);
this.pluginService.updateCredentials(props);
IamAuthenticationPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
return connectFunc();
}
Expand Down
13 changes: 0 additions & 13 deletions common/lib/aws_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,6 @@ export abstract class AwsClient {
return this._connectFunc;
}

updateCredentials(properties: Map<string, any>): void {
const user = WrapperProperties.USER.get(properties);
const pass = WrapperProperties.PASSWORD.get(properties);

if (this.targetClient.user != user) {
this.targetClient.user = user;
}

if (this.targetClient.password != pass) {
this.targetClient.password = pass;
}
}

abstract executeQuery(props: Map<string, any>, sql: string): Promise<any>;

abstract end(): Promise<any>;
Expand Down
4 changes: 3 additions & 1 deletion common/lib/connection_plugin_chain_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { WrapperProperties } from "./wrapper_property";
import { AwsWrapperError } from "./utils/aws_wrapper_error";
import { Messages } from "./utils/messages";
import { DefaultPlugin } from "./plugins/default_plugin";
import { AwsSecretsManagerPluginFactory } from "./authentication/aws_secrets_manager_plugin";

export class PluginFactoryInfo {}

Expand All @@ -33,6 +34,7 @@ export class ConnectionPluginChainBuilder {

static readonly PLUGIN_FACTORIES = new Map<string, FactoryClass>([
["iam", IamAuthenticationPluginFactory],
["awsSecretsManager", AwsSecretsManagerPluginFactory],
["failover", FailoverPluginFactory]
]);

Expand Down Expand Up @@ -64,7 +66,7 @@ export class ConnectionPluginChainBuilder {
});
}

plugins.push(new DefaultPlugin());
plugins.push(new DefaultPlugin(pluginService));

return plugins;
}
Expand Down
22 changes: 16 additions & 6 deletions common/lib/plugin_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ export class PluginService implements ErrorHandler, HostListProviderService {
return this._currentClient;
}

updateCredentials(properties: Map<string, any>) {
this.getCurrentClient().updateCredentials(properties);
}

getConnectionUrlParser(): ConnectionUrlParser {
return this.getCurrentClient().connectionUrlParser;
}
Expand Down Expand Up @@ -197,6 +193,20 @@ export class PluginService implements ErrorHandler, HostListProviderService {

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}

replaceTargetClient(props: Map<string, any>): void {
const createClientFunc = this.getCurrentClient().getCreateClientFunc();
if (createClientFunc) {
if (this.getCurrentClient().targetClient) {
this.getCurrentClient().end();
}
const newTargetClient = createClientFunc(Object.fromEntries(props));
this.getCurrentClient().targetClient = newTargetClient;
return;
}
throw new AwsWrapperError("AwsClient is missing create target client function."); // This should not be reached
}

// TODO: use replaceTargetClient method instead
async createTargetClientAndConnect(hostInfo: HostInfo, props: Map<string, any>, forceConnect: boolean): Promise<AwsClient> {
if (this.pluginServiceManagerContainer.pluginManager) {
return await this.pluginServiceManagerContainer.pluginManager.createTargetClientAndConnect(hostInfo, props, this._currentClient, forceConnect);
Expand All @@ -210,14 +220,14 @@ export class PluginService implements ErrorHandler, HostListProviderService {
if (connectFunc) {
return this.pluginServiceManagerContainer.pluginManager?.connect(hostInfo, props, false, connectFunc);
}
throw new AwsWrapperError("AwsClient is missing target client connect functions."); // This should not be reached
throw new AwsWrapperError("AwsClient is missing target client connect function."); // This should not be reached
}

forceConnect(hostInfo: HostInfo, props: Map<string, any>) {
const connectFunc = this._currentClient.getConnectFunc();
if (connectFunc) {
return this.pluginServiceManagerContainer.pluginManager?.forceConnect(hostInfo, props, false, connectFunc);
}
throw new AwsWrapperError("AwsClient is missing target client connect functions."); // This should not be reached
throw new AwsWrapperError("AwsClient is missing target client connect function."); // This should not be reached
}
}
1 change: 1 addition & 0 deletions common/lib/plugins/default_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { Messages } from "../utils/messages";
import { HostListProviderService } from "../host_list_provider_service";
import { HostInfo } from "../host_info";
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
import { AwsWrapperError } from "../utils/aws_wrapper_error";

export class DefaultPlugin extends AbstractConnectionPlugin {
id: string = uniqueId("_defaultPlugin");
Expand Down
4 changes: 4 additions & 0 deletions common/lib/utils/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"IamAuthenticationPlugin.invalidPort": "Port number: %d is not valid. Port number should be greater than zero. Falling back to default port.",
"IamAuthenticationPlugin.unhandledException": "Unhandled exception: %s",
"IamAuthenticationPlugin.connectException": "Error occurred while opening a connection: %s",
"AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials": "Was not able to either fetch or read the database credentials from AWS Secrets Manager. Ensure the correct secretId and region properties have been provided.",
"AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter": "Configuration parameter 'secretId' is required.",
"AwsSecretsManagerConnectionPlugin.unhandledException": "Unhandled exception: '%s'",
"AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection": "A connection to the provided endpoint could not be established: '%s'.",
"PluginManager.PipelineNone": "A pipeline was requested but the created pipeline evaluated to None.",
"ClusterAwareReaderFailoverHandler.invalidTopology": "'%s' was called with an invalid (null or empty) topology",
"ClusterAwareReaderFailoverHandler.attemptingReaderConnection": "Trying to connect to reader: '%s', with properties '%s'",
Expand Down
11 changes: 10 additions & 1 deletion common/lib/wrapper_property.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ export class WrapperProperties {
WrapperProperties.DEFAULT_TOKEN_EXPIRATION_SEC
);

static readonly SECRET_ID = new WrapperProperty<string>("secretsManagerSecretId", "The name or the ARN of the secret to retrieve.", null);
static readonly SECRET_REGION = new WrapperProperty<string>("secretsManagerRegion", "The region of the secret to retrieve.", null);
static readonly SECRET_ENDPOINT = new WrapperProperty<string>("secretsManagerEndpoint", "The endpoint of the secret to retrieve.", null);

static readonly CLUSTER_TOPOLOGY_REFRESH_RATE_MS = new WrapperProperty<number>(
"clusterTopologyRefreshRateMs",
"Cluster topology refresh rate in millis. " +
Expand Down Expand Up @@ -98,7 +102,12 @@ export class WrapperProperties {

static removeWrapperProperties<T>(config: T): T {
const copy = Object.assign({}, config);
const persistingProperties = [WrapperProperties.USER.name, WrapperProperties.PASSWORD.name, WrapperProperties.DATABASE.name, WrapperProperties.PORT.name];
const persistingProperties = [
WrapperProperties.USER.name,
WrapperProperties.PASSWORD.name,
WrapperProperties.DATABASE.name,
WrapperProperties.PORT.name
];

Object.values(WrapperProperties).forEach((prop) => {
if (prop instanceof WrapperProperty) {
Expand Down
16 changes: 12 additions & 4 deletions mysql/lib/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ export class AwsMySQLClient extends AwsClient {
constructor(config: any) {
super(config, new MySQLErrorHandler(), new AuroraMySQLDatabaseDialect(), new MySQLConnectionUrlParser());
this.config = config;
this.targetClient = createConnection(WrapperProperties.removeWrapperProperties(config));
this._createClientFunc = (config: any) => {
return createConnection(WrapperProperties.removeWrapperProperties(config));
};
Expand All @@ -43,18 +42,27 @@ export class AwsMySQLClient extends AwsClient {

async connect(): Promise<Connection> {
await this.internalConnect();
let conn: Promise<Connection> = this.pluginManager.connect(this.pluginService.getCurrentHostInfo(), this.properties, true, async () => {
const conn: Promise<Connection> = this.pluginManager.connect(this.pluginService.getCurrentHostInfo(), this.properties, true, async () => {
this.targetClient = createConnection(WrapperProperties.removeWrapperProperties(this.config));
return this.targetClient.promise().connect();
});
this.isConnected = true;
return conn;
}

executeQuery(props: Map<string, any>, sql: string): Promise<Query> {
async executeQuery(props: Map<string, any>, sql: string): Promise<Query> {
if (!this.isConnected) {
await this.connect(); // client.connect is not required for MySQL clients
this.isConnected = true;
}
return this.targetClient.promise().query({ sql: sql });
}

query(options: QueryOptions, callback?: any): Promise<Query> {
async query(options: QueryOptions, callback?: any): Promise<Query> {
if (!this.isConnected) {
await this.connect(); // client.connect is not required for MySQL clients
this.isConnected = true;
}
const host = this.pluginService.getCurrentHostInfo();
return this.pluginManager.execute(
host,
Expand Down
Loading

0 comments on commit 0ecf281

Please sign in to comment.