Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"**/test-resources.json",
"**/test-resources-post.ps1",
"**/assets.json",
"**/*.pfx",
".config",
".devcontainer/devcontainer.json",
".devcontainer/Dockerfile",
Expand Down
31 changes: 23 additions & 8 deletions sdk/identity/azure_identity/src/client_assertion_credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pub(crate) mod tests {

pub fn is_valid_request(
expected_authority: String,
expected_assertion: Option<String>,
) -> impl Fn(&Request) -> azure_core::Result<()> {
let expected_url = format!("{expected_authority}/oauth2/v2.0/token");
move |req: &Request| {
Expand All @@ -212,20 +213,29 @@ pub(crate) mod tests {
content_type::APPLICATION_X_WWW_FORM_URLENCODED.as_str(),
req.headers().get_str(&headers::CONTENT_TYPE).unwrap()
);
let expected_params = [
("client_assertion", FAKE_ASSERTION),
("client_assertion_type", ASSERTION_TYPE),
("client_id", FAKE_CLIENT_ID),
("grant_type", "client_credentials"),
("scope", &LIVE_TEST_SCOPES.join(" ")),
];
let body = match req.body() {
Body::Bytes(bytes) => str::from_utf8(bytes).unwrap(),
_ => panic!("unexpected body type"),
};
let actual_params: HashMap<String, String> = form_urlencoded::parse(body.as_bytes())
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
let assertion = actual_params
.get("client_assertion")
.expect("request body should contain client_assertion");
match &expected_assertion {
Some(expected) => assert_eq!(expected, assertion),
None => assert!(
!assertion.is_empty(),
"expected client_assertion to be present"
),
}
let expected_params = [
("client_assertion_type", ASSERTION_TYPE),
("client_id", FAKE_CLIENT_ID),
("grant_type", "client_credentials"),
("scope", &LIVE_TEST_SCOPES.join(" ")),
];
for (key, value) in expected_params.iter() {
assert_eq!(
*value,
Expand Down Expand Up @@ -263,6 +273,7 @@ pub(crate) mod tests {
)],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
Some(FAKE_ASSERTION.to_string()),
))),
);
let credential = ClientAssertionCredential::new(
Expand Down Expand Up @@ -297,6 +308,7 @@ pub(crate) mod tests {
vec![token_response()],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
Some(FAKE_ASSERTION.to_string()),
))),
);
let credential = ClientAssertionCredential::new(
Expand Down Expand Up @@ -335,7 +347,10 @@ pub(crate) mod tests {
for (cloud, expected_authority) in cloud_configuration_cases() {
let mock = MockSts::new(
vec![token_response()],
Some(Arc::new(is_valid_request(expected_authority))),
Some(Arc::new(is_valid_request(
expected_authority,
Some(FAKE_ASSERTION.to_string()),
))),
);
let credential = ClientAssertionCredential::new(
FAKE_TENANT_ID.to_string(),
Expand Down
197 changes: 188 additions & 9 deletions sdk/identity/azure_identity/src/client_certificate_credential.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

use crate::{authentication_error, get_authority_host, EntraIdTokenResponse, TokenCache};
use crate::{
authentication_error, deserialize, get_authority_host, EntraIdErrorResponse,
EntraIdTokenResponse, TokenCache,
};
use azure_core::{
base64,
credentials::{AccessToken, Secret, TokenCredential, TokenRequestOptions},
error::{Error, ErrorKind, ResultExt},
http::{
headers::{self, content_type},
request::Request,
ClientOptions, Method, Pipeline, PipelineSendOptions, Url,
ClientOptions, Method, Pipeline, PipelineSendOptions, StatusCode, Url,
},
time::{Duration, OffsetDateTime},
Uuid,
Expand Down Expand Up @@ -128,7 +131,7 @@ impl ClientCertificateCredential {
base64::encode_url_safe(part)
}

async fn get_token(
async fn get_token_impl(
&self,
scopes: &[&str],
options: Option<TokenRequestOptions<'_>>,
Expand Down Expand Up @@ -245,11 +248,27 @@ impl ClientCertificateCredential {
}),
)
.await?;
let response: EntraIdTokenResponse = rsp.into_body().json()?;
Ok(AccessToken::new(
response.access_token,
OffsetDateTime::now_utc() + Duration::seconds(response.expires_in),
))

match rsp.status() {
StatusCode::Ok => {
let response: EntraIdTokenResponse =
deserialize(stringify!(ClientCertificateCredential), rsp)?;
Ok(AccessToken::new(
response.access_token,
OffsetDateTime::now_utc() + Duration::seconds(response.expires_in),
))
}
_ => {
let error_response: EntraIdErrorResponse =
deserialize(stringify!(ClientCertificateCredential), rsp)?;
let message = if error_response.error_description.is_empty() {
"authentication failed".to_string()
} else {
error_response.error_description.clone()
};
Err(Error::with_message(ErrorKind::Credential, message))
}
}
}
}

Expand All @@ -273,8 +292,168 @@ impl TokenCredential for ClientCertificateCredential {
options: Option<TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
self.cache
.get_token(scopes, options, |s, o| self.get_token(s, o))
.get_token(scopes, options, |s, o| self.get_token_impl(s, o))
.await
.map_err(authentication_error::<Self>)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
client_assertion_credential::tests::is_valid_request, tests::*, TSG_LINK_ERROR_TEXT,
};
use azure_core::{
http::{headers::Headers, BufResponse, StatusCode, Transport},
Bytes,
};
use std::sync::{Arc, LazyLock};

static TEST_CERT: LazyLock<String> = LazyLock::new(|| {
let pfx = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/certificate.pfx"
))
.expect("failed to read test certificate");
base64::encode(pfx)
});

#[tokio::test]
async fn cloud_configuration() {
for (cloud, expected_authority) in cloud_configuration_cases() {
let sts = MockSts::new(
vec![token_response()],
Some(Arc::new(is_valid_request(expected_authority, None))),
);
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
Secret::new(TEST_CERT.to_string()),
Secret::new(""),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
cloud: Some(Arc::new(cloud)),
..Default::default()
},
..Default::default()
}),
)
.expect("valid credential");

