diff --git a/Cargo.toml b/Cargo.toml index 78361dd..3748d9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "firestore-db-and-auth" -version = "0.6.1" +version = "0.8.0" authors = ["David Gräff "] edition = "2018" license = "MIT" @@ -12,6 +12,8 @@ maintenance = { status = "passively-maintained" } repository = "https://github.com/davidgraeff/firestore-db-and-auth-rs" [dependencies] +bytes = "1.1" +cache_control = "0.2" reqwest = { version = "0.11", default-features = false, features = ["json", "blocking"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -19,6 +21,11 @@ chrono = { version = "0.4", features = ["serde"] } biscuit = "0.5" ring = "0.16" base64 = "0.13" +async-trait = "0.1" +tokio = { version = "1.13", features = ["macros"] } +futures = "0.3" +pin-project = "1.0" +http = "0.2" [dependencies.rocket] version = "0.4.6" diff --git a/examples/create_read_write_document.rs b/examples/create_read_write_document.rs index b5790ba..ee334de 100644 --- a/examples/create_read_write_document.rs +++ b/examples/create_read_write_document.rs @@ -3,6 +3,8 @@ use firestore_db_and_auth::{documents, dto, errors, sessions, Credentials, Fireb use firestore_db_and_auth::documents::WriteResult; use serde::{Deserialize, Serialize}; +use futures::stream::StreamExt; + mod utils; #[derive(Debug, Serialize, Deserialize)] @@ -21,7 +23,7 @@ struct DemoDTOPartial { an_int: u32, } -fn write_document(session: &mut ServiceSession, doc_id: &str) -> errors::Result { +async fn write_document(session: &mut ServiceSession, doc_id: &str) -> errors::Result { println!("Write document"); let obj = DemoDTO { @@ -30,10 +32,10 @@ fn write_document(session: &mut ServiceSession, doc_id: &str) -> errors::Result< a_timestamp: chrono::Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Nanos, true), }; - documents::write(session, "tests", Some(doc_id), &obj, documents::WriteOptions::default()) + documents::write(session, "tests", Some(doc_id), &obj, documents::WriteOptions::default()).await } -fn write_partial_document(session: &mut ServiceSession, doc_id: &str) -> errors::Result { +async fn write_partial_document(session: &mut ServiceSession, doc_id: &str) -> errors::Result { println!("Partial write document"); let obj = DemoDTOPartial { @@ -48,6 +50,7 @@ fn write_partial_document(session: &mut ServiceSession, doc_id: &str) -> errors: &obj, documents::WriteOptions { merge: true }, ) + .await } fn check_write(result: WriteResult, doc_id: &str) { @@ -62,25 +65,25 @@ fn check_write(result: WriteResult, doc_id: &str) { ); } -fn service_account_session(cred: Credentials) -> errors::Result<()> { - let mut session = ServiceSession::new(cred).unwrap(); - let b = session.access_token().to_owned(); +async fn service_account_session(cred: Credentials) -> errors::Result<()> { + let mut session = ServiceSession::new(cred).await.unwrap(); + let b = session.access_token().await.to_owned(); let doc_id = "service_test"; - check_write(write_document(&mut session, doc_id)?, doc_id); + check_write(write_document(&mut session, doc_id).await?, doc_id); // Check if cached value is used - assert_eq!(session.access_token(), b); + assert_eq!(session.access_token().await, b); println!("Read and compare document"); - let read: DemoDTO = documents::read(&mut session, "tests", doc_id)?; + let read: DemoDTO = documents::read(&mut session, "tests", doc_id).await?; assert_eq!(read.a_string, "abcd"); assert_eq!(read.an_int, 14); - check_write(write_partial_document(&mut session, doc_id)?, doc_id); + check_write(write_partial_document(&mut session, doc_id).await?, doc_id); println!("Read and compare document"); - let read: DemoDTOPartial = documents::read(&mut session, "tests", doc_id)?; + let read: DemoDTOPartial = documents::read(&mut session, "tests", doc_id).await?; // Should be updated assert_eq!(read.an_int, 16); @@ -90,14 +93,15 @@ fn service_account_session(cred: Credentials) -> errors::Result<()> { Ok(()) } -fn user_account_session(cred: Credentials) -> errors::Result<()> { - let user_session = utils::user_session_with_cached_refresh_token(&cred)?; +async fn user_account_session(cred: Credentials) -> errors::Result<()> { + let user_session = utils::user_session_with_cached_refresh_token(&cred).await?; assert_eq!(user_session.user_id, utils::TEST_USER_ID); assert_eq!(user_session.project_id(), cred.project_id); println!("user::Session::by_access_token"); - let user_session = sessions::user::Session::by_access_token(&cred, &user_session.access_token_unchecked())?; + let user_session = + sessions::user::Session::by_access_token(&cred, &user_session.access_token_unchecked().await).await?; assert_eq!(user_session.user_id, utils::TEST_USER_ID); @@ -117,13 +121,14 @@ fn user_account_session(cred: Credentials) -> errors::Result<()> { Some(doc_id), &obj, documents::WriteOptions::default(), - )?, + ) + .await?, doc_id, ); // Test reading println!("user::Session documents::read"); - let read: DemoDTO = documents::read(&user_session, "tests", doc_id)?; + let read: DemoDTO = documents::read(&user_session, "tests", doc_id).await?; assert_eq!(read.a_string, "abc"); assert_eq!(read.an_int, 12); @@ -135,14 +140,17 @@ fn user_account_session(cred: Credentials) -> errors::Result<()> { "abc".into(), dto::FieldOperator::EQUAL, "a_string", - )? + ) + .await? .collect(); assert_eq!(results.len(), 1); - let doc: DemoDTO = documents::read_by_name(&user_session, &results.get(0).unwrap().name)?; + let doc: DemoDTO = documents::read_by_name(&user_session, &results.get(0).unwrap().name).await?; assert_eq!(doc.a_string, "abc"); let mut count = 0; - let list_it: documents::List = documents::list(&user_session, "tests".to_owned()); + let list_it = documents::list(&user_session, "tests".to_owned()) + .collect::>>() + .await; for _doc in list_it { count += 1; } @@ -150,7 +158,7 @@ fn user_account_session(cred: Credentials) -> errors::Result<()> { // test if the call fails for a non existing document println!("user::Session documents::delete"); - let r = documents::delete(&user_session, "tests/non_existing", true); + let r = documents::delete(&user_session, "tests/non_existing", true).await; assert!(r.is_err()); match r.err().unwrap() { errors::FirebaseError::APIError(code, message, context) => { @@ -161,7 +169,7 @@ fn user_account_session(cred: Credentials) -> errors::Result<()> { _ => panic!("Expected an APIError"), }; - documents::delete(&user_session, &("tests/".to_owned() + doc_id), false)?; + documents::delete(&user_session, &("tests/".to_owned() + doc_id), false).await?; // Check if document is indeed removed println!("user::Session documents::query"); @@ -171,35 +179,40 @@ fn user_account_session(cred: Credentials) -> errors::Result<()> { "abc".into(), dto::FieldOperator::EQUAL, "a_string", - )? + ) + .await? .count(); assert_eq!(count, 0); println!("user::Session documents::query for f64"); let f: f64 = 13.37; - let count = documents::query(&user_session, "tests", f.into(), dto::FieldOperator::EQUAL, "a_float")?.count(); + + let count = documents::query(&user_session, "tests", f.into(), dto::FieldOperator::EQUAL, "a_float").await?; + + let count = count.count(); assert_eq!(count, 0); Ok(()) } -fn main() -> errors::Result<()> { +#[tokio::main] +async fn main() -> errors::Result<()> { // Search for a credentials file in the root directory use std::path::PathBuf; let mut credential_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")); credential_file.push("firebase-service-account.json"); - let mut cred = Credentials::from_file(credential_file.to_str().unwrap())?; + let cred = Credentials::from_file(credential_file.to_str().unwrap()).await?; // Only download the public keys once, and cache them. - let jwkset = utils::from_cache_file(credential_file.with_file_name("cached_jwks.jwks").as_path(), &cred)?; + let jwkset = utils::from_cache_file(credential_file.with_file_name("cached_jwks.jwks").as_path(), &cred).await?; cred.add_jwks_public_keys(&jwkset); - cred.verify()?; + cred.verify().await?; // Perform some db operations via a service account session - service_account_session(cred.clone())?; + service_account_session(cred.clone()).await?; // Perform some db operations via a firebase user session - user_account_session(cred)?; + user_account_session(cred).await?; Ok(()) } @@ -207,35 +220,35 @@ fn main() -> errors::Result<()> { /// For integration tests and doc code snippets: Create a Credentials instance. /// Necessary public jwk sets are downloaded or re-used if already present. #[cfg(test)] -fn valid_test_credentials() -> errors::Result { +async fn valid_test_credentials() -> errors::Result { use std::path::PathBuf; let mut jwks_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); jwks_path.push("firebase-service-account.jwks"); - let mut cred: Credentials = Credentials::new(include_str!("../firebase-service-account.json"))?; + let mut cred: Credentials = Credentials::new(include_str!("../tests/service-account-test.json"))?; // Only download the public keys once, and cache them. - let jwkset = utils::from_cache_file(jwks_path.as_path(), &cred)?; + let jwkset = utils::from_cache_file(jwks_path.as_path(), &cred).await?; cred.add_jwks_public_keys(&jwkset); cred.verify()?; Ok(cred) } -#[test] -fn valid_test_credentials_test() -> errors::Result<()> { - valid_test_credentials()?; +#[tokio::test] +async fn valid_test_credentials_test() -> errors::Result<()> { + valid_test_credentials().await?; Ok(()) } -#[test] -fn service_account_session_test() -> errors::Result<()> { - service_account_session(valid_test_credentials()?)?; +#[tokio::test] +async fn service_account_session_test() -> errors::Result<()> { + service_account_session(valid_test_credentials().await?).await?; Ok(()) } -#[test] -fn user_account_session_test() -> errors::Result<()> { - user_account_session(valid_test_credentials()?)?; +#[tokio::test] +async fn user_account_session_test() -> errors::Result<()> { + user_account_session(valid_test_credentials().await?).await?; Ok(()) } diff --git a/examples/firebase_user.rs b/examples/firebase_user.rs index 65e7ccb..4ca6ebd 100644 --- a/examples/firebase_user.rs +++ b/examples/firebase_user.rs @@ -3,13 +3,16 @@ use firestore_db_and_auth::*; const TEST_USER_ID: &str = include_str!("test_user_id.txt"); -fn main() -> errors::Result<()> { - let cred = Credentials::from_file("firebase-service-account.json").expect("Read credentials file"); +#[tokio::main] +async fn main() -> errors::Result<()> { + let cred = Credentials::from_file("firebase-service-account.json") + .await + .expect("Read credentials file"); - let user_session = UserSession::by_user_id(&cred, TEST_USER_ID, false)?; + let user_session = UserSession::by_user_id(&cred, TEST_USER_ID, false).await?; println!("users::user_info"); - let user_info_container = users::user_info(&user_session)?; + let user_info_container = users::user_info(&user_session).await?; assert_eq!(user_info_container.users[0].localId.as_ref().unwrap(), TEST_USER_ID); Ok(()) diff --git a/examples/own_auth.rs b/examples/own_auth.rs index 7099f88..014f61f 100644 --- a/examples/own_auth.rs +++ b/examples/own_auth.rs @@ -1,41 +1,37 @@ -use firestore_db_and_auth::errors::FirebaseError::APIError; use firestore_db_and_auth::{documents, errors, Credentials, FirebaseAuthBearer}; /// Define your own structure that will implement the FirebaseAuthBearer trait struct MyOwnSession { /// The google credentials pub credentials: Credentials, - pub blocking_client: reqwest::blocking::Client, pub client: reqwest::Client, access_token: String, } +#[async_trait::async_trait] impl FirebaseAuthBearer for MyOwnSession { fn project_id(&self) -> &str { &self.credentials.project_id } /// An access token. If a refresh token is known and the access token expired, /// the implementation should try to refresh the access token before returning. - fn access_token(&self) -> String { + async fn access_token(&self) -> String { self.access_token.clone() } /// The access token, unchecked. Might be expired or in other ways invalid. - fn access_token_unchecked(&self) -> String { + async fn access_token_unchecked(&self) -> String { self.access_token.clone() } /// The reqwest http client. /// The `Client` holds a connection pool internally, so it is advised that it is reused for multiple, successive connections. - fn client(&self) -> &reqwest::blocking::Client { - &self.blocking_client - } - - fn client_async(&self) -> &reqwest::Client { + fn client(&self) -> &reqwest::Client { &self.client } } -fn main() -> errors::Result<()> { - let credentials = Credentials::from_file("firebase-service-account.json")?; +#[tokio::main] +async fn main() -> errors::Result<()> { + let credentials = Credentials::from_file("firebase-service-account.json").await?; #[derive(serde::Serialize)] struct TestData { an_int: u32, @@ -44,7 +40,6 @@ fn main() -> errors::Result<()> { let session = MyOwnSession { credentials, - blocking_client: reqwest::blocking::Client::new(), client: reqwest::Client::new(), access_token: "The access token".to_owned(), }; @@ -56,12 +51,13 @@ fn main() -> errors::Result<()> { Some("test_doc"), &t, documents::WriteOptions::default(), - )?; + ) + .await?; Ok(()) } -#[test] -fn own_auth_test() { +#[tokio::test] +async fn own_auth_test() { if let Err(APIError(code, str_code, context)) = main() { assert_eq!(str_code, "Request had invalid authentication credentials. Expected OAuth 2 access token, login cookie or other valid authentication credential. See https://developers.google.com/identity/sign-in/web/devconsole-project."); assert_eq!(context, "test_doc"); diff --git a/examples/session_cookie.rs b/examples/session_cookie.rs index 6177ba9..7686f72 100644 --- a/examples/session_cookie.rs +++ b/examples/session_cookie.rs @@ -4,22 +4,23 @@ use chrono::Duration; mod utils; -fn main() -> Result<(), FirebaseError> { +#[tokio::main] +async fn main() -> Result<(), FirebaseError> { // Search for a credentials file in the root directory use std::path::PathBuf; let mut credential_file = PathBuf::from(env!("CARGO_MANIFEST_DIR")); credential_file.push("firebase-service-account.json"); - let mut cred = Credentials::from_file(credential_file.to_str().unwrap())?; + let cred = Credentials::from_file(credential_file.to_str().unwrap()).await?; // Only download the public keys once, and cache them. - let jwkset = utils::from_cache_file(credential_file.with_file_name("cached_jwks.jwks").as_path(), &cred)?; + let jwkset = utils::from_cache_file(credential_file.with_file_name("cached_jwks.jwks").as_path(), &cred).await?; cred.add_jwks_public_keys(&jwkset); - cred.verify()?; + cred.verify().await?; - let user_session = utils::user_session_with_cached_refresh_token(&cred)?; + let user_session = utils::user_session_with_cached_refresh_token(&cred).await?; - let cookie = session_cookie::create(&cred, user_session.access_token(), Duration::seconds(3600))?; + let cookie = session_cookie::create(&cred, user_session.access_token().await, Duration::seconds(3600)).await?; println!("Created session cookie: {}", cookie); Ok(()) diff --git a/examples/utils/mod.rs b/examples/utils/mod.rs index acddd4d..a001ad1 100644 --- a/examples/utils/mod.rs +++ b/examples/utils/mod.rs @@ -5,7 +5,7 @@ use firestore_db_and_auth::jwt::download_google_jwks; #[allow(dead_code)] pub const TEST_USER_ID: &str = include_str!("../test_user_id.txt"); -pub fn user_session_with_cached_refresh_token(cred: &Credentials) -> errors::Result { +pub async fn user_session_with_cached_refresh_token(cred: &Credentials) -> errors::Result { println!("Refresh token from file"); // Read refresh token from file if possible instead of generating a new refresh token each time let refresh_token: String = match std::fs::read_to_string("refresh-token-for-tests.txt") { @@ -21,12 +21,12 @@ pub fn user_session_with_cached_refresh_token(cred: &Credentials) -> errors::Res // Generate a new refresh token if necessary println!("Generate new user auth token"); let user_session: sessions::user::Session = if refresh_token.is_empty() { - let session = sessions::user::Session::by_user_id(&cred, TEST_USER_ID, true)?; + let session = sessions::user::Session::by_user_id(&cred, TEST_USER_ID, true).await?; std::fs::write("refresh-token-for-tests.txt", &session.refresh_token.as_ref().unwrap())?; session } else { println!("user::Session::by_refresh_token"); - sessions::user::Session::by_refresh_token(&cred, &refresh_token)? + sessions::user::Session::by_refresh_token(&cred, &refresh_token).await? }; Ok(user_session) @@ -35,7 +35,7 @@ pub fn user_session_with_cached_refresh_token(cred: &Credentials) -> errors::Res /// Download the two public key JWKS files if necessary and cache the content at the given file path. /// Only use this option in cloud functions if the given file path is persistent. /// You can use [`Credentials::add_jwks_public_keys`] to manually add more public keys later on. -pub fn from_cache_file(cache_file: &std::path::Path, c: &Credentials) -> errors::Result { +pub async fn from_cache_file(cache_file: &std::path::Path, c: &Credentials) -> errors::Result { use std::fs::File; use std::io::BufReader; @@ -46,9 +46,9 @@ pub fn from_cache_file(cache_file: &std::path::Path, c: &Credentials) -> errors: } else { // If not present, download the two jwks (specific service account + google system account), // merge them into one set of keys and store them in the cache file. - let mut jwks = JWKSet::new(&download_google_jwks(&c.client_email)?)?; + let mut jwks = JWKSet::new(&download_google_jwks(&c.client_email).await?.0)?; jwks.keys - .append(&mut JWKSet::new(&download_google_jwks("securetoken@system.gserviceaccount.com")?)?.keys); + .append(&mut JWKSet::new(&download_google_jwks("securetoken@system.gserviceaccount.com").await?.0)?.keys); let f = File::create(cache_file)?; serde_json::to_writer_pretty(f, &jwks)?; jwks @@ -59,15 +59,15 @@ pub fn from_cache_file(cache_file: &std::path::Path, c: &Credentials) -> errors: /// Necessary public jwk sets are downloaded or re-used if already present. #[cfg(test)] #[allow(dead_code)] -pub fn valid_test_credentials() -> errors::Result { +pub async fn valid_test_credentials() -> errors::Result { use std::path::PathBuf; let mut jwks_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); jwks_path.push("firebase-service-account.jwks"); - let mut cred: Credentials = Credentials::new(include_str!("../../firebase-service-account.json"))?; + let mut cred: Credentials = Credentials::new(include_str!("../../tests/service-account-test.json"))?; // Only download the public keys once, and cache them. - let jwkset = from_cache_file(jwks_path.as_path(), &cred)?; + let jwkset = from_cache_file(jwks_path.as_path(), &cred).await?; cred.add_jwks_public_keys(&jwkset); cred.verify()?; diff --git a/src/credentials.rs b/src/credentials.rs index 244f9de..9d57d11 100644 --- a/src/credentials.rs +++ b/src/credentials.rs @@ -2,15 +2,17 @@ //! This module contains the [`crate::credentials::Credentials`] type, used by [`crate::sessions`] to create and maintain //! authentication tokens for accessing the Firebase REST API. -use chrono::Duration; +use chrono::{offset, DateTime, Duration}; use serde::{Deserialize, Serialize}; use serde_json; use std::collections::BTreeMap; +use std::fmt; use std::fs::File; use std::sync::Arc; +use tokio::sync::RwLock; use super::jwt::{create_jwt_encoded, download_google_jwks, verify_access_token, JWKSet, JWT_AUDIENCE_IDENTITY}; -use crate::errors::FirebaseError; +use crate::{errors::FirebaseError, jwt::TokenValidationResult}; use std::io::BufReader; type Error = super::errors::FirebaseError; @@ -19,12 +21,23 @@ type Error = super::errors::FirebaseError; #[derive(Default, Clone)] pub(crate) struct Keys { pub pub_key: BTreeMap>, + pub pub_key_expires_at: Option>, pub secret: Option>, } +impl fmt::Debug for Keys { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Keys") + .field("pub_key_expires_at", &self.pub_key_expires_at) + .field("pub_key", &self.pub_key.keys().collect::>()) + .field("secret", &self.secret.is_some()) + .finish() + } +} + /// Service account credentials /// -/// Especially the service account email is required to retrieve the public java web key set (jwks) +/// Especially the service account email is required to retrieve the public json web key set (jwks) /// for verifying Google Firestore tokens. /// /// The api_key is necessary for interacting with the Firestore REST API. @@ -33,7 +46,7 @@ pub(crate) struct Keys { /// /// The private key is used for signing JWTs (javascript web token). /// A signed jwt, encoded as a base64 string, can be exchanged into a refresh and access token. -#[derive(Serialize, Deserialize, Default, Clone)] +#[derive(Serialize, Deserialize, Default, Clone, Debug)] pub struct Credentials { pub project_id: String, pub private_key_id: String, @@ -42,7 +55,7 @@ pub struct Credentials { pub client_id: String, pub api_key: String, #[serde(default, skip)] - pub(crate) keys: Keys, + pub(crate) keys: Arc>, } /// Converts a PEM (ascii base64) encoded private key into the binary der representation @@ -108,9 +121,9 @@ impl Credentials { /// You need two JWKS files for this crate to work: /// * https://www.googleapis.com/service_accounts/v1/jwk/securetoken@system.gserviceaccount.com /// * https://www.googleapis.com/service_accounts/v1/jwk/{your-service-account-email} - pub fn new(credentials_file_content: &str) -> Result { + pub async fn new(credentials_file_content: &str) -> Result { let mut credentials: Credentials = serde_json::from_str(credentials_file_content)?; - credentials.compute_secret()?; + credentials.compute_secret().await?; Ok(credentials) } @@ -118,10 +131,10 @@ impl Credentials { /// /// This is a convenience method, that reads in the given credentials file and acts otherwise the same as /// the [`Credentials::new`] method. - pub fn from_file(credential_file: &str) -> Result { + pub async fn from_file(credential_file: &str) -> Result { let f = BufReader::new(File::open(credential_file)?); let mut credentials: Credentials = serde_json::from_reader(f)?; - credentials.compute_secret()?; + credentials.compute_secret().await?; Ok(credentials) } @@ -129,9 +142,9 @@ impl Credentials { /// /// This method will also verify that the given JWKs files allow verification of Google access tokens. /// This is a convenience method, you may also just use [`Credentials::add_jwks_public_keys`]. - pub fn with_jwkset(mut self, jwks: &JWKSet) -> Result { - self.add_jwks_public_keys(jwks); - self.verify()?; + pub async fn with_jwkset(self, jwks: &JWKSet) -> Result { + self.add_jwks_public_keys(jwks).await; + self.verify().await?; Ok(self) } @@ -153,15 +166,15 @@ impl Credentials { /// .download_jwkset()?; /// # Ok::<(), firestore_db_and_auth::errors::FirebaseError>(()) /// ``` - pub fn download_jwkset(mut self) -> Result { - self.download_google_jwks()?; - self.verify()?; + pub async fn download_jwkset(self) -> Result { + self.download_google_jwks().await?; + self.verify().await?; Ok(self) } /// Verifies that creating access tokens is possible with the given credentials and public keys. /// Returns an empty result type on success. - pub fn verify(&self) -> Result<(), Error> { + pub async fn verify(&self) -> Result<(), Error> { let access_token = create_jwt_encoded( &self, Some(["admin"].iter()), @@ -169,15 +182,31 @@ impl Credentials { Some(self.client_id.clone()), None, JWT_AUDIENCE_IDENTITY, - )?; - verify_access_token(&self, &access_token)?; + ) + .await?; + verify_access_token(&self, &access_token).await?; Ok(()) } + pub async fn verify_token(&self, token: &str) -> Result { + verify_access_token(&self, token).await + } + /// Find the secret in the jwt set that matches the given key id, if any. /// Used for jws validation - pub fn decode_secret(&self, kid: &str) -> Option> { - self.keys.pub_key.get(kid).and_then(|f| Some(f.clone())) + pub async fn decode_secret(&self, kid: &str) -> Result>, Error> { + let should_refresh = { + let keys = self.keys.read().await; + keys.pub_key_expires_at + .map(|expires_at| expires_at - offset::Utc::now() < Duration::minutes(10)) + .unwrap_or(false) + }; + + if should_refresh { + self.download_google_jwks().await?; + } + + Ok(self.keys.read().await.pub_key.get(kid).and_then(|f| Some(f.clone()))) } /// Add a JSON Web Key Set (JWKS) to allow verification of Google access tokens. @@ -194,60 +223,86 @@ impl Credentials { /// c.verify()?; /// # Ok::<(), firestore_db_and_auth::errors::FirebaseError>(()) /// ``` - pub fn add_jwks_public_keys(&mut self, jwkset: &JWKSet) { + pub async fn add_jwks_public_keys(&self, jwkset: &JWKSet) { + let mut keys = self.keys.write().await; + for entry in jwkset.keys.iter() { if !entry.headers.key_id.is_some() { continue; } let key_id = entry.headers.key_id.as_ref().unwrap().to_owned(); - self.keys - .pub_key - .insert(key_id, Arc::new(entry.ne.jws_public_key_secret())); + keys.pub_key.insert(key_id, Arc::new(entry.ne.jws_public_key_secret())); } } /// If you haven't called [`Credentials::add_jwks_public_keys`] to manually add public keys, /// this method will download one for your google service account and one for the oauth related /// securetoken@system.gserviceaccount.com service account. - pub fn download_google_jwks(&mut self) -> Result<(), Error> { - let jwks = download_google_jwks(&self.client_email)?; - self.add_jwks_public_keys(&JWKSet::new(&jwks)?); - let jwks = download_google_jwks("securetoken@system.gserviceaccount.com")?; - self.add_jwks_public_keys(&JWKSet::new(&jwks)?); + pub async fn download_google_jwks(&self) -> Result<(), Error> { + { + let mut keys = self.keys.write().await; + keys.pub_key = BTreeMap::new(); + } + + let (jwks, max_age_client) = download_google_jwks(&self.client_email).await?; + self.add_jwks_public_keys(&JWKSet::new(&jwks)?).await; + let (jwks, max_age_public) = download_google_jwks("securetoken@system.gserviceaccount.com").await?; + self.add_jwks_public_keys(&JWKSet::new(&jwks)?).await; + + let default_expiration = Duration::hours(2); + let max_age_client = max_age_client.unwrap_or(default_expiration); + let max_age_public = max_age_public.unwrap_or(default_expiration); + + let expires_at = if max_age_client < max_age_public { + max_age_client + } else { + max_age_public + }; + + { + let mut keys = self.keys.write().await; + keys.pub_key_expires_at = Some(offset::Utc::now() + expires_at); + } + Ok(()) } + /// Compute the Rsa keypair by using the private_key of the credentials file. /// You must call this if you have manually created a credentials object. /// /// This is automatically invoked if you use [`Credentials::new`] or [`Credentials::from_file`]. - pub fn compute_secret(&mut self) -> Result<(), Error> { + pub async fn compute_secret(&mut self) -> Result<(), Error> { use biscuit::jws::Secret; use ring::signature; let vec = pem_to_der(&self.private_key)?; let key_pair = signature::RsaKeyPair::from_pkcs8(&vec)?; - self.keys.secret = Some(Arc::new(Secret::RsaKeyPair(Arc::new(key_pair)))); + self.keys.write().await.secret = Some(Arc::new(Secret::RsaKeyPair(Arc::new(key_pair)))); Ok(()) } } #[doc(hidden)] #[allow(dead_code)] -pub fn doctest_credentials() -> Credentials { +pub async fn doctest_credentials() -> Credentials { let jwk_list = JWKSet::new(include_str!("../tests/service-account-test.jwks")).unwrap(); Credentials::new(include_str!("../tests/service-account-test.json")) + .await .expect("Failed to deserialize credentials") .with_jwkset(&jwk_list) + .await .expect("JWK public keys verification failed") } -#[test] -fn deserialize_credentials() { +#[tokio::test] +async fn deserialize_credentials() { let jwk_list = JWKSet::new(include_str!("../tests/service-account-test.jwks")).unwrap(); let c: Credentials = Credentials::new(include_str!("../tests/service-account-test.json")) + .await .expect("Failed to deserialize credentials") .with_jwkset(&jwk_list) + .await .expect("JWK public keys verification failed"); assert_eq!(c.api_key, "api_key"); @@ -256,8 +311,10 @@ fn deserialize_credentials() { credential_file.push("tests/service-account-test.json"); let c = Credentials::from_file(credential_file.to_str().unwrap()) + .await .expect("Failed to open credentials file") .with_jwkset(&jwk_list) + .await .expect("JWK public keys verification failed"); assert_eq!(c.api_key, "api_key"); } diff --git a/src/documents/delete.rs b/src/documents/delete.rs index 725df5b..198904e 100644 --- a/src/documents/delete.rs +++ b/src/documents/delete.rs @@ -11,7 +11,7 @@ use crate::errors::extract_google_api_error_async; /// * 'auth' The authentication token /// * 'path' The relative collection path and document id, for example "my_collection/document_id" /// * 'fail_if_not_existing' If true this method will return an error if the document does not exist. -pub fn delete(auth: &impl FirebaseAuthBearer, path: &str, fail_if_not_existing: bool) -> Result<()> { +pub async fn delete(auth: &impl FirebaseAuthBearer, path: &str, fail_if_not_existing: bool) -> Result<()> { let url = firebase_url(auth.project_id(), path); let query_request = dto::Write { @@ -28,45 +28,7 @@ pub fn delete(auth: &impl FirebaseAuthBearer, path: &str, fail_if_not_existing: let resp = auth .client() .delete(&url) - .bearer_auth(auth.access_token().to_owned()) - .json(&query_request) - .send()?; - - extract_google_api_error(resp, || path.to_owned())?; - - Ok({}) -} - -//#[unstable(feature = "unstable", issue = "1234", reason = "Not yet decided if _async suffix or own module namespace")] -/// -/// Deletes the document at the given path. -/// -/// You cannot use this directly with paths from [`list`] and [`query`] document metadata objects. -/// Those contain an absolute document path. Use [`abs_to_rel`] to convert to a relative path. -/// -/// ## Arguments -/// * 'auth' The authentication token -/// * 'path' The relative collection path and document id, for example "my_collection/document_id" -/// * 'fail_if_not_existing' If true this method will return an error if the document does not exist. -#[cfg(feature = "unstable")] -pub async fn delete_async(auth: &impl FirebaseAuthBearer, path: &str, fail_if_not_existing: bool) -> Result<()> { - let url = firebase_url(auth.project_id(), path); - - let query_request = dto::Write { - current_document: Some(dto::Precondition { - exists: match fail_if_not_existing { - true => Some(true), - false => None, - }, - ..Default::default() - }), - ..Default::default() - }; - - let resp = auth - .client_async() - .delete(&url) - .bearer_auth(auth.access_token().to_owned()) + .bearer_auth(auth.access_token().await.to_owned()) .json(&query_request) .send() .await?; diff --git a/src/documents/list.rs b/src/documents/list.rs index fb37a3d..53ee33e 100644 --- a/src/documents/list.rs +++ b/src/documents/list.rs @@ -1,4 +1,12 @@ use super::*; +use bytes::Bytes; +use core::pin::Pin; +use futures::{ + stream::{self, Stream}, + task::{Context, Poll}, + Future, +}; +use std::boxed::Box; /// List all documents of a given collection. /// @@ -30,25 +38,85 @@ use super::*; /// ## Arguments /// * 'auth' The authentication token /// * 'collection_id' The document path / collection; For example "my_collection" or "a/nested/collection" -pub fn list(auth: &BEARER, collection_id: impl Into) -> List +pub fn list( + auth: &AUTH, + collection_id: impl Into, +) -> Pin> + Send>> where - BEARER: FirebaseAuthBearer, + for<'b> T: Deserialize<'b> + 'static, + AUTH: FirebaseAuthBearer + Clone + Send + Sync + 'static, { + let auth = auth.clone(); let collection_id = collection_id.into(); - List { - url: firebase_url(auth.project_id(), &collection_id), - auth, - next_page_token: None, - documents: vec![], - current: 0, - done: false, - collection_id, - phantom: std::marker::PhantomData, - } + + Box::pin(stream::unfold( + ListInner { + url: firebase_url(auth.project_id(), &collection_id), + auth, + next_page_token: None, + documents: vec![], + current: 0, + done: false, + collection_id: collection_id.to_string(), + }, + |this| async move { + let mut this = this.clone(); + if this.done { + return None; + } + + if this.documents.len() <= this.current { + let url = match &this.next_page_token { + Some(next_page_token) => format!("{}pageToken={}", this.url, next_page_token), + None => this.url.clone(), + }; + + let result = get_new_data(&this.collection_id, &url, &this.auth).await; + match result { + Err(e) => { + this.done = true; + return Some((Err(e), this)); + } + Ok(v) => match v.documents { + None => return None, + Some(documents) => { + this.documents = documents; + this.current = 0; + this.next_page_token = v.next_page_token; + } + }, + } + } + + let doc = this.documents.get(this.current).unwrap().clone(); + + this.current += 1; + + if this.documents.len() <= this.current && this.next_page_token.is_none() { + this.done = true; + } + + let result = document_to_pod(&Bytes::new(), &doc); + match result { + Err(e) => Some((Err(e), this)), + Ok(pod) => Some(( + Ok(( + pod, + dto::Document { + update_time: doc.update_time.clone(), + create_time: doc.create_time.clone(), + name: doc.name.clone(), + fields: None, + }, + )), + this, + )), + } + }, + )) } -#[inline] -fn get_new_data<'a>( +async fn get_new_data<'a>( collection_id: &str, url: &str, auth: &'a impl FirebaseAuthBearer, @@ -56,85 +124,23 @@ fn get_new_data<'a>( let resp = auth .client() .get(url) - .bearer_auth(auth.access_token().to_owned()) - .send()?; + .bearer_auth(auth.access_token().await) + .send() + .await?; - let resp = extract_google_api_error(resp, || collection_id.to_owned())?; + let resp = extract_google_api_error_async(resp, || collection_id.to_owned()).await?; - let json: dto::ListDocumentsResponse = resp.json()?; + let json: dto::ListDocumentsResponse = resp.json().await?; Ok(json) } -/// This type is returned as a result by [`list`]. -/// Use it as an iterator. The paging API is used internally and new pages are fetched lazily. -/// -/// Please note that this API acts as an iterator of same-like documents. -/// This type is not suitable if you want to list documents of different types. -pub struct List<'a, T, BEARER> { - auth: &'a BEARER, +#[derive(Clone)] +struct ListInner { + auth: AUTH, next_page_token: Option, documents: Vec, current: usize, done: bool, url: String, collection_id: String, - phantom: std::marker::PhantomData, -} - -impl<'a, T, BEARER> Iterator for List<'a, T, BEARER> -where - for<'b> T: Deserialize<'b>, - BEARER: FirebaseAuthBearer, -{ - type Item = Result<(T, dto::Document)>; - - fn next(&mut self) -> Option { - if self.done { - return None; - } - - if self.documents.len() <= self.current { - let url = match &self.next_page_token { - Some(next_page_token) => format!("{}pageToken={}", self.url, next_page_token), - None => self.url.clone(), - }; - - let result = get_new_data(&self.collection_id, &url, self.auth); - match result { - Err(e) => { - self.done = true; - return Some(Err(e)); - } - Ok(v) => match v.documents { - None => return None, - Some(documents) => { - self.documents = documents; - self.current = 0; - self.next_page_token = v.next_page_token; - } - }, - }; - } - - let doc = self.documents.get(self.current).unwrap(); - - self.current += 1; - if self.documents.len() <= self.current && self.next_page_token.is_none() { - self.done = true; - } - - let result = document_to_pod(&doc); - match result { - Err(e) => Some(Err(e)), - Ok(pod) => Some(Ok(( - pod, - dto::Document { - update_time: doc.update_time.clone(), - create_time: doc.create_time.clone(), - name: doc.name.clone(), - fields: None, - }, - ))), - } - } } diff --git a/src/documents/mod.rs b/src/documents/mod.rs index a991c87..e5f52d0 100644 --- a/src/documents/mod.rs +++ b/src/documents/mod.rs @@ -2,9 +2,9 @@ //! //! Interact with Firestore documents. //! Please check the root page of this documentation for examples. - +#![allow(unused_imports, dead_code)] use super::dto; -use super::errors::{extract_google_api_error, FirebaseError, Result}; +use super::errors::{extract_google_api_error, extract_google_api_error_async, FirebaseError, Result}; use super::firebase_rest_to_rust::{document_to_pod, pod_to_document}; use super::FirebaseAuthBearer; diff --git a/src/documents/query.rs b/src/documents/query.rs index 9565d13..008fb04 100644 --- a/src/documents/query.rs +++ b/src/documents/query.rs @@ -32,7 +32,7 @@ use std::vec::IntoIter; /// * 'value' The query / filter value. For example "car". /// * 'operator' The query operator. For example "EQUAL". /// * 'field' The query / filter field. For example "type". -pub fn query( +pub async fn query( auth: &impl FirebaseAuthBearer, collection_id: &str, value: serde_json::Value, @@ -67,13 +67,14 @@ pub fn query( let resp = auth .client() .post(&url) - .bearer_auth(auth.access_token().to_owned()) + .bearer_auth(auth.access_token().await) .json(&query_request) - .send()?; + .send() + .await?; - let resp = extract_google_api_error(resp, || collection_id.to_owned())?; + let resp = extract_google_api_error_async(resp, || collection_id.to_owned()).await?; - let json: Option> = resp.json()?; + let json: Option> = resp.json().await?; Ok(Query(json.unwrap_or_default().into_iter())) } @@ -86,6 +87,7 @@ pub fn query( /// /// Please note that this API acts as an iterator of same-like documents. /// This type is not suitable if you want to list documents of different types. +#[derive(Debug)] pub struct Query(IntoIter); impl Iterator for Query { diff --git a/src/documents/read.rs b/src/documents/read.rs index 1d295dc..a4914eb 100644 --- a/src/documents/read.rs +++ b/src/documents/read.rs @@ -7,14 +7,22 @@ use std::io::Read; /// ## Arguments /// * `auth` The authentication token /// * `document_name` The document path / collection and document id; For example `projects/my_project/databases/(default)/documents/tests/test` -pub fn read_by_name(auth: &impl FirebaseAuthBearer, document_name: impl AsRef) -> Result +pub async fn read_by_name(auth: &impl FirebaseAuthBearer, document_name: &str) -> Result where for<'b> T: Deserialize<'b>, { - let resp = request_document(auth, document_name)?; - // Here `resp.json()?` is a method provided by `reqwest` - let json: dto::Document = resp.json()?; - Ok(document_to_pod(&json)?) + let resp = request_document(auth, document_name).await?; + + // We take the raw response first in order to provide + // more complete errors on deserialization failure + let full = resp.bytes().await?; + let json = serde_json::from_slice(&full).map_err(|e| FirebaseError::SerdeVerbose { + doc: Some(String::from(document_name)), + input_doc: String::from_utf8_lossy(&full).to_string(), + ser: e, + })?; + + Ok(document_to_pod(&full, &json)?) } /// @@ -24,12 +32,12 @@ where /// * `auth` The authentication token /// * `path` The document path / collection; For example `my_collection` or `a/nested/collection` /// * `document_id` The document id. Make sure that you do not include the document id to the path argument. -pub fn read(auth: &impl FirebaseAuthBearer, path: &str, document_id: impl AsRef) -> Result +pub async fn read(auth: &impl FirebaseAuthBearer, path: &str, document_id: &str) -> Result where for<'b> T: Deserialize<'b>, { let document_name = document_name(&auth.project_id(), path, document_id); - read_by_name(auth, &document_name) + read_by_name(auth, &document_name).await } /// Return the raw unparsed content of the Firestore document. Methods like @@ -39,39 +47,31 @@ where /// Note that this leverages [`std::io::Read`](https://doc.rust-lang.org/std/io/trait.Read.html) and the `read_to_string()` method to chunk the /// response. This will raise `FirebaseError::IO` if there are errors reading the stream. Please /// see [`read_to_end()`](https://doc.rust-lang.org/std/io/trait.Read.html#method.read_to_end) -pub fn contents(auth: &impl FirebaseAuthBearer, path: &str, document_id: impl AsRef) -> Result { +pub async fn contents(auth: &impl FirebaseAuthBearer, path: &str, document_id: &str) -> Result { let document_name = document_name(&auth.project_id(), path, document_id); - let mut resp = request_document(auth, document_name)?; - let mut text = String::new(); - match resp.read_to_string(&mut text) { - Ok(_bytes) => Ok(text), - Err(e) => Err(FirebaseError::IO(e)), - } + let resp = request_document(auth, &document_name).await?; + resp.text().await.map_err(|e| FirebaseError::Request(e)) } /// Executes the request to retrieve the document. Returns the response from `reqwest` -fn request_document( - auth: &impl FirebaseAuthBearer, - document_name: impl AsRef, -) -> Result { +async fn request_document(auth: &impl FirebaseAuthBearer, document_name: &str) -> Result { let url = firebase_url_base(document_name.as_ref()); let resp = auth .client() .get(&url) - .bearer_auth(auth.access_token().to_owned()) - .send()?; + .bearer_auth(auth.access_token().await) + .send() + .await?; - extract_google_api_error(resp, || document_name.as_ref().to_owned()) + extract_google_api_error_async(resp, || document_name.to_owned()).await } /// Simple method to join the path and document identifier in correct format -fn document_name(project_id: impl AsRef, path: impl AsRef, document_id: impl AsRef) -> String { +fn document_name(project_id: &str, path: &str, document_id: &str) -> String { format!( "projects/{}/databases/(default)/documents/{}/{}", - project_id.as_ref(), - path.as_ref(), - document_id.as_ref() + project_id, path, document_id ) } diff --git a/src/documents/write.rs b/src/documents/write.rs index 52a615d..de2c709 100644 --- a/src/documents/write.rs +++ b/src/documents/write.rs @@ -78,7 +78,7 @@ pub struct WriteOptions { /// * 'document_id' The document id. Make sure that you do not include the document id in the path argument. /// * 'document' The document /// * 'options' Write options -pub fn write( +pub async fn write( auth: &impl FirebaseAuthBearer, path: &str, document_id: Option>, @@ -107,19 +107,21 @@ where }; let resp = builder - .bearer_auth(auth.access_token().to_owned()) + .bearer_auth(auth.access_token().await.to_owned()) .json(&firebase_document) - .send()?; + .send() + .await?; - let resp = extract_google_api_error(resp, || { + let resp = extract_google_api_error_async(resp, || { document_id .as_ref() .and_then(|f| Some(f.as_ref().to_owned())) .or(Some(String::new())) .unwrap() - })?; + }) + .await?; - let result_document: dto::Document = resp.json()?; + let result_document: dto::Document = resp.json().await?; let document_id = Path::new(&result_document.name) .file_name() .ok_or_else(|| FirebaseError::Generic("Resulting documents 'name' field is not a valid path"))? diff --git a/src/dto.rs b/src/dto.rs index 8a00323..aa14644 100644 --- a/src/dto.rs +++ b/src/dto.rs @@ -596,7 +596,7 @@ mod tests { "updateTime": "2020-04-28T14:52:51.250511Z" }"#; let document: Result = serde_json::from_str(&doc); - assert!(document.is_ok(), true); + assert!(document.is_ok()); } #[test] @@ -659,6 +659,6 @@ mod tests { "updateTime": "2020-04-28T14:52:51.250511Z" }"#; let document: Result = serde_json::from_str(&doc); - assert!(document.is_ok(), true); + assert!(document.is_ok()); } } diff --git a/src/errors.rs b/src/errors.rs index 055b2f9..a0c500b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -40,6 +40,12 @@ pub enum FirebaseError { doc: Option, ser: serde_json::Error, }, + /// Verbose deserialization failure + SerdeVerbose { + doc: Option, + input_doc: String, + ser: serde_json::Error, + }, /// When the credentials.json file contains an invalid private key this error is returned RSA(ring::error::KeyRejected), /// Disk access errors @@ -84,29 +90,37 @@ impl std::convert::From for FirebaseError { impl fmt::Display for FirebaseError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { + match self { FirebaseError::Generic(m) => write!(f, "{}", m), - FirebaseError::APIError(code, ref m, ref context) => { + FirebaseError::APIError(code, m, context) => { write!(f, "API Error! Code {} - {}. Context: {}", code, m, context) } - FirebaseError::UnexpectedResponse(m, status, ref text, ref source) => { + FirebaseError::UnexpectedResponse(m, status, text, source) => { writeln!(f, "{} - {}", &m, status)?; writeln!(f, "{}", text)?; writeln!(f, "{}", source)?; Ok(()) } - FirebaseError::Request(ref e) => e.fmt(f), - FirebaseError::JWT(ref e) => e.fmt(f), - FirebaseError::JWTValidation(ref e) => e.fmt(f), - FirebaseError::RSA(ref e) => e.fmt(f), - FirebaseError::IO(ref e) => e.fmt(f), - FirebaseError::Ser { ref doc, ref ser } => { + FirebaseError::Request(e) => e.fmt(f), + FirebaseError::JWT(e) => e.fmt(f), + FirebaseError::JWTValidation(e) => e.fmt(f), + FirebaseError::RSA(e) => e.fmt(f), + FirebaseError::IO(e) => e.fmt(f), + FirebaseError::Ser { doc, ser } => { if let Some(doc) = doc { writeln!(f, "{} in document {}", ser, doc) } else { ser.fmt(f) } } + FirebaseError::SerdeVerbose { doc, input_doc, ser } => { + let doc = doc.clone().unwrap_or("Unknown document".to_string()); + writeln!( + f, + "Serde deserialization failed for document '{}' with error '{}' on input: '{}'", + doc, ser, input_doc + ) + } } } } @@ -123,6 +137,7 @@ impl error::Error for FirebaseError { FirebaseError::RSA(_) => None, FirebaseError::IO(ref e) => Some(e), FirebaseError::Ser { ref ser, .. } => Some(ser), + FirebaseError::SerdeVerbose { ref ser, .. } => Some(ser), } } } diff --git a/src/firebase_rest_to_rust.rs b/src/firebase_rest_to_rust.rs index 9f68a0e..3ca0735 100644 --- a/src/firebase_rest_to_rust.rs +++ b/src/firebase_rest_to_rust.rs @@ -3,6 +3,7 @@ //! the data types of the Firebase REST API. Those are 1:1 translations of the grpc API //! and deeply nested and wrapped. +use bytes::Bytes; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; @@ -10,7 +11,7 @@ use std::collections::HashMap; use super::dto; use super::errors::{FirebaseError, Result}; -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct Wrapper { #[serde(flatten)] extra: HashMap, @@ -116,7 +117,7 @@ pub(crate) fn serde_value_to_firebase_value(v: &serde_json::Value) -> dto::Value /// Internals: /// /// This method uses recursion to decode the given firebase type. -pub fn document_to_pod(document: &dto::Document) -> Result +pub fn document_to_pod(input_doc: &Bytes, document: &dto::Document) -> Result where for<'de> T: Deserialize<'de>, { @@ -137,8 +138,9 @@ where }; let v = serde_json::to_value(r)?; - let r: T = serde_json::from_value(v).map_err(|e| FirebaseError::Ser { + let r: T = serde_json::from_value(v).map_err(|e| FirebaseError::SerdeVerbose { doc: Some(document.name.clone()), + input_doc: String::from_utf8_lossy(input_doc).replace("\n", " ").to_string(), ser: e, })?; Ok(r) diff --git a/src/jwt.rs b/src/jwt.rs index 531cfd3..13aa0ae 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -6,11 +6,13 @@ use serde::{Deserialize, Serialize}; use chrono::{Duration, Utc}; use std::collections::HashSet; +use std::ops::Add; use std::slice::Iter; use crate::errors::FirebaseError; use biscuit::jwa::SignatureAlgorithm; use biscuit::{ClaimPresenceOptions, SingleOrMultiple, ValidationOptions}; +use cache_control::CacheControl; use std::ops::Deref; type Error = super::errors::FirebaseError; @@ -60,40 +62,34 @@ impl JWKSet { } /// Download the Google JWK Set for a given service account. +/// Returns the JWKS alongside the maximum time the JWKS is valid for. /// The resulting set of JWKs need to be added to a credentials object /// for jwk verifications. -pub fn download_google_jwks(account_mail: &str) -> Result { +pub async fn download_google_jwks(account_mail: &str) -> Result<(String, Option), Error> { let url = format!("https://www.googleapis.com/service_accounts/v1/jwk/{}", account_mail); - let resp = reqwest::blocking::Client::new().get(&url).send()?; - Ok(resp.text()?) + let resp = reqwest::Client::new().get(&url).send().await?; + let max_age = resp + .headers() + .get("cache-control") + .and_then(|cache_control| cache_control.to_str().ok()) + .and_then(|cache_control| CacheControl::from_value(cache_control)) + .and_then(|cache_control| cache_control.max_age) + .and_then(|max_age| Duration::from_std(max_age).ok()); + + Ok((resp.text().await?, max_age)) } -/// Download the Google JWK Set for a given service account. -/// The resulting set of JWKs need to be added to a credentials object -/// for jwk verifications. -#[cfg(feature = "unstable")] -pub async fn download_google_jwks_async(account_mail: &str) -> Result { - let resp = reqwest::Client::new() - .get(&format!( - "https://www.googleapis.com/service_accounts/v1/jwk/{}", - account_mail - )) - .send() - .await?; - Ok(resp.text().await?) -} - -pub(crate) fn create_jwt_encoded>( +pub(crate) async fn create_jwt_encoded>( credentials: &Credentials, - scope: Option>, + scope: Option>, duration: chrono::Duration, client_id: Option, user_id: Option, audience: &str, ) -> Result { let jwt = create_jwt(credentials, scope, duration, client_id, user_id, audience)?; - let secret = credentials - .keys + let secret_lock = credentials.keys.read().await; + let secret = secret_lock .secret .as_ref() .ok_or(Error::Generic("No private key added via add_keypair_key!"))?; @@ -119,15 +115,19 @@ pub(crate) fn jwt_update_expiry_if(jwt: &mut AuthClaimsJWT, expire_in_minutes: i let ref mut claims = jwt.payload_mut().unwrap().registered; let now = biscuit::Timestamp::from(Utc::now()); + let now_plus_hour = biscuit::Timestamp::from(Utc::now().add(Duration::hours(1))); + if let Some(issued_at) = claims.issued_at.as_ref() { let diff: Duration = Utc::now().signed_duration_since(issued_at.deref().clone()); if diff.num_minutes() > expire_in_minutes { claims.issued_at = Some(now); + claims.expiry = Some(now_plus_hour); } else { return false; } } else { claims.issued_at = Some(now); + claims.expiry = Some(now_plus_hour); } true @@ -144,8 +144,6 @@ pub(crate) fn create_jwt( where S: AsRef, { - use std::ops::Add; - use biscuit::{ jws::{Header, RegisteredHeader}, ClaimsSet, Empty, RegisteredClaims, JWT, @@ -179,6 +177,7 @@ where Ok(JWT::new_decoded(header, expected_claims)) } +#[derive(Debug)] pub struct TokenValidationResult { pub claims: JwtOAuthPrivateClaims, pub audience: String, @@ -194,7 +193,7 @@ impl TokenValidationResult { } } -pub(crate) fn verify_access_token( +pub(crate) async fn verify_access_token( credentials: &Credentials, access_token: &str, ) -> Result { @@ -208,6 +207,7 @@ pub(crate) fn verify_access_token( .ok_or(FirebaseError::Generic("No jwt kid"))?; let secret = credentials .decode_secret(kid) + .await? .ok_or(FirebaseError::Generic("No secret for kid"))?; let token = token.into_decoded(&secret.deref(), SignatureAlgorithm::RS256)?; @@ -247,7 +247,10 @@ pub mod session_cookie { use super::*; use std::ops::Add; - pub(crate) fn create_jwt_encoded(credentials: &Credentials, duration: chrono::Duration) -> Result { + pub(crate) async fn create_jwt_encoded( + credentials: &Credentials, + duration: chrono::Duration, + ) -> Result { let scope = [ "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/firebase.database", @@ -285,8 +288,8 @@ pub mod session_cookie { }; let jwt = JWT::new_decoded(header, expected_claims); - let secret = credentials - .keys + let secret_lock = credentials.keys.read().await; + let secret = secret_lock .secret .as_ref() .ok_or(Error::Generic("No private key added via add_keypair_key!"))?; diff --git a/src/lib.rs b/src/lib.rs index ade7262..3d6175b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,18 +27,19 @@ pub use sessions::user::Session as UserSession; /// Firestore document methods in [`crate::documents`] expect an object that implements this `FirebaseAuthBearer` trait. /// /// Implement this trait for your own data structure and provide the Firestore project id and a valid access token. +#[async_trait::async_trait] pub trait FirebaseAuthBearer { /// Return the project ID. This is required for the firebase REST API. fn project_id(&self) -> &str; + /// An access token. If a refresh token is known and the access token expired, /// the implementation should try to refresh the access token before returning. - fn access_token(&self) -> String; + async fn access_token(&self) -> String; + /// The access token, unchecked. Might be expired or in other ways invalid. - fn access_token_unchecked(&self) -> String; - /// The reqwest http client. - /// The `Client` holds a connection pool internally, so it is advised that it is reused for multiple, successive connections. - fn client(&self) -> &reqwest::blocking::Client; + async fn access_token_unchecked(&self) -> String; + /// The reqwest http client. /// The `Client` holds a connection pool internally, so it is advised that it is reused for multiple, successive connections. - fn client_async(&self) -> &reqwest::Client; + fn client(&self) -> &reqwest::Client; } diff --git a/src/sessions.rs b/src/sessions.rs index 3857587..5afca2b 100644 --- a/src/sessions.rs +++ b/src/sessions.rs @@ -2,8 +2,9 @@ //! //! A session can be either for a service-account or impersonated via a firebase auth user id. +#![allow(unused_imports)] use super::credentials; -use super::errors::{extract_google_api_error, FirebaseError}; +use super::errors::{extract_google_api_error, extract_google_api_error_async, FirebaseError}; use super::jwt::{ create_jwt, is_expired, jwt_update_expiry_if, verify_access_token, AuthClaimsJWT, JWT_AUDIENCE_FIRESTORE, JWT_AUDIENCE_IDENTITY, @@ -15,6 +16,8 @@ use serde::{Deserialize, Serialize}; use std::cell::RefCell; use std::ops::Deref; use std::slice::Iter; +use std::sync::Arc; +use tokio::sync::RwLock; pub mod user { use super::*; @@ -66,6 +69,7 @@ pub mod user { /// An impersonated session. /// Firestore rules will restrict your access. + #[derive(Clone)] pub struct Session { /// The firebase auth user id pub user_id: String, @@ -74,64 +78,63 @@ pub mod user { pub refresh_token: Option, /// The firebase projects API key, as defined in the credentials object pub api_key: String, - access_token_: RefCell, + + access_token_: Arc>, + project_id_: String, - /// The http client. Replace or modify the client if you have special demands like proxy support - pub client: reqwest::blocking::Client, /// The http client for async operations. Replace or modify the client if you have special demands like proxy support - pub client_async: reqwest::Client, + pub client: reqwest::Client, } + #[async_trait::async_trait] impl super::FirebaseAuthBearer for Session { fn project_id(&self) -> &str { &self.project_id_ } + + async fn access_token_unchecked(&self) -> String { + self.access_token_.read().await.clone() + } + /// Returns the current access token. /// This method will automatically refresh your access token, if it has expired. /// /// If the refresh failed, this will return an empty string. - fn access_token(&self) -> String { - let jwt = self.access_token_.borrow(); - let jwt = jwt.as_str(); + async fn access_token(&self) -> String { + // Let's keep the access token locked for writes for the entirety of this function, + // so we don't have multiple refreshes going on at the same time + let mut jwt = self.access_token_.write().await; if is_expired(&jwt, 0).unwrap() { // Unwrap: the token is always valid at this point - if let Ok(response) = get_new_access_token(&self.api_key, jwt) { - self.access_token_.swap(&RefCell::new(response.id_token.clone())); + if let Ok(response) = get_new_access_token(&self.api_key, &jwt).await { + *jwt = response.id_token.clone(); return response.id_token; } else { // Failed to refresh access token. Return an empty string return String::new(); } } - jwt.to_owned() - } - fn access_token_unchecked(&self) -> String { - self.access_token_.borrow().clone() + jwt.clone() } - fn client(&self) -> &reqwest::blocking::Client { + fn client(&self) -> &reqwest::Client { &self.client } - - fn client_async(&self) -> &reqwest::Client { - &self.client_async - } } /// Gets a new access token via an api_key and a refresh_token. - /// This is a blocking operation. - fn get_new_access_token( + async fn get_new_access_token( api_key: &str, refresh_token: &str, ) -> Result { let request_body = vec![("grant_type", "refresh_token"), ("refresh_token", refresh_token)]; let url = refresh_to_access_endpoint(api_key); - let client = reqwest::blocking::Client::new(); - let response = client.post(&url).form(&request_body).send()?; - Ok(response.json()?) + let client = reqwest::Client::new(); + let response = client.post(&url).form(&request_body).send().await?; + Ok(response.json().await?) } #[allow(non_snake_case)] @@ -185,7 +188,7 @@ pub mod user { /// See: /// * https://firebase.google.com/docs/reference/rest/auth#section-refresh-token /// * https://firebase.google.com/docs/auth/admin/create-custom-tokens#create_custom_tokens_using_a_third-party_jwt_library - pub fn new( + pub async fn new( credentials: &Credentials, user_id: Option<&str>, firebase_tokenid: Option<&str>, @@ -193,7 +196,7 @@ pub mod user { ) -> Result { // Check if current tokenid is still valid if let Some(firebase_tokenid) = firebase_tokenid { - let r = Session::by_access_token(credentials, firebase_tokenid); + let r = Session::by_access_token(credentials, firebase_tokenid).await; if r.is_ok() { let mut r = r.unwrap(); r.refresh_token = refresh_token.and_then(|f| Some(f.to_owned())); @@ -203,7 +206,7 @@ pub mod user { // Check if refresh_token is already sufficient if let Some(refresh_token) = refresh_token { - let r = Session::by_refresh_token(credentials, refresh_token); + let r = Session::by_refresh_token(credentials, refresh_token).await; if r.is_ok() { return r; } @@ -212,7 +215,7 @@ pub mod user { // Neither refresh token nor access token worked or are provided. // Try to get new new tokens for the given user_id via the REST API and the service-account credentials. if let Some(user_id) = user_id { - let r = Session::by_user_id(credentials, user_id, true); + let r = Session::by_user_id(credentials, user_id, true).await; if r.is_ok() { return r; } @@ -228,16 +231,19 @@ pub mod user { /// - `refresh_token` A refresh token. /// /// Async support: This is a blocking operation. - pub fn by_refresh_token(credentials: &Credentials, refresh_token: &str) -> Result { - let r: RefreshTokenToAccessTokenResponse = get_new_access_token(&credentials.api_key, refresh_token)?; + pub async fn by_refresh_token( + credentials: &Credentials, + refresh_token: &str, + ) -> Result { + let r: RefreshTokenToAccessTokenResponse = + get_new_access_token(&credentials.api_key, refresh_token).await?; Ok(Session { user_id: r.user_id, - access_token_: RefCell::new(r.id_token), + access_token_: Arc::new(RwLock::new(r.id_token)), refresh_token: Some(r.refresh_token), project_id_: credentials.project_id.to_owned(), api_key: credentials.api_key.clone(), - client: reqwest::blocking::Client::new(), - client_async: reqwest::Client::new(), + client: reqwest::Client::new(), }) } @@ -250,7 +256,7 @@ pub mod user { /// Google generates only a few dozens of refresh tokens before it starts to invalidate already generated ones. /// For short lived, immutable, non-persisting services you do not want a refresh token. /// - pub fn by_user_id( + pub async fn by_user_id( credentials: &Credentials, user_id: &str, with_refresh_token: bool, @@ -264,28 +270,28 @@ pub mod user { Some(user_id.to_owned()), JWT_AUDIENCE_IDENTITY, )?; - let secret = credentials - .keys + let secret_lock = credentials.keys.read().await; + let secret = secret_lock .secret .as_ref() .ok_or(FirebaseError::Generic("No private key added via add_keypair_key!"))?; let encoded = jwt.encode(&secret.deref())?.encoded()?.encode(); - let resp = reqwest::blocking::Client::new() + let resp = reqwest::Client::new() .post(&token_endpoint(&credentials.api_key)) .json(&CustomJwtToFirebaseID::new(encoded, with_refresh_token)) - .send()?; - let resp = extract_google_api_error(resp, || user_id.to_owned())?; - let r: CustomJwtToFirebaseIDResponse = resp.json()?; + .send() + .await?; + let resp = extract_google_api_error_async(resp, || user_id.to_owned()).await?; + let r: CustomJwtToFirebaseIDResponse = resp.json().await?; Ok(Session { user_id: user_id.to_owned(), - access_token_: RefCell::new(r.idToken), + access_token_: Arc::new(RwLock::new(r.idToken)), refresh_token: r.refreshToken, project_id_: credentials.project_id.to_owned(), api_key: credentials.api_key.clone(), - client: reqwest::blocking::Client::new(), - client_async: reqwest::Client::new(), + client: reqwest::Client::new(), }) } @@ -300,16 +306,15 @@ pub mod user { /// - `credentials` The credentials /// - `access_token` An access token, sometimes called a firebase id token. /// - pub fn by_access_token(credentials: &Credentials, access_token: &str) -> Result { - let result = verify_access_token(&credentials, access_token)?; + pub async fn by_access_token(credentials: &Credentials, access_token: &str) -> Result { + let result = verify_access_token(&credentials, access_token).await?; Ok(Session { user_id: result.subject, project_id_: result.audience, - access_token_: RefCell::new(access_token.to_owned()), + access_token_: Arc::new(RwLock::new(access_token.to_owned())), refresh_token: None, api_key: credentials.api_key.clone(), - client: reqwest::blocking::Client::new(), - client_async: reqwest::Client::new(), + client: reqwest::Client::new(), }) } @@ -325,7 +330,7 @@ pub mod user { /// Google generates only a few dozens of refresh tokens before it starts to invalidate already generated ones. /// For short lived, immutable, non-persisting services you do not want a refresh token. /// - pub fn by_oauth2( + pub async fn by_oauth2( credentials: &Credentials, access_token: String, provider: OAuth2Provider, @@ -346,11 +351,11 @@ pub mod user { return_secure_token, }; - let response = reqwest::blocking::Client::new().post(&uri).json(&json).send()?; + let response = reqwest::Client::new().post(&uri).json(&json).send().await?; - let oauth_response: OAuthResponse = response.json()?; + let oauth_response: OAuthResponse = response.json().await?; - self::Session::by_user_id(&credentials, &oauth_response.local_id, with_refresh_token) + self::Session::by_user_id(&credentials, &oauth_response.local_id, with_refresh_token).await } } } @@ -377,7 +382,7 @@ pub mod session_cookie { } /// https://cloud.google.com/identity-platform/docs/reference/rest/v1/projects/createSessionCookie - #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] struct SessionLoginDTO { /// Required. A valid Identity Platform ID token. #[serde(rename = "idToken")] @@ -393,8 +398,6 @@ pub mod session_cookie { #[derive(Debug, Deserialize)] struct Oauth2ResponseDTO { access_token: String, - expires_in: u64, - token_type: String, } /// Firebase Auth provides server-side session cookie management for traditional websites that rely on session cookies. @@ -418,13 +421,13 @@ pub mod session_cookie { /// - `id_token` An access token, sometimes called a firebase id token. /// - `duration` The cookie duration /// - pub fn create( + pub async fn create( credentials: &credentials::Credentials, id_token: String, duration: chrono::Duration, ) -> Result { // Generate the assertion from the admin credentials - let assertion = crate::jwt::session_cookie::create_jwt_encoded(credentials, duration)?; + let assertion = crate::jwt::session_cookie::create_jwt_encoded(credentials, duration).await?; // Request Google Oauth2 to retrieve the access token in order to create a session cookie let client = reqwest::blocking::Client::new(); @@ -455,6 +458,8 @@ pub mod session_cookie { /// Find the service account session defined in here pub mod service_account { + use crate::jwt::TokenValidationResult; + use super::*; use credentials::Credentials; @@ -463,50 +468,60 @@ pub mod service_account { use std::ops::Deref; /// Service account session + #[derive(Clone, Debug)] pub struct Session { /// The google credentials pub credentials: Credentials, - /// The http client. Replace or modify the client if you have special demands like proxy support - pub client: reqwest::blocking::Client, /// The http client for async operations. Replace or modify the client if you have special demands like proxy support - pub client_async: reqwest::Client, - jwt: RefCell, - access_token_: RefCell, + pub client: reqwest::Client, + jwt: Arc>, + access_token_: Arc>, } + #[async_trait::async_trait] impl super::FirebaseAuthBearer for Session { fn project_id(&self) -> &str { &self.credentials.project_id } + /// Return the encoded jwt to be used as bearer token. If the jwt /// issue_at is older than 50 minutes, it will be updated to the current time. - fn access_token(&self) -> String { - let mut jwt = self.jwt.borrow_mut(); - - if jwt_update_expiry_if(&mut jwt, 50) { - if let Some(secret) = self.credentials.keys.secret.as_ref() { - if let Ok(v) = self.jwt.borrow().encode(&secret.deref()) { - if let Ok(v2) = v.encoded() { - self.access_token_.swap(&RefCell::new(v2.encode())); - } - } + async fn access_token(&self) -> String { + // Keeping the JWT and the access token in write mode so this area is + // a single-entrace critical section for refreshes sake + let mut access_token = self.access_token_.write().await; + let maybe_jwt = { + let mut jwt = self.jwt.write().await; + + if jwt_update_expiry_if(&mut jwt, 50) { + self.credentials + .keys + .read() + .await + .secret + .as_ref() + .and_then(|secret| jwt.clone().encode(&secret.deref()).ok()) + } else { + None + } + }; + + if let Some(v) = maybe_jwt { + if let Ok(v) = v.encoded() { + *access_token = v.encode(); } } - self.access_token_.borrow().clone() + access_token.clone() } - fn access_token_unchecked(&self) -> String { - self.access_token_.borrow().clone() + async fn access_token_unchecked(&self) -> String { + self.access_token_.read().await.clone() } - fn client(&self) -> &reqwest::blocking::Client { + fn client(&self) -> &reqwest::Client { &self.client } - - fn client_async(&self) -> &reqwest::Client { - &self.client_async - } } impl Session { @@ -519,7 +534,7 @@ pub mod service_account { /// as bearer token. /// /// See https://developers.google.com/identity/protocols/OAuth2ServiceAccount - pub fn new(credentials: Credentials) -> Result { + pub async fn new(credentials: Credentials) -> Result { let scope: Option> = None; let jwt = create_jwt( &credentials, @@ -529,20 +544,26 @@ pub mod service_account { None, JWT_AUDIENCE_FIRESTORE, )?; - let secret = credentials - .keys - .secret - .as_ref() - .ok_or(FirebaseError::Generic("No private key added via add_keypair_key!"))?; - let encoded = jwt.encode(&secret.deref())?.encoded()?.encode(); + let encoded = { + let secret_lock = credentials.keys.read().await; + let secret = secret_lock + .secret + .as_ref() + .ok_or(FirebaseError::Generic("No private key added via add_keypair_key!"))?; + jwt.encode(&secret.deref())?.encoded()?.encode() + }; Ok(Session { - access_token_: RefCell::new(encoded), - jwt: RefCell::new(jwt), + access_token_: Arc::new(RwLock::new(encoded)), + jwt: Arc::new(RwLock::new(jwt)), + credentials, - client: reqwest::blocking::Client::new(), - client_async: reqwest::Client::new(), + client: reqwest::Client::new(), }) } + + pub async fn verify_token(&self, token: &str) -> Result { + self.credentials.verify_token(token).await + } } } diff --git a/src/users.rs b/src/users.rs index a05a936..15b2163 100644 --- a/src/users.rs +++ b/src/users.rs @@ -2,7 +2,7 @@ //! //! Retrieve firebase user information -use super::errors::{extract_google_api_error, Result}; +use super::errors::{extract_google_api_error_async, Result}; use super::sessions::{service_account, user}; use serde::{Deserialize, Serialize}; @@ -66,20 +66,21 @@ fn firebase_auth_url(v: &str, v2: &str) -> String { /// Error codes: /// - INVALID_ID_TOKEN /// - USER_NOT_FOUND -pub fn user_info(session: &user::Session) -> Result { +pub async fn user_info(session: &user::Session) -> Result { let url = firebase_auth_url("lookup", &session.api_key); let resp = session .client() .post(&url) .json(&UserRequest { - idToken: session.access_token(), + idToken: session.access_token().await, }) - .send()?; + .send() + .await?; - let resp = extract_google_api_error(resp, || session.user_id.to_owned())?; + let resp = extract_google_api_error_async(resp, || session.user_id.to_owned()).await?; - Ok(resp.json()?) + Ok(resp.json().await?) } /// Removes the firebase auth user associated with the given user session @@ -87,17 +88,18 @@ pub fn user_info(session: &user::Session) -> Result { /// Error codes: /// - INVALID_ID_TOKEN /// - USER_NOT_FOUND -pub fn user_remove(session: &user::Session) -> Result<()> { +pub async fn user_remove(session: &user::Session) -> Result<()> { let url = firebase_auth_url("delete", &session.api_key); let resp = session .client() .post(&url) .json(&UserRequest { - idToken: session.access_token(), + idToken: session.access_token().await, }) - .send()?; + .send() + .await?; - extract_google_api_error(resp, || session.user_id.to_owned())?; + extract_google_api_error_async(resp, || session.user_id.to_owned()).await?; Ok({}) } @@ -117,7 +119,12 @@ struct SignInUpUserRequest { pub returnSecureToken: bool, } -fn sign_up_in(session: &service_account::Session, email: &str, password: &str, action: &str) -> Result { +async fn sign_up_in( + session: &service_account::Session, + email: &str, + password: &str, + action: &str, +) -> Result { let url = firebase_auth_url(action, &session.credentials.api_key); let resp = session .client() @@ -127,18 +134,20 @@ fn sign_up_in(session: &service_account::Session, email: &str, password: &str, a password: password.to_owned(), returnSecureToken: true, }) - .send()?; + .send() + .await?; - let resp = extract_google_api_error(resp, || email.to_owned())?; + let resp = extract_google_api_error_async(resp, || email.to_owned()).await?; - let resp: SignInUpUserResponse = resp.json()?; + let resp: SignInUpUserResponse = resp.json().await?; Ok(user::Session::new( &session.credentials, Some(&resp.localId), Some(&resp.idToken), Some(&resp.refreshToken), - )?) + ) + .await?) } /// Creates the firebase auth user with the given email and password and returns @@ -148,8 +157,8 @@ fn sign_up_in(session: &service_account::Session, email: &str, password: &str, a /// EMAIL_EXISTS: The email address is already in use by another account. /// OPERATION_NOT_ALLOWED: Password sign-in is disabled for this project. /// TOO_MANY_ATTEMPTS_TRY_LATER: We have blocked all requests from this device due to unusual activity. Try again later. -pub fn sign_up(session: &service_account::Session, email: &str, password: &str) -> Result { - sign_up_in(session, email, password, "signUp") +pub async fn sign_up(session: &service_account::Session, email: &str, password: &str) -> Result { + sign_up_in(session, email, password, "signUp").await } /// Signs in with the given email and password and returns a user session. @@ -158,6 +167,6 @@ pub fn sign_up(session: &service_account::Session, email: &str, password: &str) /// EMAIL_NOT_FOUND: There is no user record corresponding to this identifier. The user may have been deleted. /// INVALID_PASSWORD: The password is invalid or the user does not have a password. /// USER_DISABLED: The user account has been disabled by an administrator. -pub fn sign_in(session: &service_account::Session, email: &str, password: &str) -> Result { - sign_up_in(session, email, password, "signInWithPassword") +pub async fn sign_in(session: &service_account::Session, email: &str, password: &str) -> Result { + sign_up_in(session, email, password, "signInWithPassword").await }