diff --git a/README.md b/README.md index 1009233..17509ec 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,8 @@ Some bundles create outputs, the driver captures these in an Azure File Share, t | CNAB_AZURE_LOCATION | The location in which to create the ACI Container Group and Resource Group | | CNAB_AZURE_NAME | The name of the ACI instance to create - if not specified a name will be generated | | CNAB_AZURE_DELETE_RESOURCES | Set to false so as not to delete the RG and ACI container group created, default is true - useful for debugging - only deletes RG if it was created by the driver | +| CNAB_AZURE_CLI_ARM_ENDPOINT | The URL for the Azure Resource Manager when using from the CLI. This defaults to 'https://management.azure.com/ | +| CNAB_AZURE_MSI_AUDIENCE | The 'audience' to include in the Cloud Shell MSI token request. This defaults to 'https://management.azure.com/' but can be changed if needed for clouds other than Azure public. | | CNAB_AZURE_MSI_TYPE | This can be set to either `user` or `system` This value is presented to the invocation image container as `AZURE_MSI_TYPE`| | CNAB_AZURE_SYSTEM_MSI_ROLE | If `CNAB_AZURE_SYSTEM_MSI_ROLE` is set to `system` this defines the role to be assigned to System MSI User, if this is null or empty then the role defaults to `Contributor` | | CNAB_AZURE_SYSTEM_MSI_SCOPE | If `CNAB_AZURE_SYSTEM_MSI_ROLE` is set to `system` this defines the scope to apply the role to System MSI User - if this is null or empty then the scope will be Resource Group that the ACI Instance is being created | diff --git a/pkg/azure/cloudshell.go b/pkg/azure/cloudshell.go index 92de6c5..7ca7e36 100644 --- a/pkg/azure/cloudshell.go +++ b/pkg/azure/cloudshell.go @@ -51,10 +51,6 @@ type cloudrive struct { Size int `json:"diskSizeInGB"` } -type adtoken struct { - Oid string `json:"oid"` -} - // FileShareDetails contains details of the clouddrive FileShare type FileShareDetails struct { Name string @@ -278,8 +274,11 @@ func CheckCanAccessResource(actionID string, scope string) (bool, error) { if !IsInCloudShell() { return false, errors.New("Not Running in CloudShell") } - - oid, err := getOidFromToken() + adalToken, err := GetCloudShellToken() + if err != nil { + return false, fmt.Errorf("Error Getting CloudShellToken: %v", err) + } + oid, err := getFromToken(adalToken.AccessToken, "oid") if err != nil { return false, fmt.Errorf("failed to get Oid: %v ", err) } @@ -302,25 +301,24 @@ func CheckCanAccessResource(actionID string, scope string) (bool, error) { log.Debug("Check Access POST Body ", string(payload)) return makeCheckAccessRequest(payload, scope) } -func getOidFromToken() (string, error) { - adalToken, err := GetCloudShellToken() - if err != nil { - return "", fmt.Errorf("failed to get CloudShell Token: %s", err) - } - - bearerToken := strings.Split(adalToken.AccessToken, ".")[1] +func getFromToken(accessToken string, parameter string) (string, error) { + bearerToken := strings.Split(accessToken, ".")[1] if len(bearerToken) == 0 { - return "", fmt.Errorf("Failed to get bearer token from CloudShell Token: %v ", err) + return "", errors.New("Failed to get bearer token from CloudShell Token") } token, err := base64.RawStdEncoding.DecodeString(bearerToken) if err != nil { return "", fmt.Errorf("Failed to decode Bearer Token: %v ", err) } - adToken := adtoken{} + var adToken map[string]interface{} if err := json.Unmarshal(token, &adToken); err != nil { - return "", fmt.Errorf("failed to unmarshall CloudShell token: %v ", err) + return "", fmt.Errorf("Failed to unmarshall CloudShell token: %v ", err) } - return adToken.Oid, nil + parameterValue, hasParameter := adToken[parameter] + if hasParameter == false { + return "", errors.New("Requested token parameter not present") + } + return parameterValue.(string), err } func makeCheckAccessRequest(payload []byte, scope string) (bool, error) { @@ -330,13 +328,17 @@ func makeCheckAccessRequest(payload []byte, scope string) (bool, error) { if err != nil { return false, fmt.Errorf("Error Getting CloudShellToken: %v", err) } + audUrl, err := getFromToken(adalToken.AccessToken, "aud") + if err != nil { + audUrl = "https://management.azure.com/" + } retry: for i := 1; i < 4; i++ { timeout := time.Duration(time.Duration(i) * time.Second) client := http.Client{ Timeout: timeout, } - url := fmt.Sprintf("https://management.azure.com/%s/providers/Microsoft.Authorization/CheckAccess", scope) + url := fmt.Sprintf("%s%s/providers/Microsoft.Authorization/CheckAccess", audUrl, scope) log.Debug("Check Access URL: ", url) var req *http.Request req, err = http.NewRequest("POST", url, bytes.NewBuffer(payload)) diff --git a/pkg/azure/login_info.go b/pkg/azure/login_info.go index e8cbbf0..3ba6ea7 100644 --- a/pkg/azure/login_info.go +++ b/pkg/azure/login_info.go @@ -124,6 +124,12 @@ func GetCloudShellToken() (*adal.Token, error) { return nil, errors.New("MSI_ENDPOINT environment variable not set") } + MSIAudience := os.Getenv("CNAB_AZURE_MSI_AUDIENCE") + if len(MSIAudience) == 0 { + MSIAudience = "https://management.azure.com/" + } + log.Debug("CloudShell MSI Audience: ", MSIAudience) + timeout := time.Duration(1 * time.Second) client := http.Client{ Timeout: timeout, @@ -136,7 +142,7 @@ func GetCloudShellToken() (*adal.Token, error) { req.Header.Set("Metadata", "true") query := req.URL.Query() query.Add("api-version", "2018-02-01") - query.Add("resource", "https://management.azure.com/") + query.Add("resource", MSIAudience) req.URL.RawQuery = query.Encode() log.Debug("Cloud Shell Token URI: ", req.RequestURI) resp, err := client.Do(req) diff --git a/pkg/driver/aci-driver.go b/pkg/driver/aci-driver.go index 07bd84a..55127fa 100644 --- a/pkg/driver/aci-driver.go +++ b/pkg/driver/aci-driver.go @@ -1208,7 +1208,14 @@ func (d *aciDriver) createCredentialEnvVars(env []containerinstance.EnvironmentV if d.loginInfo.LoginType == az.CLI { log.Debug("Propagating OAuth Token from cli") - t, err := cli.GetTokenFromCLI("https://management.azure.com/") + + ARMEndpoint := os.Getenv("CNAB_AZURE_CLI_ARM_ENDPOINT") + if len(ARMEndpoint) == 0 { + ARMEndpoint = "https://management.azure.com/" + } + log.Debug("CLI ARM Endpoint: ", ARMEndpoint) + + t, err := cli.GetTokenFromCLI(ARMEndpoint) if err != nil { return nil, err }