Skip to content
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import { decodeJWT } from '@aws-amplify/core/internals/utils';

import { refreshAuthTokens } from '../../../src/providers/cognito/utils/refreshAuthTokens';
Expand Down Expand Up @@ -60,6 +63,7 @@ describe('refreshToken', () => {
});

it('should refresh token', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const expectedOutput = {
accessToken: decodeJWT(mockAccessToken),
idToken: decodeJWT(mockAccessToken),
Expand All @@ -82,6 +86,7 @@ describe('refreshToken', () => {
},
},
username: mockedUsername,
clientMetadata,
});

// stringify and re-parse for JWT equality
Expand All @@ -93,6 +98,7 @@ describe('refreshToken', () => {
expect.objectContaining({
ClientId: 'aaaaaaaaaaaa',
RefreshToken: mockedRefreshToken,
ClientMetadata: clientMetadata,
}),
);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,40 @@ describe('TokenOrchestrator', () => {
expect(tokens?.accessToken).toEqual(validAuthTokens.accessToken);
});
});

describe('setClientMetadataProvider', () => {
it('should use clientMetadataProvider for token refresh', async () => {
const clientMetadata = { 'app-version': '1.0.0' };
const clientMetadataProvider = () => Promise.resolve(clientMetadata);

mockTokenRefresher.mockResolvedValue({
accessToken: { payload: {} },
idToken: { payload: {} },
clockDrift: 0,
refreshToken: 'newRefreshToken',
username: 'testuser',
});

tokenOrchestrator.setTokenRefresher(mockTokenRefresher);
tokenOrchestrator.setAuthTokenStore(mockAuthTokenStore);
tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);

mockAuthTokenStore.loadTokens.mockResolvedValue({
accessToken: { payload: { exp: 1 } },
idToken: { payload: { exp: 1 } },
clockDrift: 0,
refreshToken: 'refreshToken',
username: 'testuser',
});
mockAuthTokenStore.getLastAuthUser.mockResolvedValue('testuser');

await tokenOrchestrator.getTokens({ forceRefresh: true });

expect(mockTokenRefresher).toHaveBeenCalledWith(
expect.objectContaining({
clientMetadata,
}),
);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
defaultStorage,
Expand Down Expand Up @@ -38,6 +39,12 @@ export class CognitoUserPoolsTokenProvider
this.authTokenStore.setKeyValueStorage(keyValueStorage);
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.tokenOrchestrator.setClientMetadataProvider(clientMetadataProvider);
}

setAuthConfig(authConfig: AuthConfig) {
this.authTokenStore.setAuthConfig(authConfig);
this.tokenOrchestrator.setAuthConfig(authConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
CognitoUserPoolConfig,
FetchAuthSessionOptions,
Hub,
Expand All @@ -19,7 +20,7 @@ import { assertServiceError } from '../../../errors/utils/assertServiceError';
import { AuthError } from '../../../errors/AuthError';
import { oAuthStore } from '../utils/oauth/oAuthStore';
import { addInflightPromise } from '../utils/oauth/inflightPromise';
import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

import {
AuthTokenOrchestrator,
Expand All @@ -32,6 +33,7 @@ import {

export class TokenOrchestrator implements AuthTokenOrchestrator {
private authConfig?: AuthConfig;
clientMetadataProvider?: ClientMetadataProvider;
tokenStore?: AuthTokenStore;
tokenRefresher?: TokenRefresher;
inflightPromise: Promise<void> | undefined;
Expand Down Expand Up @@ -94,6 +96,12 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
return this.tokenRefresher;
}

setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void {
this.clientMetadataProvider = clientMetadataProvider;
}

async getTokens(
options?: FetchAuthSessionOptions,
): Promise<
Expand Down Expand Up @@ -130,6 +138,8 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
tokens = await this.refreshTokens({
tokens,
username,
clientMetadata:
options?.clientMetadata ?? (await this.clientMetadataProvider?.()),
});

if (tokens === null) {
Expand All @@ -147,16 +157,19 @@ export class TokenOrchestrator implements AuthTokenOrchestrator {
private async refreshTokens({
tokens,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens | null> {
try {
const { signInDetails } = tokens;
const newTokens = await this.getTokenRefresher()({
tokens,
authConfig: this.authConfig,
username,
clientMetadata,
});
newTokens.signInDetails = signInDetails;
await this.setTokens({ tokens: newTokens });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@
import {
AuthConfig,
AuthTokens,
ClientMetadataProvider,
FetchAuthSessionOptions,
KeyValueStorageInterface,
TokenProvider,
} from '@aws-amplify/core';

import { CognitoAuthSignInDetails } from '../types';
import { ClientMetadata, CognitoAuthSignInDetails } from '../types';

export type TokenRefresher = ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}) => Promise<CognitoAuthTokens>;

export type AuthKeys<AuthKey extends string> = Record<AuthKey, string>;
Expand Down Expand Up @@ -66,6 +69,9 @@ export interface AuthTokenOrchestrator {
export interface CognitoUserPoolTokenProviderType extends TokenProvider {
setKeyValueStorage(keyValueStorage: KeyValueStorageInterface): void;
setAuthConfig(authConfig: AuthConfig): void;
setClientMetadataProvider(
clientMetadataProvider: ClientMetadataProvider,
): void;
}

export type CognitoAuthTokens = AuthTokens & {
Expand Down
2 changes: 1 addition & 1 deletion packages/auth/src/providers/cognito/types/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export const cognitoHostedUIIdentityProviderMap: Record<AuthProvider, string> =
/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;
export type { ClientMetadata } from '@aws-amplify/core';

/**
* Allowed values for preferredChallenge
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@ import { assertAuthTokensWithRefreshToken } from '../utils/types';
import { AuthError } from '../../../errors/AuthError';
import { createCognitoUserPoolEndpointResolver } from '../factories';
import { createGetTokensFromRefreshTokenClient } from '../../../foundation/factories/serviceClients/cognitoIdentityProvider';
import { ClientMetadata } from '../types';

const refreshAuthTokensFunction: TokenRefresher = async ({
tokens,
authConfig,
username,
clientMetadata,
}: {
tokens: CognitoAuthTokens;
authConfig?: AuthConfig;
username: string;
clientMetadata?: ClientMetadata;
}): Promise<CognitoAuthTokens> => {
assertTokenProviderConfig(authConfig?.Cognito);
const { userPoolId, userPoolClientId, userPoolEndpoint } = authConfig.Cognito;
Expand All @@ -41,6 +44,7 @@ const refreshAuthTokensFunction: TokenRefresher = async ({
ClientId: userPoolClientId,
RefreshToken: tokens.refreshToken,
DeviceKey: tokens.deviceMetadata?.deviceKey,
ClientMetadata: clientMetadata,
},
);

Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ export {
OAuthConfig,
CognitoUserPoolConfig,
JWT,
ClientMetadata,
ClientMetadataProvider,
} from './singleton/Auth/types';
export { decodeJWT } from './singleton/Auth/utils';
export {
Expand Down
11 changes: 11 additions & 0 deletions packages/core/src/singleton/Auth/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@
import { StrictUnion } from '../../types';
import { AtLeastOne } from '../types';

/**
* Arbitrary key/value pairs that may be passed as part of certain Cognito requests
*/
export type ClientMetadata = Record<string, string>;

/**
* Function type for providing client metadata for Cognito operations
*/
export type ClientMetadataProvider = () => Promise<ClientMetadata>;

// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/safe-json-parse.ts
// From https://github.com/awslabs/aws-jwt-verify/blob/main/src/jwt-model.ts
interface JwtPayloadStandardFields {
Expand Down Expand Up @@ -66,6 +76,7 @@ export interface TokenProvider {

export interface FetchAuthSessionOptions {
forceRefresh?: boolean;
clientMetadata?: ClientMetadata;
}

export interface AuthTokens {
Expand Down
Loading