Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(oauth): update all auth info in spec in oauth/update action #12514

Merged
merged 2 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions packages/fx-core/src/component/driver/oauth/update.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {
import { OauthNameTooLongError } from "./error/oauthNameTooLong";
import { UpdateOauthArgs } from "./interface/updateOauthArgs";
import { logMessageKeys } from "./utility/constants";
import { getandValidateOauthInfoFromSpec } from "./utility/utility";
import { getandValidateOauthInfoFromSpec, OauthInfo } from "./utility/utility";
import { OauthDisablePKCEError } from "./error/oauthDisablePKCEError";

const actionName = "oauth/update"; // DO NOT MODIFY the name
Expand All @@ -44,7 +44,6 @@ export class UpdateOauthDriver implements StepDriver {
this.validateArgs(args);

const authInfo = await getandValidateOauthInfoFromSpec(args, context, actionName);
const domain = authInfo.domain;
const appStudioTokenRes = await context.m365TokenProvider.getAccessToken({
scopes: AppStudioScopes,
});
Expand All @@ -61,7 +60,7 @@ export class UpdateOauthDriver implements StepDriver {
throw new OauthDisablePKCEError(actionName);
}

const diffMsgs = this.compareOauthRegistration(getOauthRes, args, domain);
const diffMsgs = this.compareOauthRegistration(getOauthRes, args, authInfo);
// If there is no difference, skip the update
if (!diffMsgs || diffMsgs.length === 0) {
const summary = getLocalizedString(logMessageKeys.skipUpdateOauth);
Expand All @@ -76,7 +75,9 @@ export class UpdateOauthDriver implements StepDriver {

// If there is difference, ask user to confirm the update
// Skip confirm if only targetUrlsShouldStartWith is different when the url contains devtunnel
if (!this.shouldSkipConfirm(diffMsgs, getOauthRes.targetUrlsShouldStartWith, domain)) {
if (
!this.shouldSkipConfirm(diffMsgs, getOauthRes.targetUrlsShouldStartWith, authInfo.domain)
) {
const userConfirm = await context.ui!.confirm!({
name: "confirm-update-oauth",
title: getLocalizedString("driver.oauth.confirm.update", diffMsgs.join(",\n")),
Expand All @@ -87,7 +88,7 @@ export class UpdateOauthDriver implements StepDriver {
}
}

const oauth = this.mapArgsToOauthRegistration(args, domain);
const oauth = this.mapArgsToOauthRegistration(args, authInfo);
await teamsDevPortalClient.updateOauthRegistration(
appStudioToken,
oauth,
Expand Down Expand Up @@ -179,7 +180,7 @@ export class UpdateOauthDriver implements StepDriver {
private compareOauthRegistration(
current: OauthRegistration,
input: UpdateOauthArgs,
domain: string[]
authInfo: OauthInfo
): string[] {
const diffMsgs: string[] = [];
if (current.description !== input.name) {
Expand All @@ -201,6 +202,7 @@ export class UpdateOauthDriver implements StepDriver {
}

// Compare domain
const domain = authInfo.domain;
if (
current.targetUrlsShouldStartWith.length !== domain.length ||
!current.targetUrlsShouldStartWith.every((value) => domain.includes(value)) ||
Expand All @@ -213,7 +215,46 @@ export class UpdateOauthDriver implements StepDriver {
);
}

if (current.isPKCEEnabled !== input.isPKCEEnabled) {
// TODO: Need to separate the logic for different flows
// Compare authorizationEndpoint
if (
authInfo.authorizationEndpoint &&
current.authorizationEndpoint !== authInfo.authorizationEndpoint
) {
diffMsgs.push(
`authorizationEndpoint: ${current.authorizationEndpoint} => ${authInfo.authorizationEndpoint}`
);
}

// Compare tokenExchangeEndpoint
if (
authInfo.tokenExchangeEndpoint &&
current.tokenExchangeEndpoint !== authInfo.tokenExchangeEndpoint
) {
diffMsgs.push(
`tokenExchangeEndpoint: ${current.tokenExchangeEndpoint} => ${authInfo.tokenExchangeEndpoint}`
);
}

// Compare tokenRefreshEndpoint
if (current.tokenRefreshEndpoint !== authInfo.tokenRefreshEndpoint) {
diffMsgs.push(
`tokenRefreshEndpoint: ${current.tokenRefreshEndpoint ?? "Undefined"} => ${
authInfo.tokenRefreshEndpoint ?? "Undefined"
}`
);
}

// Compare scopes
if (!this.compareScopes(current.scopes, authInfo.scopes)) {
diffMsgs.push(
`scopes: ${current.scopes.join(",")} => ${
authInfo.scopes ? authInfo.scopes.join(",") : "Undefined"
}`
);
}

if (!!current.isPKCEEnabled !== !!input.isPKCEEnabled) {
diffMsgs.push(
`isPKCEEnabled: ${(!!current.isPKCEEnabled).toString()} => ${(!!input.isPKCEEnabled).toString()}`
);
Expand All @@ -233,7 +274,10 @@ export class UpdateOauthDriver implements StepDriver {
);
}

private mapArgsToOauthRegistration(args: UpdateOauthArgs, domain: string[]): OauthRegistration {
private mapArgsToOauthRegistration(
args: UpdateOauthArgs,
authInfo: OauthInfo
): OauthRegistration {
const targetAudience = args.targetAudience
? (args.targetAudience as OauthRegistrationTargetAudience)
: undefined;
Expand All @@ -243,11 +287,24 @@ export class UpdateOauthDriver implements StepDriver {

return {
description: args.name,
targetUrlsShouldStartWith: domain,
targetUrlsShouldStartWith: authInfo.domain,
applicableToApps: applicableToApps,
m365AppId: applicableToApps === OauthRegistrationAppType.SpecificApp ? args.appId : "",
targetAudience: targetAudience,
isPKCEEnabled: !!args.isPKCEEnabled,
authorizationEndpoint: authInfo.authorizationEndpoint,
tokenExchangeEndpoint: authInfo.tokenExchangeEndpoint,
tokenRefreshEndpoint: authInfo.tokenRefreshEndpoint,
scopes: authInfo.scopes ?? [],
} as OauthRegistration;
}

private compareScopes(current: string[], input: string[] | undefined): boolean {
return (
!!input &&
current.length === input.length &&
current.every((value) => input.includes(value)) &&
input.every((value) => current.includes(value))
);
}
}
72 changes: 34 additions & 38 deletions packages/fx-core/src/component/driver/oauth/utility/utility.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,46 +60,42 @@ export async function getandValidateOauthInfoFromSpec(
});
validateDomain(domains, actionName);

if ("flow" in args) {
const authInfoArray = operations
.map((value) => {
let authInfo;
switch (args.flow) {
case "authorizationCode":
default:
authInfo = (value.auth?.authScheme as OpenAPIV3.OAuth2SecurityScheme).flows
.authorizationCode;
}
return {
authorizationUrl: authInfo!.authorizationUrl,
tokenUrl: authInfo!.tokenUrl,
refreshUrl: authInfo!.refreshUrl,
scopes: Object.keys(authInfo!.scopes),
};
})
.reduce((accumulator: AuthInfo[], currentValue) => {
if (!accumulator.find((item) => isEqual(item, currentValue))) {
accumulator.push(currentValue);
}
return accumulator;
}, []);
// Need to separate the logic for different flows
const flow = "flow" in args ? args.flow : "authorizationCode";
const authInfoArray = operations
.map((value) => {
let authInfo;
switch (flow) {
case "authorizationCode":
default:
authInfo = (value.auth?.authScheme as OpenAPIV3.OAuth2SecurityScheme).flows
.authorizationCode;
}
return {
authorizationUrl: authInfo!.authorizationUrl,
tokenUrl: authInfo!.tokenUrl,
refreshUrl: authInfo!.refreshUrl,
scopes: Object.keys(authInfo!.scopes),
};
})
.reduce((accumulator: AuthInfo[], currentValue) => {
if (!accumulator.find((item) => isEqual(item, currentValue))) {
accumulator.push(currentValue);
}
return accumulator;
}, []);

if (authInfoArray.length !== 1) {
throw new OauthAuthInfoInvalid(actionName);
}
const authInfo = authInfoArray[0];
return {
domain: domains,
authorizationEndpoint: authInfo.authorizationUrl,
tokenExchangeEndpoint: authInfo.tokenUrl,
tokenRefreshEndpoint: authInfo.refreshUrl,
scopes: authInfo.scopes,
};
} else {
return {
domain: domains,
};
if (authInfoArray.length !== 1) {
throw new OauthAuthInfoInvalid(actionName);
}
const authInfo = authInfoArray[0];
return {
domain: domains,
authorizationEndpoint: authInfo.authorizationUrl,
tokenExchangeEndpoint: authInfo.tokenUrl,
tokenRefreshEndpoint: authInfo.refreshUrl,
scopes: authInfo.scopes,
};
}

function validateDomain(domain: string[], actionName: string): void {
Expand Down
Loading
Loading