diff --git a/Cargo.lock b/Cargo.lock index a99e4e7..00f6cc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -887,6 +887,7 @@ dependencies = [ "axum", "chrono", "http-body-util", + "openadr-vtn", "openadr-wire", "rangemap", "reqwest", diff --git a/openadr-client/Cargo.toml b/openadr-client/Cargo.toml index 138e65b..041a9e5 100644 --- a/openadr-client/Cargo.toml +++ b/openadr-client/Cargo.toml @@ -30,3 +30,4 @@ uuid.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["full", "test-util"] } +openadr-vtn = { path = "../openadr-vtn" } diff --git a/openadr-client/src/bin/cli.rs b/openadr-client/src/bin/cli.rs index 43dd7b9..62d0e84 100644 --- a/openadr-client/src/bin/cli.rs +++ b/openadr-client/src/bin/cli.rs @@ -5,10 +5,7 @@ use openadr_wire::program::ProgramContent; async fn main() -> Result<(), Box> { let client = openadr_client::Client::with_url( "http://localhost:3000/".try_into()?, - Some(ClientCredentials::new( - "admin".to_string(), - "admin".to_string(), - )), + Some(ClientCredentials::admin()), ); let _created_program = client.create_program(ProgramContent::new("name")).await?; // let created_program_1 = client.create_program(ProgramContent::new("name1")).await?; diff --git a/openadr-client/src/bin/everest.rs b/openadr-client/src/bin/everest.rs index 57ad718..bf7b86f 100644 --- a/openadr-client/src/bin/everest.rs +++ b/openadr-client/src/bin/everest.rs @@ -4,7 +4,7 @@ use openadr_wire::{ values_map::Value, }; -use openadr_client::{ClientRef, ProgramClient, Target, Timeline}; +use openadr_client::{ProgramClient, Target, Timeline}; use std::{error::Error, time::Duration}; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::{select, sync::mpsc}; @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box> { } async fn poll_timeline( - mut program: ProgramClient, + mut program: ProgramClient, poll_interval: std::time::Duration, sender: mpsc::Sender, ) -> Result<(), openadr_client::Error> { diff --git a/openadr-client/src/event.rs b/openadr-client/src/event.rs index b582de8..d682f8f 100644 --- a/openadr-client/src/event.rs +++ b/openadr-client/src/event.rs @@ -11,13 +11,13 @@ use openadr_wire::{ }; #[derive(Debug)] -pub struct EventClient { - client: Arc, +pub struct EventClient { + client: Arc, data: Event, } -impl EventClient { - pub(super) fn from_event(client: Arc, event: Event) -> Self { +impl EventClient { + pub(super) fn from_event(client: Arc, event: Event) -> Self { Self { client, data: event, @@ -75,7 +75,7 @@ impl EventClient { } /// Create a new report for the event - pub async fn create_report(&self, report_data: ReportContent) -> Result> { + pub async fn create_report(&self, report_data: ReportContent) -> Result { if report_data.program_id != self.data().program_id { return Err(Error::InvalidParentObject); } @@ -93,7 +93,7 @@ impl EventClient { client_name: Option<&str>, skip: usize, limit: usize, - ) -> Result>> { + ) -> Result> { let skip_str = skip.to_string(); let limit_str = limit.to_string(); @@ -116,7 +116,7 @@ impl EventClient { } /// Get all reports from the VTN for a specific client, trying to paginate whenever possible - pub async fn get_client_reports(&self, client_name: &str) -> Result>> { + pub async fn get_client_reports(&self, client_name: &str) -> Result> { let page_size = self.client.default_page_size(); let mut reports = vec![]; let mut page = 0; @@ -140,7 +140,7 @@ impl EventClient { } /// Get all reports from the VTN, trying to paginate whenever possible - pub async fn get_all_reports(&self) -> Result>> { + pub async fn get_all_reports(&self) -> Result> { let page_size = self.client.default_page_size(); let mut reports = vec![]; let mut page = 0; diff --git a/openadr-client/src/lib.rs b/openadr-client/src/lib.rs index b7ea25d..99c669b 100644 --- a/openadr-client/src/lib.rs +++ b/openadr-client/src/lib.rs @@ -5,16 +5,17 @@ mod report; mod target; mod timeline; +use axum::async_trait; use std::{ - future::Future, + fmt::Debug, sync::Arc, time::{Duration, Instant}, }; - use tokio::sync::RwLock; use axum::body::Body; use http_body_util::BodyExt; +use reqwest::{Method, RequestBuilder, Response}; use tower::{Service, ServiceExt}; use url::Url; @@ -33,11 +34,18 @@ pub(crate) use openadr_wire::{ Program, }; +#[async_trait] +trait HttpClient: Debug { + fn request_builder(&self, method: reqwest::Method, url: Url) -> reqwest::RequestBuilder; + async fn send(&self, req: reqwest::RequestBuilder) -> reqwest::Result; +} + /// Client used for interaction with a VTN. /// /// Can be used to implement both, the VEN and the business logic -pub struct Client { - client_ref: Arc, +#[derive(Debug, Clone)] +pub struct Client { + client_ref: Arc, } pub struct ClientCredentials { @@ -47,6 +55,19 @@ pub struct ClientCredentials { pub default_credential_expires_in: Duration, } +impl Debug for ClientCredentials { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(std::any::type_name::()) + .field("client_id", &self.client_id) + .field("refresh_margin", &self.refresh_margin) + .field( + "default_credential_expires_in", + &self.default_credential_expires_in, + ) + .finish_non_exhaustive() + } +} + impl ClientCredentials { pub fn new(client_id: String, client_secret: String) -> Self { Self { @@ -56,6 +77,10 @@ impl ClientCredentials { default_credential_expires_in: Duration::from_secs(3600), } } + + pub fn admin() -> Self { + Self::new("admin".to_string(), "admin".to_string()) + } } struct AuthToken { @@ -64,23 +89,25 @@ struct AuthToken { since: Instant, } -pub struct ReqwestClientRef { - client: reqwest::Client, +impl Debug for AuthToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct(std::any::type_name::()) + .field("expires_in", &self.expires_in) + .field("since", &self.since) + .finish_non_exhaustive() + } +} + +#[derive(Debug)] +pub struct ClientRef { + client: Box, base_url: url::Url, default_page_size: usize, auth_data: Option, auth_token: RwLock>, } -impl std::fmt::Debug for ReqwestClientRef { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_tuple("ClientRef") - .field(&self.base_url.to_string()) - .finish() - } -} - -impl ReqwestClientRef { +impl ClientRef { /// This ensures the client is authenticated. /// /// We follow the process according to RFC 6749, section 4.4 (client @@ -112,16 +139,19 @@ impl ReqwestClientRef { // we should authenticate let auth_url = self.base_url.join("auth/token")?; - let request = self.client.post(auth_url).form(&AccessTokenRequest { - grant_type: "client_credentials", - scope: None, - client_id: None, - client_secret: None, - }); + let request = + self.client + .request_builder(Method::POST, auth_url) + .form(&AccessTokenRequest { + grant_type: "client_credentials", + scope: None, + client_id: None, + client_secret: None, + }); let request = request.basic_auth(&auth_data.client_id, Some(&auth_data.client_secret)); let request = request.header("Accept", "application/json"); let since = Instant::now(); - let res = request.send().await?; + let res = self.client.send(request).await?; if !res.status().is_success() { let problem = res.json::().await?; return Err(crate::Error::AuthProblem(problem)); @@ -158,42 +188,7 @@ impl ReqwestClientRef { *self.auth_token.write().await = Some(token); Ok(()) } -} - -pub trait ClientRef { - fn get( - &self, - path: &str, - query: &[(&str, &str)], - ) -> impl Future> + Send; - fn post( - &self, - path: &str, - body: &S, - query: &[(&str, &str)], - ) -> impl Future> + Send - where - S: serde::ser::Serialize + Sync, - T: serde::de::DeserializeOwned; - - fn put( - &self, - path: &str, - body: &S, - query: &[(&str, &str)], - ) -> impl Future> + Send - where - S: serde::ser::Serialize + Sync, - T: serde::de::DeserializeOwned; - - fn delete(&self, path: &str, query: &[(&str, &str)]) - -> impl Future> + Send; - - fn default_page_size(&self) -> usize; -} - -impl ReqwestClientRef { async fn request( &self, mut request: reqwest::RequestBuilder, @@ -212,7 +207,7 @@ impl ReqwestClientRef { request = request.bearer_auth(&token.token); } } - let res = request.send().await?; + let res = self.client.send(request).await?; // handle any errors returned by the server if !res.status().is_success() { @@ -222,16 +217,14 @@ impl ReqwestClientRef { Ok(res.json().await?) } -} -impl ClientRef for ReqwestClientRef { async fn get( &self, path: &str, query: &[(&str, &str)], ) -> Result { let url = self.base_url.join(path)?; - let request = self.client.get(url); + let request = self.client.request_builder(Method::GET, url); self.request(request, query).await } @@ -241,7 +234,7 @@ impl ClientRef for ReqwestClientRef { T: serde::de::DeserializeOwned, { let url = self.base_url.join(path)?; - let request = self.client.post(url).json(body); + let request = self.client.request_builder(Method::POST, url).json(body); self.request(request, query).await } @@ -251,13 +244,13 @@ impl ClientRef for ReqwestClientRef { T: serde::de::DeserializeOwned, { let url = self.base_url.join(path)?; - let request = self.client.put(url).json(body); + let request = self.client.request_builder(Method::PUT, url).json(body); self.request(request, query).await } async fn delete(&self, path: &str, query: &[(&str, &str)]) -> Result<()> { let url = self.base_url.join(path)?; - let request = self.client.delete(url); + let request = self.client.request_builder(Method::DELETE, url); self.request(request, query).await } @@ -266,44 +259,55 @@ impl ClientRef for ReqwestClientRef { } } +#[derive(Debug)] +pub struct ReqwestClientRef { + client: reqwest::Client, +} + +#[async_trait] +impl HttpClient for ReqwestClientRef { + fn request_builder(&self, method: reqwest::Method, url: Url) -> RequestBuilder { + self.client.request(method, url) + } + + async fn send(&self, req: RequestBuilder) -> std::result::Result { + req.send().await + } +} + +#[derive(Debug)] pub struct MockClientRef { router: Arc>, - default_page_size: usize, } impl MockClientRef { pub fn new(router: axum::Router) -> Self { MockClientRef { router: Arc::new(tokio::sync::Mutex::new(router)), - default_page_size: 50, } } - async fn request( - &self, - method: axum::http::Method, - path: &str, - body: Option>, - query: &[(&str, &str)], - ) -> Result { - let mut uri = format!("/{path}?"); - let mut it = query.iter().peekable(); + pub fn into_client(self, auth: Option) -> Client { + let client = ClientRef { + client: Box::new(self), + base_url: Url::parse("https://example.com/").unwrap(), + default_page_size: 50, + auth_data: auth, + auth_token: RwLock::new(None), + }; - while let Some((key, value)) = it.next() { - uri.push_str(key); - uri.push('='); - uri.push_str(value); + Client::new(client) + } +} - if it.peek().is_some() { - uri.push('&'); - } - } +#[async_trait] +impl HttpClient for MockClientRef { + fn request_builder(&self, method: reqwest::Method, url: Url) -> RequestBuilder { + reqwest::Client::new().request(method, url) + } - let request = axum::http::Request::builder() - .method(method) - .uri(uri) - .body(Body::from(body.unwrap_or_default())) - .unwrap(); + async fn send(&self, req: reqwest::RequestBuilder) -> reqwest::Result { + let request = axum::http::Request::try_from(req.build().unwrap()).unwrap(); let response = ServiceExt::>::ready(&mut *self.router.lock().await) @@ -313,53 +317,16 @@ impl MockClientRef { .await .unwrap(); - let body = response.into_body().collect().await.unwrap().to_bytes(); + let (parts, body) = response.into_parts(); + let body = body.collect().await.unwrap().to_bytes(); + let body = reqwest::Body::from(body); + let response = axum::http::Response::from_parts(parts, body); - Ok(serde_json::from_slice(&body).unwrap()) + Ok(response.into()) } } -impl ClientRef for MockClientRef { - async fn get( - &self, - path: &str, - query: &[(&str, &str)], - ) -> Result { - self.request(axum::http::Method::GET, path, None, query) - .await - } - - async fn post(&self, path: &str, body: &S, query: &[(&str, &str)]) -> Result - where - S: serde::ser::Serialize + Sync, - T: serde::de::DeserializeOwned, - { - let body = serde_json::to_vec(body)?; - self.request(axum::http::Method::POST, path, Some(body), query) - .await - } - - async fn put(&self, path: &str, body: &S, query: &[(&str, &str)]) -> Result - where - S: serde::ser::Serialize + Sync, - T: serde::de::DeserializeOwned, - { - let body = serde_json::to_vec(body)?; - self.request(axum::http::Method::PUT, path, Some(body), query) - .await - } - - async fn delete(&self, path: &str, query: &[(&str, &str)]) -> Result<()> { - self.request(axum::http::Method::DELETE, path, None, query) - .await - } - - fn default_page_size(&self) -> usize { - self.default_page_size - } -} - -impl Client { +impl Client { /// Create a new client for a VTN located at the specified URL pub fn with_url(base_url: Url, auth: Option) -> Self { let client = reqwest::Client::new(); @@ -373,8 +340,8 @@ impl Client { client: reqwest::Client, auth: Option, ) -> Self { - let client_ref = ReqwestClientRef { - client, + let client_ref = ClientRef { + client: Box::new(ReqwestClientRef { client }), base_url, default_page_size: 50, auth_data: auth, @@ -383,17 +350,15 @@ impl Client { Self::new(client_ref) } -} -impl Client { - pub fn new(client_ref: C) -> Self { + fn new(client_ref: ClientRef) -> Self { Client { client_ref: Arc::new(client_ref), } } /// Create a new program on the VTN - pub async fn create_program(&self, program_data: ProgramContent) -> Result> { + pub async fn create_program(&self, program_data: ProgramContent) -> Result { let program = self.client_ref.post("programs", &program_data, &[]).await?; Ok(ProgramClient::from_program( self.client_ref.clone(), @@ -408,7 +373,7 @@ impl Client { targets: &[&str], skip: usize, limit: usize, - ) -> Result>> { + ) -> Result> { // convert query params let target_type_str = target_type.map(|t| t.to_string()); let skip_str = skip.to_string(); @@ -434,7 +399,7 @@ impl Client { } /// Get a single program from the VTN that matches the given target - pub async fn get_program(&self, target: Target<'_>) -> Result> { + pub async fn get_program(&self, target: Target<'_>) -> Result { let mut programs = self .get_programs_req(Some(target.target_label()), target.target_values(), 0, 2) .await?; @@ -447,7 +412,7 @@ impl Client { } /// Get a list of programs from the VTN with the given query parameters - pub async fn get_program_list(&self, target: Target<'_>) -> Result>> { + pub async fn get_program_list(&self, target: Target<'_>) -> Result> { let page_size = self.client_ref.default_page_size(); let mut programs = vec![]; let mut page = 0; @@ -476,7 +441,7 @@ impl Client { } /// Get all programs from the VTN, trying to paginate whenever possible - pub async fn get_all_programs(&self) -> Result>> { + pub async fn get_all_programs(&self) -> Result> { let page_size = self.client_ref.default_page_size(); let mut programs = vec![]; let mut page = 0; @@ -501,12 +466,12 @@ impl Client { } /// Get a program by name - pub async fn get_program_by_name(&self, name: &str) -> Result> { + pub async fn get_program_by_name(&self, name: &str) -> Result { self.get_program(Target::Program(name)).await } /// Get a program by id - pub async fn get_program_by_id(&self, id: &ProgramId) -> Result> { + pub async fn get_program_by_id(&self, id: &ProgramId) -> Result { let program = self .client_ref .get(&format!("programs/{}", id.as_str()), &[]) diff --git a/openadr-client/src/program.rs b/openadr-client/src/program.rs index 35a82fe..5f69d40 100644 --- a/openadr-client/src/program.rs +++ b/openadr-client/src/program.rs @@ -14,13 +14,13 @@ use crate::{ /// A client for interacting with the data in a specific program and the events /// contained in the program. #[derive(Debug)] -pub struct ProgramClient { - client: Arc, +pub struct ProgramClient { + client: Arc, data: Program, } -impl ProgramClient { - pub(super) fn from_program(client: Arc, program: Program) -> Self { +impl ProgramClient { + pub(super) fn from_program(client: Arc, program: Program) -> Self { Self { client, data: program, @@ -71,7 +71,7 @@ impl ProgramClient { } /// Create a new event on the VTN - pub async fn create_event(&self, event_data: EventContent) -> Result> { + pub async fn create_event(&self, event_data: EventContent) -> Result { if &event_data.program_id != self.id() { return Err(crate::Error::InvalidParentObject); } @@ -100,7 +100,7 @@ impl ProgramClient { targets: &[&str], skip: usize, limit: usize, - ) -> Result>> { + ) -> Result> { // convert query params let target_type_str = target_type.map(|t| t.to_string()); let skip_str = skip.to_string(); @@ -126,7 +126,7 @@ impl ProgramClient { } /// Get a single event from the VTN that matches the given target - pub async fn get_event(&self, target: Target<'_>) -> Result> { + pub async fn get_event(&self, target: Target<'_>) -> Result { let mut events = self .get_events_req(Some(target.target_label()), target.target_values(), 0, 2) .await?; @@ -140,7 +140,7 @@ impl ProgramClient { } /// Get a list of events from the VTN with the given query parameters - pub async fn get_event_list(&self, target: Target<'_>) -> Result>> { + pub async fn get_event_list(&self, target: Target<'_>) -> Result> { let page_size = self.client.default_page_size(); let mut events = vec![]; let mut page = 0; @@ -169,7 +169,7 @@ impl ProgramClient { } /// Get all events from the VTN, trying to paginate whenever possible - pub async fn get_all_events(&self) -> Result>> { + pub async fn get_all_events(&self) -> Result> { let page_size = self.client.default_page_size(); let mut events = vec![]; let mut page = 0; diff --git a/openadr-client/src/report.rs b/openadr-client/src/report.rs index e2e38ca..448afc2 100644 --- a/openadr-client/src/report.rs +++ b/openadr-client/src/report.rs @@ -6,13 +6,13 @@ use crate::error::Result; use crate::ClientRef; #[derive(Debug)] -pub struct ReportClient { - client: Arc, +pub struct ReportClient { + client: Arc, data: Report, } -impl ReportClient { - pub(super) fn from_report(client: Arc, report: Report) -> Self { +impl ReportClient { + pub(super) fn from_report(client: Arc, report: Report) -> Self { Self { client, data: report, diff --git a/openadr-client/tests/basic-read.rs b/openadr-client/tests/basic-read.rs new file mode 100644 index 0000000..c131650 --- /dev/null +++ b/openadr-client/tests/basic-read.rs @@ -0,0 +1,53 @@ +use openadr_wire::program::ProgramContent; + +mod helper { + use std::env::VarError; + + use openadr_client::{Client, ClientCredentials, MockClientRef}; + use openadr_vtn::data_source::AuthInfo; + use url::Url; + + pub fn setup_mock_client() -> Client { + use openadr_vtn::{data_source::InMemoryStorage, jwt::JwtManager, state::AppState}; + + let auth_info = AuthInfo::bl_admin(); + let client_credentials = ClientCredentials::admin(); + + let storage = InMemoryStorage::default(); + storage.auth.try_write().unwrap().push(auth_info); + + let app_state = AppState::new(storage, JwtManager::from_secret(b"test")); + + MockClientRef::new(app_state.into_router()).into_client(Some(client_credentials)) + } + + pub fn setup_url_client(url: Url) -> Client { + Client::with_url(url, Some(ClientCredentials::admin())) + } + + pub fn setup_client() -> Client { + match std::env::var("OPENADR_RS_VTN_URL") { + Ok(url) => match url.parse() { + Ok(url) => setup_url_client(url), + Err(e) => panic!("Could not parse URL: {e}"), + }, + Err(VarError::NotPresent) => setup_mock_client(), + Err(VarError::NotUnicode(e)) => panic!("Could not parse URL: {e:?}"), + } + } +} + +#[tokio::test] +async fn basic_create_read() -> Result<(), openadr_client::Error> { + let client = helper::setup_client(); + + client + .create_program(ProgramContent::new("test-prog")) + .await?; + + let programs = client.get_all_programs().await?; + assert_eq!(programs.len(), 1); + assert_eq!(programs[0].data().program_name, "test-prog"); + + Ok(()) +} diff --git a/openadr-vtn/src/api/event.rs b/openadr-vtn/src/api/event.rs index e8ea009..28ada65 100644 --- a/openadr-vtn/src/api/event.rs +++ b/openadr-vtn/src/api/event.rs @@ -216,11 +216,7 @@ mod test { fn state_with_events(events: Vec) -> AppState { let store = InMemoryStorage::default(); - store.auth.try_write().unwrap().push(AuthInfo { - client_id: "admin".to_string(), - client_secret: "admin".to_string(), - roles: vec![AuthRole::AnyBusiness, AuthRole::UserManager], - }); + store.auth.try_write().unwrap().push(AuthInfo::bl_admin()); { let mut writer = store.events.try_write().unwrap(); diff --git a/openadr-vtn/src/api/program.rs b/openadr-vtn/src/api/program.rs index 65f8507..b98b45d 100644 --- a/openadr-vtn/src/api/program.rs +++ b/openadr-vtn/src/api/program.rs @@ -251,11 +251,7 @@ mod test { fn state_with_programs(programs: Vec) -> AppState { let store = InMemoryStorage::default(); - store.auth.try_write().unwrap().push(AuthInfo { - client_id: "admin".to_string(), - client_secret: "admin".to_string(), - roles: vec![AuthRole::AnyBusiness, AuthRole::UserManager], - }); + store.auth.try_write().unwrap().push(AuthInfo::bl_admin()); { let mut writer = store.programs.try_write().unwrap(); diff --git a/openadr-vtn/src/data_source/mod.rs b/openadr-vtn/src/data_source/mod.rs index 545b582..099723c 100644 --- a/openadr-vtn/src/data_source/mod.rs +++ b/openadr-vtn/src/data_source/mod.rs @@ -83,6 +83,16 @@ pub struct AuthInfo { pub roles: Vec, } +impl AuthInfo { + pub fn bl_admin() -> Self { + Self { + client_id: "admin".to_string(), + client_secret: "admin".to_string(), + roles: vec![AuthRole::AnyBusiness, AuthRole::UserManager], + } + } +} + #[derive(Default, Clone)] pub struct InMemoryStorage { pub programs: Arc>>, diff --git a/openadr-vtn/src/main.rs b/openadr-vtn/src/main.rs index b07ef4f..d6a016e 100644 --- a/openadr-vtn/src/main.rs +++ b/openadr-vtn/src/main.rs @@ -6,7 +6,7 @@ use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::{fmt, EnvFilter}; use openadr_vtn::data_source::{AuthInfo, InMemoryStorage}; -use openadr_vtn::jwt::{AuthRole, JwtManager}; +use openadr_vtn::jwt::JwtManager; use openadr_vtn::state::AppState; #[tokio::main] @@ -21,11 +21,7 @@ async fn main() { info!("listening on http://{}", listener.local_addr().unwrap()); let storage = InMemoryStorage::default(); - storage.auth.write().await.push(AuthInfo { - client_id: "admin".to_string(), - client_secret: "admin".to_string(), - roles: vec![AuthRole::AnyBusiness, AuthRole::UserManager], - }); + storage.auth.write().await.push(AuthInfo::bl_admin()); let state = AppState::new(storage, JwtManager::from_base64_secret("test").unwrap()); if let Err(e) = axum::serve(listener, state.into_router())