credential
.get_token(LIVE_TEST_SCOPES, None)
.await
.expect("token");
}
}

#[tokio::test]
async fn get_token_error() {
let description = "AADSTS7000215: Invalid client certificate.";
let sts = MockSts::new(
vec![BufResponse::from_bytes(
StatusCode::BadRequest,
Headers::default(),
Bytes::from(format!(
r#"{{"error":"invalid_client","error_description":"{description}","error_codes":[7000215],"timestamp":"2025-04-04 21:10:04Z","trace_id":"...","correlation_id":"...","error_uri":"https://login.microsoftonline.com/error?code=7000215"}}"#,
)),
)],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
None,
))),
);
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.to_string(),
Secret::new(""),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
..Default::default()
},
..Default::default()
}),
)
.expect("valid credential");

let err = credential
.get_token(LIVE_TEST_SCOPES, None)
.await
.expect_err("expected error");
assert!(matches!(err.kind(), ErrorKind::Credential));
assert!(
err.to_string().contains(description),
"expected error description from the response, got '{}'",
err
);
assert!(
err.to_string()
.contains(&format!("{TSG_LINK_ERROR_TEXT}#client-cert")),
"expected error to contain a link to the troubleshooting guide, got '{err}'",
);
}

#[tokio::test]
async fn get_token_success() {
let sts = MockSts::new(
vec![token_response()],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
None,
))),
);
let credential = ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.to_string(),
Secret::new(""),
Some(ClientCertificateCredentialOptions {
client_options: ClientOptions {
transport: Some(Transport::new(Arc::new(sts))),
..Default::default()
},
..Default::default()
}),
)
.expect("valid credential");
let token = credential
.get_token(LIVE_TEST_SCOPES, None)
.await
.expect("token");

assert_eq!(FAKE_TOKEN, token.token.secret());
let lifetime =
token.expires_on.unix_timestamp() - OffsetDateTime::now_utc().unix_timestamp();
assert!(
(3600..3601).contains(&lifetime),
"token should expire in ~3600 seconds but actually expires in {} seconds",
lifetime
);

let cached_token = credential
.get_token(LIVE_TEST_SCOPES, None)
.await
.expect("cached token");
assert_eq!(token.token.secret(), cached_token.token.secret());
assert_eq!(token.expires_on, cached_token.expires_on);
}

#[tokio::test]
async fn no_scopes() {
ClientCertificateCredential::new(
FAKE_TENANT_ID.to_string(),
FAKE_CLIENT_ID.to_string(),
TEST_CERT.to_string(),
Secret::new(""),
None,
)
.expect("valid credential")
.get_token(&[], None)
.await
.expect_err("no scopes provided");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ mod tests {
)],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
Some(FAKE_ASSERTION.to_string()),
))),
);
let cred = WorkloadIdentityCredential::new(Some(WorkloadIdentityCredentialOptions {
Expand Down Expand Up @@ -289,6 +290,7 @@ mod tests {
)],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
Some(FAKE_ASSERTION.to_string()),
))),
);
let cred = WorkloadIdentityCredential::new(Some(WorkloadIdentityCredentialOptions {
Expand Down Expand Up @@ -402,6 +404,7 @@ mod tests {
)],
Some(Arc::new(is_valid_request(
FAKE_PUBLIC_CLOUD_AUTHORITY.to_string(),
Some(FAKE_ASSERTION.to_string()),
))),
);
let cred = WorkloadIdentityCredential::new(Some(WorkloadIdentityCredentialOptions {
Expand Down
Binary file not shown.