Skip to content

Commit

Permalink
feat: aws secrets manager plugin (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill authored Apr 9, 2024
1 parent 29a4426 commit 0460cc2
Show file tree
Hide file tree
Showing 12 changed files with 1,211 additions and 255 deletions.
191 changes: 191 additions & 0 deletions common/lib/authentication/aws_secrets_manager_plugin.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import {
SecretsManagerClientConfig,
SecretsManagerClient,
SecretsManagerServiceException,
GetSecretValueCommand
} from "@aws-sdk/client-secrets-manager";
import { logger } from "../../logutils";
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
import { ConnectionPlugin } from "../connection_plugin";
import { HostInfo } from "../host_info";
import { ConnectionPluginFactory } from "../plugin_factory";
import { PluginService } from "../plugin_service";
import { AwsWrapperError } from "../utils/aws_wrapper_error";
import { Messages } from "../utils/messages";
import { WrapperProperties } from "../wrapper_property";

export class AwsSecretsManagerPlugin extends AbstractConnectionPlugin {
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(["connect", "forceConnect"]);
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
private readonly pluginService: PluginService;
private secret: Secret | null = null;
static secretsCache: Map<string, Secret> = new Map();
secretKey: SecretCacheKey;
secretsManagerClient: SecretsManagerClient;

constructor(pluginService: PluginService, properties: Map<string, any>) {
super();

this.pluginService = pluginService;
const secretId = WrapperProperties.SECRET_ID.get(properties);
const endpoint = WrapperProperties.SECRET_ENDPOINT.get(properties);
let region = WrapperProperties.SECRET_REGION.get(properties);
const config: SecretsManagerClientConfig = {};

if (!secretId) {
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter"));
}

if (!region) {
const groups = secretId.match(AwsSecretsManagerPlugin.SECRETS_ARN_PATTERN)?.groups;
if (groups?.region) {
region = groups.region;
}
config.region = region;
}

if (endpoint) {
config.endpoint = endpoint;
}

this.secretKey = new SecretCacheKey(secretId, region);
this.secretsManagerClient = new SecretsManagerClient(config);
}

getSubscribedMethods(): Set<string> {
return AwsSecretsManagerPlugin.SUBSCRIBED_METHODS;
}

connect<T>(hostInfo: HostInfo, props: Map<string, any>, isInitialConnection: boolean, connectFunc: () => Promise<T>): Promise<T> {
return this.connectInternal(hostInfo, props, connectFunc);
}

forceConnect<T>(hostInfo: HostInfo, props: Map<string, any>, isInitialConnection: boolean, forceConnectFunc: () => Promise<T>): Promise<T> {
return this.connectInternal(hostInfo, props, forceConnectFunc);
}

private async connectInternal<T>(hostInfo: HostInfo, props: Map<string, any>, connectFunc: () => Promise<T>): Promise<T> {
let secretWasFetched = await this.updateSecret(false);
try {
WrapperProperties.USER.set(props, this.secret?.username ?? "");
WrapperProperties.PASSWORD.set(props, this.secret?.password ?? "");
this.pluginService.updateConfigWithProperties(props);
return await connectFunc();
} catch (error) {
if (error instanceof Error) {
if ((error.message.includes("password authentication failed") || error.message.includes("Access denied")) && !secretWasFetched) {
// Login unsuccessful with cached credentials
// Try to re-fetch credentials and try again

secretWasFetched = await this.updateSecret(true);
if (secretWasFetched) {
WrapperProperties.USER.set(props, this.secret?.username ?? "");
WrapperProperties.PASSWORD.set(props, this.secret?.password ?? "");
return await connectFunc();
}
}
logger.debug(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledException", error.name, error.message));
}
throw error;
}
}

private async updateSecret(forceRefresh: boolean): Promise<boolean> {
let fetched = false;
this.secret = AwsSecretsManagerPlugin.secretsCache.get(JSON.stringify(this.secretKey)) ?? null;

if (!this.secret || forceRefresh) {
try {
this.secret = await this.fetchLatestCredentials();
fetched = true;
AwsSecretsManagerPlugin.secretsCache.set(JSON.stringify(this.secretKey), this.secret);
} catch (error) {
if (error instanceof SecretsManagerServiceException) {
this.logAndThrowError(Messages.get("AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials"));
} else if (error instanceof Error && error.message.includes("AWS SDK error")) {
this.logAndThrowError(Messages.get("AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection", error.message));
} else {
this.logAndThrowError(Messages.get("AwsSecretsManagerConnectionPlugin.unhandledException", JSON.stringify(error)));
}
}
}

return fetched;
}

private async fetchLatestCredentials(): Promise<Secret> {
const commandInput = {
SecretId: this.secretKey.secretId
};
const command = new GetSecretValueCommand(commandInput);
const result = await this.secretsManagerClient.send(command);
const secret = new Secret(JSON.parse(result.SecretString ?? "").username, JSON.parse(result.SecretString ?? "").password);
if (secret && secret.username && secret.password) {
return secret;
}
throw new AwsWrapperError(Messages.get("AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials"));
}

private logAndThrowError(message: string) {
logger.debug(message);
throw new AwsWrapperError(message);
}
}

export class SecretCacheKey {
private readonly _secretId: string;
private readonly _region: string | null;

constructor(secretId: string, region: string) {
this._secretId = secretId;
this._region = region;
}

get secretId(): string {
return this._secretId;
}

get region(): string | null {
return this._region;
}
}

export class Secret {
private readonly _username: string;
private readonly _password: string;

constructor(username: string, password: string) {
this._username = username;
this._password = password;
}

get username(): string {
return this._username;
}

get password(): string | null {
return this._password;
}
}

export class AwsSecretsManagerPluginFactory implements ConnectionPluginFactory {
getInstance(pluginService: PluginService, properties: Map<string, any>): ConnectionPlugin {
return new AwsSecretsManagerPlugin(pluginService, new Map(properties));
}
}
3 changes: 1 addition & 2 deletions common/lib/authentication/iam_authentication_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
const token = await this.generateAuthenticationToken(hostInfo, props, host, port, region);
logger.debug(Messages.get("IamAuthenticationPlugin.generatedNewIamToken", token));
WrapperProperties.PASSWORD.set(props, token);
this.pluginService.updateCredentials(props);
IamAuthenticationPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
}
this.pluginService.updateConfigWithProperties(props);

try {
return connectFunc();
Expand All @@ -98,7 +98,6 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
const token = await this.generateAuthenticationToken(hostInfo, props, host, port, region);
logger.debug(Messages.get("IamAuthenticationPlugin.generatedNewIamToken", token));
WrapperProperties.PASSWORD.set(props, token);
this.pluginService.updateCredentials(props);
IamAuthenticationPlugin.tokenCache.set(cacheKey, new TokenInfo(token, tokenExpiry));
return connectFunc();
}
Expand Down
13 changes: 0 additions & 13 deletions common/lib/aws_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,6 @@ export abstract class AwsClient {
return this._connectFunc;
}

updateCredentials(properties: Map<string, any>): void {
const user = WrapperProperties.USER.get(properties);
const pass = WrapperProperties.PASSWORD.get(properties);

if (this.targetClient.user != user) {
this.targetClient.user = user;
}

if (this.targetClient.password != pass) {
this.targetClient.password = pass;
}
}

abstract executeQuery(props: Map<string, any>, sql: string): Promise<any>;

abstract end(): Promise<any>;
Expand Down
2 changes: 2 additions & 0 deletions common/lib/connection_plugin_chain_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { WrapperProperties } from "./wrapper_property";
import { AwsWrapperError } from "./utils/aws_wrapper_error";
import { Messages } from "./utils/messages";
import { DefaultPlugin } from "./plugins/default_plugin";
import { AwsSecretsManagerPluginFactory } from "./authentication/aws_secrets_manager_plugin";

export class PluginFactoryInfo {}

Expand All @@ -33,6 +34,7 @@ export class ConnectionPluginChainBuilder {

static readonly PLUGIN_FACTORIES = new Map<string, FactoryClass>([
["iam", IamAuthenticationPluginFactory],
["secretsManager", AwsSecretsManagerPluginFactory],
["failover", FailoverPluginFactory]
]);

Expand Down
26 changes: 20 additions & 6 deletions common/lib/plugin_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ export class PluginService implements ErrorHandler, HostListProviderService {
return this._currentClient;
}

updateCredentials(properties: Map<string, any>) {
this.getCurrentClient().updateCredentials(properties);
}

getConnectionUrlParser(): ConnectionUrlParser {
return this.getCurrentClient().connectionUrlParser;
}
Expand Down Expand Up @@ -197,6 +193,24 @@ export class PluginService implements ErrorHandler, HostListProviderService {

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}

updateConfigWithProperties(props: Map<string, any>) {
this._currentClient.config = Object.fromEntries(props.entries());
}

replaceTargetClient(props: Map<string, any>): void {
const createClientFunc = this.getCurrentClient().getCreateClientFunc();
if (createClientFunc) {
if (this.getCurrentClient().targetClient) {
this.getCurrentClient().end();
}
const newTargetClient = createClientFunc(Object.fromEntries(props));
this.getCurrentClient().targetClient = newTargetClient;
return;
}
throw new AwsWrapperError("AwsClient is missing create target client function."); // This should not be reached
}

// TODO: use replaceTargetClient method instead
async createTargetClientAndConnect(hostInfo: HostInfo, props: Map<string, any>, forceConnect: boolean): Promise<AwsClient> {
if (this.pluginServiceManagerContainer.pluginManager) {
return await this.pluginServiceManagerContainer.pluginManager.createTargetClientAndConnect(hostInfo, props, this._currentClient, forceConnect);
Expand All @@ -210,14 +224,14 @@ export class PluginService implements ErrorHandler, HostListProviderService {
if (connectFunc) {
return this.pluginServiceManagerContainer.pluginManager?.connect(hostInfo, props, false, connectFunc);
}
throw new AwsWrapperError("AwsClient is missing target client connect functions."); // This should not be reached
throw new AwsWrapperError("AwsClient is missing target client connect function."); // This should not be reached
}

forceConnect(hostInfo: HostInfo, props: Map<string, any>) {
const connectFunc = this._currentClient.getConnectFunc();
if (connectFunc) {
return this.pluginServiceManagerContainer.pluginManager?.forceConnect(hostInfo, props, false, connectFunc);
}
throw new AwsWrapperError("AwsClient is missing target client connect functions."); // This should not be reached
throw new AwsWrapperError("AwsClient is missing target client connect function."); // This should not be reached
}
}
4 changes: 4 additions & 0 deletions common/lib/utils/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
"IamAuthenticationPlugin.invalidPort": "Port number: %d is not valid. Port number should be greater than zero. Falling back to default port.",
"IamAuthenticationPlugin.unhandledException": "Unhandled exception: %s",
"IamAuthenticationPlugin.connectException": "Error occurred while opening a connection: %s",
"AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials": "Was not able to either fetch or read the database credentials from AWS Secrets Manager. Ensure the correct secretId and region properties have been provided.",
"AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter": "Configuration parameter 'secretId' is required.",
"AwsSecretsManagerConnectionPlugin.unhandledException": "Unhandled exception: '%s'",
"AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection": "A connection to the provided endpoint could not be established: '%s'.",
"PluginManager.PipelineNone": "A pipeline was requested but the created pipeline evaluated to None.",
"ClusterAwareReaderFailoverHandler.invalidTopology": "'%s' was called with an invalid (null or empty) topology",
"ClusterAwareReaderFailoverHandler.attemptingReaderConnection": "Trying to connect to reader: '%s', with properties '%s'",
Expand Down
4 changes: 4 additions & 0 deletions common/lib/wrapper_property.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ export class WrapperProperties {
WrapperProperties.DEFAULT_TOKEN_EXPIRATION_SEC
);

static readonly SECRET_ID = new WrapperProperty<string>("secretId", "The name or the ARN of the secret to retrieve.", null);
static readonly SECRET_REGION = new WrapperProperty<string>("secretRegion", "The region of the secret to retrieve.", null);
static readonly SECRET_ENDPOINT = new WrapperProperty<string>("secretEndpoint", "The endpoint of the secret to retrieve.", null);

static readonly CLUSTER_TOPOLOGY_REFRESH_RATE_MS = new WrapperProperty<number>(
"clusterTopologyRefreshRateMs",
"Cluster topology refresh rate in millis. " +
Expand Down
14 changes: 11 additions & 3 deletions mysql/lib/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ export class AwsMySQLClient extends AwsClient {
constructor(config: any) {
super(config, new MySQLErrorHandler(), new AuroraMySQLDatabaseDialect(), new MySQLConnectionUrlParser());
this.config = config;
this.targetClient = createConnection(WrapperProperties.removeWrapperProperties(config));
this._createClientFunc = (config: any) => {
return createConnection(WrapperProperties.removeWrapperProperties(config));
};
Expand All @@ -44,17 +43,26 @@ export class AwsMySQLClient extends AwsClient {
async connect(): Promise<Connection> {
await this.internalConnect();
const conn: Promise<Connection> = this.pluginManager.connect(this.pluginService.getCurrentHostInfo(), this.properties, true, async () => {
this.targetClient = createConnection(WrapperProperties.removeWrapperProperties(this.config));
return this.targetClient.promise().connect();
});
this.isConnected = true;
return conn;
}

executeQuery(props: Map<string, any>, sql: string): Promise<Query> {
async executeQuery(props: Map<string, any>, sql: string): Promise<Query> {
if (!this.isConnected) {
await this.connect(); // client.connect is not required for MySQL clients
this.isConnected = true;
}
return this.targetClient.promise().query({ sql: sql });
}

query(options: QueryOptions, callback?: any): Promise<Query> {
async query(options: QueryOptions, callback?: any): Promise<Query> {
if (!this.isConnected) {
await this.connect(); // client.connect is not required for MySQL clients
this.isConnected = true;
}
const host = this.pluginService.getCurrentHostInfo();
return this.pluginManager.execute(
host,
Expand Down
Loading

0 comments on commit 0460cc2

Please sign in to comment.