Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Stack Connectors][Microsoft Defender] Add caching of OAuth access token to connector #206975

Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,6 @@ describe('Microsoft Defender for Endpoint Connector', () => {
connectorMock = microsoftDefenderEndpointConnectorMocks.create();
});

describe('Access Token management', () => {
it('should call API to generate as new token', async () => {
await connectorMock.instanceMock.isolateHost(
{ id: '1-2-3', comment: 'foo' },
connectorMock.usageCollector
);

expect(connectorMock.instanceMock.request).toHaveBeenCalledWith(
expect.objectContaining({
url: `${connectorMock.options.config.oAuthServerUrl}/${connectorMock.options.config.tenantId}/oauth2/v2.0/token`,
method: 'POST',
data: {
grant_type: 'client_credentials',
client_id: connectorMock.options.config.clientId,
scope: connectorMock.options.config.oAuthScope,
client_secret: connectorMock.options.secrets.clientSecret,
},
}),
connectorMock.usageCollector
);
});
});

describe('#testConnector', () => {
it('should return expected response', async () => {
Object.entries(connectorMock.apiMock).forEach(([url, responseFn]) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import { actionsConfigMock } from '@kbn/actions-plugin/server/actions_config.moc
import { loggingSystemMock } from '@kbn/core-logging-server-mocks';
import { actionsMock } from '@kbn/actions-plugin/server/mocks';
import { ConnectorUsageCollector } from '@kbn/actions-plugin/server/usage';
import { ConnectorToken } from '@kbn/actions-plugin/server/types';
import { ConnectorTokenClient } from '@kbn/actions-plugin/server/lib/connector_token_client';
import {
MicrosoftDefenderEndpointConfig,
MicrosoftDefenderEndpointMachine,
Expand All @@ -31,6 +33,67 @@ export interface CreateMicrosoftDefenderConnectorMockResponse {
usageCollector: ConnectorUsageCollector;
}

const createConnectorTokenMock = (overrides: Partial<ConnectorToken> = {}): ConnectorToken => {
const expiresAt = new Date();
expiresAt.setMinutes(expiresAt.getMinutes() + 30);

return {
id: '1',
connectorId: '123',
tokenType: 'access_token',
token: 'testtokenvalue',
expiresAt: expiresAt.toISOString(),
createdAt: '2025-01-16T13:02:43.494Z',
updatedAt: '2025-01-16T13:02:43.494Z',
...overrides,
};
};

const applyConnectorTokenClientInstanceMock = (
connectorTokenClient: ConnectorTokenClient
): void => {
// Make connector token client a mocked class instance
let cachedTokenMock: ConnectorToken | null = null;

jest.spyOn(connectorTokenClient, 'updateOrReplace');
jest
.spyOn(connectorTokenClient, 'create')
.mockImplementation(
async ({ connectorId, token, expiresAtMillis: expiresAt, tokenType = 'access_token' }) => {
cachedTokenMock = createConnectorTokenMock({
connectorId,
token,
expiresAt,
tokenType,
});
return cachedTokenMock;
}
);
jest
.spyOn(connectorTokenClient, 'update')
.mockImplementation(
async ({ token, expiresAtMillis: expiresAt, tokenType = 'access_token' }) => {
if (cachedTokenMock) {
cachedTokenMock = {
...cachedTokenMock,
token,
expiresAt,
tokenType,
};
}

return cachedTokenMock;
}
);
jest.spyOn(connectorTokenClient, 'get').mockImplementation(async () => {
return { hasErrors: !cachedTokenMock, connectorToken: cachedTokenMock };
});
jest.spyOn(connectorTokenClient, 'deleteConnectorTokens').mockImplementation(async () => {
cachedTokenMock = null;
return [];
});
};

const createMicrosoftDefenderConnectorMock = (): CreateMicrosoftDefenderConnectorMockResponse => {
const apiUrl = 'https://api.mock__microsoft.com';
const options: CreateMicrosoftDefenderConnectorMockResponse['options'] = {
Expand Down Expand Up @@ -102,6 +165,8 @@ const createMicrosoftDefenderConnectorMock = (): CreateMicrosoftDefenderConnecto
}
);

applyConnectorTokenClientInstanceMock(options.services.connectorTokenClient);

return {
options,
apiMock,
Expand Down Expand Up @@ -175,4 +240,6 @@ export const microsoftDefenderEndpointConnectorMocks = Object.freeze({
create: createMicrosoftDefenderConnectorMock,
createMachineMock: createMicrosoftMachineMock,
createMachineActionMock: createMicrosoftMachineAction,
applyConnectorTokenClientMock: applyConnectorTokenClientInstanceMock,
createConnectorToken: createConnectorTokenMock,
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import {
CreateMicrosoftDefenderConnectorMockResponse,
microsoftDefenderEndpointConnectorMocks,
} from './mocks';
import { OAuthTokenManager } from './o_auth_token_manager';
import { ConnectorTokenClient } from '@kbn/actions-plugin/server/lib/connector_token_client';

describe('Microsoft Defender for Endpoint oAuth token manager', () => {
let testMock: CreateMicrosoftDefenderConnectorMockResponse;
let msOAuthManagerMock: OAuthTokenManager;
let connectorTokenManagerClientMock: jest.Mocked<ConnectorTokenClient>;

beforeEach(() => {
testMock = microsoftDefenderEndpointConnectorMocks.create();
connectorTokenManagerClientMock = testMock.options.services
.connectorTokenClient as jest.Mocked<ConnectorTokenClient>;
msOAuthManagerMock = new OAuthTokenManager({
...testMock.options,
apiRequest: async (...args) => testMock.instanceMock.request(...args),
});
});

it('should call MS api to generate new token', async () => {
await msOAuthManagerMock.get(testMock.usageCollector);

expect(testMock.instanceMock.request).toHaveBeenCalledWith(
expect.objectContaining({
url: `${testMock.options.config.oAuthServerUrl}/${testMock.options.config.tenantId}/oauth2/v2.0/token`,
method: 'POST',
data: {
grant_type: 'client_credentials',
client_id: testMock.options.config.clientId,
scope: testMock.options.config.oAuthScope,
client_secret: testMock.options.secrets.clientSecret,
},
}),
testMock.usageCollector
);
});

it('should use cached token when one exists', async () => {
const {
connectorId,
token,
expiresAt: expiresAtMillis,
tokenType,
} = microsoftDefenderEndpointConnectorMocks.createConnectorToken();
await connectorTokenManagerClientMock.create({
connectorId,
token,
expiresAtMillis,
tokenType,
});
await msOAuthManagerMock.get(testMock.usageCollector);

expect(testMock.instanceMock.request).not.toHaveBeenCalled();
expect(connectorTokenManagerClientMock.get).toHaveBeenCalledWith({
connectorId: '1',
tokenType: 'access_token',
});
});

it('should call MS API to generate new token when the cached token is expired', async () => {
const { connectorId, token, tokenType } =
microsoftDefenderEndpointConnectorMocks.createConnectorToken();
await connectorTokenManagerClientMock.create({
connectorId,
token,
expiresAtMillis: '2024-01-16T13:02:43.494Z',
tokenType,
});
await msOAuthManagerMock.get(testMock.usageCollector);

expect(connectorTokenManagerClientMock.get).toHaveBeenCalledWith({
connectorId: '1',
tokenType: 'access_token',
});
expect(testMock.instanceMock.request).toHaveBeenCalledWith(
expect.objectContaining({
url: `${testMock.options.config.oAuthServerUrl}/${testMock.options.config.tenantId}/oauth2/v2.0/token`,
}),
testMock.usageCollector
);
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import { ServiceParams, SubActionConnector } from '@kbn/actions-plugin/server';
import { ConnectorUsageCollector } from '@kbn/actions-plugin/server/usage';
import { ConnectorToken } from '@kbn/actions-plugin/server/types';
import { MicrosoftDefenderEndpointDoNotValidateResponseSchema } from '../../../common/microsoft_defender_endpoint/schema';
import {
MicrosoftDefenderEndpointConfig,
Expand All @@ -15,8 +16,12 @@ import {
} from '../../../common/microsoft_defender_endpoint/types';

export class OAuthTokenManager {
private accessToken: string = '';
private connectorToken: ConnectorToken | null = null;
private readonly oAuthTokenUrl: string;
// NOTE: this `tokenType` here MUST be `access_token` due to the use of
// `ConnectorTokenClient.updateOrCreate()` method, which hardcodes the `tokenType`
private readonly tokenType = 'access_token';
private generatingNewTokenPromise: Promise<void> | null = null;

constructor(
private readonly params: ServiceParams<
Expand All @@ -34,48 +39,127 @@ export class OAuthTokenManager {
this.oAuthTokenUrl = url.toString();
}

private async generateNewToken(connectorUsageCollector: ConnectorUsageCollector): Promise<void> {
// FYI: API Docs: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow#get-a-token
const { oAuthScope, clientId } = this.params.config;
const newToken = await this.params.apiRequest<MicrosoftDefenderEndpointApiTokenResponse>(
{
url: this.oAuthTokenUrl,
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
data: {
grant_type: 'client_credentials',
client_id: clientId,
scope: oAuthScope,
client_secret: this.params.secrets.clientSecret,
private isTokenExpired(token: ConnectorToken): boolean {
const now = new Date();
now.setSeconds(now.getSeconds() - 5); // Allows for a threshold of -5s before considering the token expired

const isExpired = token.expiresAt < now.toISOString();

if (isExpired) {
this.params.logger.debug(`Cached access token expired at [${token.expiresAt}]`);
}

return isExpired;
}

private async retrieveOrGenerateNewTokenIfNeeded(
connectorUsageCollector: ConnectorUsageCollector
): Promise<void> {
if (this.generatingNewTokenPromise) {
return await this.generatingNewTokenPromise;
}

this.generatingNewTokenPromise = (async () => {
const {
connector: { id: connectorId },
logger,
} = this.params;
const connectorTokenClient = this.params.services.connectorTokenClient;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this services.connectorTokenClient a functionality exposed by response ops team, or you created that in a separate PR before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is exposed by the ResponseOps framework. Its purpose is to only store tokens so they can be reused across instances of kibana.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I'll try to rebuild the token cache in CS to use this one 👍


if (!this.connectorToken) {
logger.debug(`Retrieving cached connector access token (if any)`);

const cachedToken = await connectorTokenClient.get({
connectorId,
tokenType: this.tokenType,
});

if (cachedToken.connectorToken) {
this.connectorToken = cachedToken.connectorToken;

const logToken = {
...this.connectorToken,
token: '[redacted]',
};

logger.debug(() => `using cached access token:\n${JSON.stringify(logToken, null, 2)}`);
} else {
logger.debug(`No cached access token found`);
}
}

if (this.connectorToken && !this.isTokenExpired(this.connectorToken)) {
logger.debug('Cached token is not expired - no need to request a new one');
return;
}

logger.debug(`Requesting a new Microsoft access token for connector id [${connectorId}]]`);

// FYI: API Docs: https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-client-creds-grant-flow#get-a-token
const { oAuthScope, clientId } = this.params.config;
const tokenRequestDate = Date.now();
const newToken = await this.params.apiRequest<MicrosoftDefenderEndpointApiTokenResponse>(
{
url: this.oAuthTokenUrl,
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
data: {
grant_type: 'client_credentials',
client_id: clientId,
scope: oAuthScope,
client_secret: this.params.secrets.clientSecret,
},
responseSchema: MicrosoftDefenderEndpointDoNotValidateResponseSchema,
},
responseSchema: MicrosoftDefenderEndpointDoNotValidateResponseSchema,
},
connectorUsageCollector
);

this.params.logger.debug(
() =>
`Successfully created an access token for Microsoft Defend for Endpoint:\n${JSON.stringify({
...newToken.data,
access_token: '[REDACTED]',
})}`
);

this.accessToken = newToken.data.access_token;
connectorUsageCollector
);

logger.debug(
() =>
`Successfully created an access token for Microsoft Defend for Endpoint:\n${JSON.stringify(
{
...newToken.data,
access_token: '[REDACTED]',
}
)}`
);

await connectorTokenClient.updateOrReplace({
connectorId,
tokenRequestDate,
deleteExisting: true,
token: this.connectorToken,
newToken: newToken.data.access_token,
expiresInSec: newToken.data.expires_in,
});

const updatedCachedToken = await connectorTokenClient.get({
connectorId,
tokenType: this.tokenType,
});

if (!updatedCachedToken.connectorToken) {
throw new Error(`Failed to retrieve cached [${this.tokenType}] after it was updated.`);
}

this.connectorToken = updatedCachedToken.connectorToken;
})().finally(() => {
this.generatingNewTokenPromise = null;
});

return this.generatingNewTokenPromise;
}

/**
* Returns the Bearer token that should be used in API calls
*/
public async get(connectorUsageCollector: ConnectorUsageCollector): Promise<string> {
if (!this.accessToken) {
await this.generateNewToken(connectorUsageCollector);
}
await this.retrieveOrGenerateNewTokenIfNeeded(connectorUsageCollector);

if (!this.accessToken) {
if (!this.connectorToken) {
throw new Error('Access token for Microsoft Defend for Endpoint not available!');
}

return this.accessToken;
return this.connectorToken.token;
}
}
Loading