Skip to content

Commit

Permalink
add audience to custom_service_account (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
dacozai authored Oct 14, 2024
1 parent 9ed8eef commit 127c336
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions src/custom_service_account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct CustomServiceAccount {
signer: Signer,
tokens: RwLock<HashMap<Vec<String>, Arc<Token>>>,
subject: Option<String>,
audience: Option<String>,
}

impl CustomServiceAccount {
Expand Down Expand Up @@ -59,6 +60,12 @@ impl CustomServiceAccount {
self
}

/// Set the `Audience` to impersonate a user
pub fn with_audience(mut self, audience: String) -> Self {
self.audience = Some(audience);
self
}

fn new(credentials: ServiceAccountKey, client: HttpClient) -> Result<Self, Error> {
debug!(project = ?credentials.project_id, email = credentials.client_email, "found credentials");
Ok(Self {
Expand All @@ -67,13 +74,19 @@ impl CustomServiceAccount {
credentials,
tokens: RwLock::new(HashMap::new()),
subject: None,
audience: None,
})
}

#[instrument(level = Level::DEBUG, skip(self))]
async fn fetch_token(&self, scopes: &[&str]) -> Result<Arc<Token>, Error> {
let jwt =
Claims::new(&self.credentials, scopes, self.subject.as_deref()).to_jwt(&self.signer)?;
let jwt = Claims::new(
&self.credentials,
scopes,
self.subject.as_deref(),
self.audience.as_deref(),
)
.to_jwt(&self.signer)?;
let body = Bytes::from(
form_urlencoded::Serializer::new(String::new())
.extend_pairs(&[("grant_type", GRANT_TYPE), ("assertion", jwt.as_str())])
Expand Down Expand Up @@ -156,7 +169,12 @@ pub(crate) struct Claims<'a> {
}

impl<'a> Claims<'a> {
pub(crate) fn new(key: &'a ServiceAccountKey, scopes: &[&str], sub: Option<&'a str>) -> Self {
pub(crate) fn new(
key: &'a ServiceAccountKey,
scopes: &[&str],
sub: Option<&'a str>,
aud: Option<&'a str>,
) -> Self {
let mut scope = String::with_capacity(16);
for (i, s) in scopes.iter().enumerate() {
if i != 0 {
Expand All @@ -169,7 +187,7 @@ impl<'a> Claims<'a> {
let iat = Utc::now().timestamp();
Claims {
iss: &key.client_email,
aud: &key.token_uri,
aud: aud.unwrap_or(&key.token_uri),
exp: iat + 3600 - 5, // Max validity is 1h
iat,
sub,
Expand Down

0 comments on commit 127c336

Please sign in to comment.