From 1ea73b4dc26a2a5a9582d0f4609ca1530c542d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20Greinhofer?= Date: Thu, 8 Aug 2024 09:14:20 -0500 Subject: [PATCH] Refactor BNA API (#127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactors the BNA API to match the Zalando best practices as well as our OpenAPI specification. - Fixes ACCEPT_CONTENT headers in `APIError` - Adds convenience functions associated to `APIErrors` - Expands the BNARequestExt Trait - Fixes the CityPost wrapper - Fixes the logic of the POST /cities endpoint and adds some extra validations - Signed-off-by: Rémy Greinhofer Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com> --- effortless/src/error.rs | 23 ++++- effortless/src/fragment.rs | 7 ++ entity/src/wrappers/city.rs | 3 +- lambdas/src/cities/get-cities-bnas.rs | 57 ++++++------ lambdas/src/cities/get-cities-census.rs | 56 ++++++------ lambdas/src/cities/get-cities.rs | 113 +++++++++++++----------- lambdas/src/cities/mod.rs | 82 +++++++++++++++++ lambdas/src/cities/post-cities.rs | 54 ++++++++--- lambdas/src/lib.rs | 13 +++ migration/src/m20220101_000001_main.rs | 8 +- openapi.yaml | 22 +++-- 11 files changed, 291 insertions(+), 147 deletions(-) create mode 100644 lambdas/src/cities/mod.rs diff --git a/effortless/src/error.rs b/effortless/src/error.rs index 070351f..15825d8 100644 --- a/effortless/src/error.rs +++ b/effortless/src/error.rs @@ -1,4 +1,4 @@ -use lambda_http::{http::StatusCode, Body, Response}; +use lambda_http::{http::header, http::StatusCode, Body, Response}; use serde::{Deserialize, Serialize}; use serde_json::json; use serde_with::skip_serializing_none; @@ -133,6 +133,26 @@ impl APIErrors { errors: errors.to_vec(), } } + + /// Creates an empty `APIErrors`. + pub fn empty() -> Self { + Self { errors: vec![] } + } + + /// Adds an `APIError`. + pub fn add(mut self, value: APIError) { + self.errors.push(value); + } + + /// Extends with an existing `APIErrors`. + pub fn extend(&mut self, value: APIErrors) { + self.errors.extend(value.errors); + } + + /// Returns True if there is no error. + pub fn is_empty(&self) -> bool { + self.errors.is_empty() + } } impl From for APIErrors { @@ -156,6 +176,7 @@ impl From for Response { }; Response::builder() .status(status) + .header(header::CONTENT_TYPE, "application/json") .body(json!(value).to_string().into()) .unwrap() } diff --git a/effortless/src/fragment.rs b/effortless/src/fragment.rs index 6d88914..69a73b7 100644 --- a/effortless/src/fragment.rs +++ b/effortless/src/fragment.rs @@ -150,6 +150,9 @@ pub trait BnaRequestExt { /// If there is no request ID or the event is not coming from an ApiGatewayV2, the /// function returns None. fn apigw_request_id(&self) -> Option; + + /// Returns true if there are path parameters available. + fn has_path_parameters(&self) -> bool; } impl BnaRequestExt for http::Request { @@ -183,6 +186,10 @@ impl BnaRequestExt for http::Request { _ => None, } } + + fn has_path_parameters(&self) -> bool { + !self.path_parameters().is_empty() + } } #[cfg(test)] diff --git a/entity/src/wrappers/city.rs b/entity/src/wrappers/city.rs index 6fee6f1..9da4e84 100644 --- a/entity/src/wrappers/city.rs +++ b/entity/src/wrappers/city.rs @@ -8,6 +8,7 @@ pub struct CityPost { pub latitude: Option, pub longitude: Option, pub name: String, + pub region: Option, pub state: String, pub state_abbrev: Option, pub speed_limit: Option, @@ -21,7 +22,7 @@ impl IntoActiveModel for CityPost { latitude: ActiveValue::Set(self.latitude), longitude: ActiveValue::Set(self.longitude), name: ActiveValue::Set(self.name), - region: ActiveValue::NotSet, + region: ActiveValue::set(self.region), state: ActiveValue::Set(self.state), state_abbrev: ActiveValue::Set(self.state_abbrev), speed_limit: ActiveValue::Set(self.speed_limit), diff --git a/lambdas/src/cities/get-cities-bnas.rs b/lambdas/src/cities/get-cities-bnas.rs index d00e3f8..7174d6b 100644 --- a/lambdas/src/cities/get-cities-bnas.rs +++ b/lambdas/src/cities/get-cities-bnas.rs @@ -1,8 +1,12 @@ use dotenv::dotenv; -use effortless::api::{entry_not_found, missing_parameter, parse_path_parameter}; +use effortless::api::entry_not_found; use entity::{city, summary}; use lambda_http::{run, service_fn, Body, Error, Request, Response}; -use lambdas::{build_paginated_response, database_connect, pagination_parameters}; +use lambdas::{ + build_paginated_response, + cities::{extract_path_parameters, PathParameters}, + database_connect, pagination_parameters, +}; use sea_orm::{EntityTrait, PaginatorTrait}; use serde_json::json; use tracing::info; @@ -10,8 +14,11 @@ use tracing::info; async fn function_handler(event: Request) -> Result, Error> { dotenv().ok(); - // Set the database connection. - let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; + // Extract the path parameters. + let params: PathParameters = match extract_path_parameters(&event) { + Ok(p) => p, + Err(e) => return Ok(e.into()), + }; // Retrieve pagination parameters if any. let (page_size, page) = match pagination_parameters(&event) { @@ -19,36 +26,22 @@ async fn function_handler(event: Request) -> Result, Error> { Err(e) => return Ok(e), }; - let country = match parse_path_parameter::(&event, "country") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - dbg!(&country); - let region = match parse_path_parameter::(&event, "region") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - let name = match parse_path_parameter::(&event, "name") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; + // Set the database connection. + let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; - if country.is_some() && region.is_some() && name.is_some() { - let select = city::Entity::find_by_id((country.unwrap(), region.unwrap(), name.unwrap())) - .find_also_related(summary::Entity); - let model = select - .clone() - .paginate(&db, page_size) - .fetch_page(page - 1) - .await?; - if model.is_empty() { - return Ok(entry_not_found(&event).into()); - } - let total_items = select.count(&db).await?; - build_paginated_response(json!(model), total_items, page, page_size, &event) - } else { - Ok(missing_parameter(&event, "country or region or name").into()) + // Retrieve the city and associated BNA summary(ies). + let select = city::Entity::find_by_id((params.country, params.region, params.name)) + .find_also_related(summary::Entity); + let model = select + .clone() + .paginate(&db, page_size) + .fetch_page(page - 1) + .await?; + if model.is_empty() { + return Ok(entry_not_found(&event).into()); } + let total_items = select.count(&db).await?; + build_paginated_response(json!(model), total_items, page, page_size, &event) } #[tokio::main] diff --git a/lambdas/src/cities/get-cities-census.rs b/lambdas/src/cities/get-cities-census.rs index 5157882..89c8811 100644 --- a/lambdas/src/cities/get-cities-census.rs +++ b/lambdas/src/cities/get-cities-census.rs @@ -1,8 +1,12 @@ use dotenv::dotenv; -use effortless::api::{entry_not_found, missing_parameter, parse_path_parameter}; +use effortless::api::entry_not_found; use entity::{census, city}; use lambda_http::{run, service_fn, Body, Error, Request, Response}; -use lambdas::{build_paginated_response, database_connect, pagination_parameters}; +use lambdas::{ + build_paginated_response, + cities::{extract_path_parameters, PathParameters}, + database_connect, pagination_parameters, +}; use sea_orm::{EntityTrait, PaginatorTrait}; use serde_json::json; use tracing::info; @@ -26,8 +30,11 @@ async fn main() -> Result<(), Error> { async fn function_handler(event: Request) -> Result, Error> { dotenv().ok(); - // Set the database connection. - let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; + // Extract the path parameters. + let params: PathParameters = match extract_path_parameters(&event) { + Ok(p) => p, + Err(e) => return Ok(e.into()), + }; // Retrieve pagination parameters if any. let (page_size, page) = match pagination_parameters(&event) { @@ -35,35 +42,22 @@ async fn function_handler(event: Request) -> Result, Error> { Err(e) => return Ok(e), }; - let country = match parse_path_parameter::(&event, "country") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - let region = match parse_path_parameter::(&event, "region") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - let name = match parse_path_parameter::(&event, "name") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; + // Set the database connection. + let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; - if country.is_some() && region.is_some() && name.is_some() { - let select = city::Entity::find_by_id((country.unwrap(), region.unwrap(), name.unwrap())) - .find_also_related(census::Entity); - let model = select - .clone() - .paginate(&db, page_size) - .fetch_page(page - 1) - .await?; - if model.is_empty() { - return Ok(entry_not_found(&event).into()); - } - let total_items = select.count(&db).await?; - build_paginated_response(json!(model), total_items, page, page_size, &event) - } else { - Ok(missing_parameter(&event, "country or region or name").into()) + // Retrieve the city and associated census(es). + let select = city::Entity::find_by_id((params.country, params.region, params.name)) + .find_also_related(census::Entity); + let model = select + .clone() + .paginate(&db, page_size) + .fetch_page(page - 1) + .await?; + if model.is_empty() { + return Ok(entry_not_found(&event).into()); } + let total_items = select.count(&db).await?; + build_paginated_response(json!(model), total_items, page, page_size, &event) } // #[cfg(test)] diff --git a/lambdas/src/cities/get-cities.rs b/lambdas/src/cities/get-cities.rs index 3a7c877..75d90e6 100644 --- a/lambdas/src/cities/get-cities.rs +++ b/lambdas/src/cities/get-cities.rs @@ -1,8 +1,12 @@ use dotenv::dotenv; -use effortless::api::{entry_not_found, parse_path_parameter}; +use effortless::{api::entry_not_found, fragment::BnaRequestExt}; use entity::city; use lambda_http::{run, service_fn, Body, Error, IntoResponse, Request, Response}; -use lambdas::{build_paginated_response, database_connect, pagination_parameters}; +use lambdas::{ + build_paginated_response, + cities::{extract_path_parameters, PathParameters}, + database_connect, pagination_parameters_2, +}; use sea_orm::{EntityTrait, PaginatorTrait}; use serde_json::json; use tracing::info; @@ -13,43 +17,43 @@ async fn function_handler(event: Request) -> Result, Error> { // Set the database connection. let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; - // Retrieve pagination parameters if any. - let (page_size, page) = match pagination_parameters(&event) { - Ok((page_size, page)) => (page_size, page), - Err(e) => return Ok(e), - }; - - let country = match parse_path_parameter::(&event, "country") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - let region = match parse_path_parameter::(&event, "region") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; - let name = match parse_path_parameter::(&event, "name") { - Ok(value) => value, - Err(e) => return Ok(e.into()), - }; + // With params. + if event.has_path_parameters() { + let params: PathParameters = match extract_path_parameters(&event) { + Ok(p) => p, + Err(e) => return Ok(e.into()), + }; - if country.is_some() && region.is_some() && name.is_some() { - let select = city::Entity::find_by_id((country.unwrap(), region.unwrap(), name.unwrap())); + let select = city::Entity::find_by_id((params.country, params.region, params.name)); let model = select.one(&db).await?; let res: Response = match model { Some(model) => json!(model).into_response().await, None => entry_not_found(&event).into(), }; - Ok(res) - } else { - let select = city::Entity::find(); - let body = select - .clone() - .paginate(&db, page_size) - .fetch_page(page - 1) - .await?; - let total_items = select.count(&db).await?; - build_paginated_response(json!(body), total_items, page, page_size, &event) + return Ok(res); } + + // Retrieve pagination parameters if any. + let pagination = match pagination_parameters_2(&event) { + Ok(p) => p, + Err(e) => return Ok(e), + }; + + // Without params. + let select = city::Entity::find(); + let body = select + .clone() + .paginate(&db, pagination.page_size) + .fetch_page(pagination.page - 1) + .await?; + let total_items = select.count(&db).await?; + build_paginated_response( + json!(body), + total_items, + pagination.page, + pagination.page_size, + &event, + ) } #[tokio::main] @@ -70,25 +74,30 @@ async fn main() -> Result<(), Error> { #[cfg(test)] mod tests { - // use super::*; - // use lambda_http::{http, RequestExt}; - // use std::collections::HashMap; + use super::*; + use lambda_http::{http, RequestExt}; + use std::collections::HashMap; - // #[tokio::test] - // async fn test_handler_opportunity() { - // let event = http::Request::builder() - // .header(http::header::CONTENT_TYPE, "application/json") - // .body(Body::Empty) - // .expect("failed to build request") - // .with_path_parameters(HashMap::from([ - // ("country".to_string(), "United%20States".to_string()), - // ("region".to_string(), "Texas".to_string()), - // ("name".to_string(), "Austin".to_string()), - // ])) - // .with_request_context(lambda_http::request::RequestContext::ApiGatewayV2( - // lambda_http::aws_lambda_events::apigw::ApiGatewayV2httpRequestContext::default(), - // )); - // let r = function_handler(event).await.unwrap(); - // dbg!(r); - // } + #[test] + fn test_extract_path_parameters() { + let country: String = String::from("United States"); + let region: String = String::from("Texas"); + let name: String = String::from("Austin"); + let event = http::Request::builder() + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::Empty) + .expect("failed to build request") + .with_path_parameters(HashMap::from([ + ("country".to_string(), country.clone()), + ("region".to_string(), region.clone()), + ("name".to_string(), name.clone()), + ])) + .with_request_context(lambda_http::request::RequestContext::ApiGatewayV2( + lambda_http::aws_lambda_events::apigw::ApiGatewayV2httpRequestContext::default(), + )); + let r = extract_path_parameters(&event).unwrap(); + assert_eq!(r.country, country); + assert_eq!(r.region, region); + assert_eq!(r.name, name); + } } diff --git a/lambdas/src/cities/mod.rs b/lambdas/src/cities/mod.rs new file mode 100644 index 0000000..048c902 --- /dev/null +++ b/lambdas/src/cities/mod.rs @@ -0,0 +1,82 @@ +//! Module for the /cities enpoint. + +use effortless::{api::parse_path_parameter, error::APIErrors}; +use lambda_http::Request; + +/// Represent the path parameters for the /cities enpoint. +pub struct PathParameters { + /// Country name. + pub country: String, + /// Region name. + pub region: String, + /// City name. + pub name: String, +} + +/// Extract the path parameters for the /cities endpoint. +pub fn extract_path_parameters(event: &Request) -> Result { + let mut api_errors = APIErrors::empty(); + + let country = match parse_path_parameter::(event, "country") { + Ok(value) => value, + Err(e) => { + api_errors.extend(e); + None + } + }; + + let region = match parse_path_parameter::(event, "region") { + Ok(value) => value, + Err(e) => { + api_errors.extend(e); + None + } + }; + let name = match parse_path_parameter::(event, "name") { + Ok(value) => value, + Err(e) => { + api_errors.extend(e); + None + } + }; + + if !api_errors.is_empty() { + return Err(api_errors); + } + + Ok(PathParameters { + country: country.unwrap(), + region: region.unwrap(), + name: name.unwrap(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use lambda_http::{http, Body, RequestExt}; + use std::collections::HashMap; + + #[test] + fn test_extract_path_parameters() { + let country: String = String::from("United States"); + let region: String = String::from("Texas"); + let name: String = String::from("Austin"); + let event = http::Request::builder() + .header(http::header::CONTENT_TYPE, "application/json") + .body(Body::Empty) + .expect("failed to build request") + .with_path_parameters(HashMap::from([ + ("country".to_string(), country.clone()), + ("region".to_string(), region.clone()), + ("name".to_string(), name.clone()), + ])) + .with_request_context(lambda_http::request::RequestContext::ApiGatewayV2( + lambda_http::aws_lambda_events::apigw::ApiGatewayV2httpRequestContext::default(), + )); + let r = extract_path_parameters(&event).unwrap(); + assert_eq!(r.country, country); + assert_eq!(r.region, region); + assert_eq!(r.name, name); + } +} diff --git a/lambdas/src/cities/post-cities.rs b/lambdas/src/cities/post-cities.rs index 1ba68c7..ad7f57e 100644 --- a/lambdas/src/cities/post-cities.rs +++ b/lambdas/src/cities/post-cities.rs @@ -1,13 +1,18 @@ use dotenv::dotenv; use effortless::api::{invalid_body, parse_request_body}; use entity::{ + country, prelude::*, wrappers::{self, city::CityPost}, }; -use lambda_http::{run, service_fn, Body, Error, IntoResponse, Request, Response}; +use lambda_http::{ + http::{header, StatusCode}, + run, service_fn, Body, Error, Request, Response, +}; use lambdas::database_connect; -use sea_orm::ActiveModelTrait; -use sea_orm::{ActiveValue, EntityTrait, IntoActiveModel}; +use sea_orm::{ + ActiveModelTrait, ActiveValue, ColumnTrait, EntityTrait, IntoActiveModel, QueryFilter, +}; use serde_json::json; use tracing::info; use uuid::Uuid; @@ -37,13 +42,10 @@ async fn function_handler(event: Request) -> Result, Error> { Err(e) => return Ok(e.into()), }; - // Check if the country is US and set the region accordingly. + // Extract some country information. let country = wrapper.country.clone(); let state_full = wrapper.state.clone(); - let region: Option = match country.to_lowercase().eq("United States") { - true => None, - false => Some(country), - }; + let region = wrapper.region.clone(); // Turn the model wrapper into an active model. let mut active_city = wrapper.into_active_model(); @@ -54,12 +56,26 @@ async fn function_handler(event: Request) -> Result, Error> { // Get the database connection. let db = database_connect(Some("DATABASE_URL_SECRET_ID")).await?; - // Set the region if needed. - if region.is_none() { - let state_region_model = StateRegionCrosswalk::find_by_id(state_full) + // Ensure the country is a valid one. + if Country::find() + .filter(country::Column::Name.eq(&country)) + .one(&db) + .await? + .is_none() + { + return Ok(invalid_body( + &event, + "the country `{country}` is not in the list of countries supported by the BNA", + ) + .into()); + } + + // If the country is the United States, set the region to the standardized state abbreviation. + if country.to_lowercase().eq("united states") { + match StateRegionCrosswalk::find_by_id(state_full) .one(&db) - .await?; - match state_region_model { + .await? + { Some(model) => { let region: wrappers::BnaRegion = model.region.into(); active_city.region = ActiveValue::Set(Some(region.to_string())); @@ -68,10 +84,20 @@ async fn function_handler(event: Request) -> Result, Error> { } } + // If the region is not set, ensure it is equal to the country. + if region.is_none() { + active_city.region = ActiveValue::Set(Some(country)); + } + // And insert a new entry. info!("inserting City into database: {:?}", active_city); let city = active_city.insert(&db).await?; - Ok(json!(city).into_response().await) + let response = Response::builder() + .status(StatusCode::CREATED) + .header(header::CONTENT_TYPE, "application/json") + .body(Body::Text(json!(city).to_string())) + .expect("unable to build http::Response"); + Ok(response) } #[cfg(test)] diff --git a/lambdas/src/lib.rs b/lambdas/src/lib.rs index 82a0e7e..cff82dc 100644 --- a/lambdas/src/lib.rs +++ b/lambdas/src/lib.rs @@ -1,3 +1,4 @@ +pub mod cities; pub mod link_header; use bnacore::aws::get_aws_secrets_value; @@ -92,6 +93,18 @@ pub fn pagination_parameters(event: &Request) -> APIResult<(u64, u64)> { Ok((page_size, page)) } +/// Represent the query parameters related to the pagination. +pub struct PaginationParameters { + /// The number of items per page. + pub page_size: u64, + /// The result page being returned. + pub page: u64, +} + +pub fn pagination_parameters_2(event: &Request) -> Result> { + pagination_parameters(event).map(|(page_size, page)| PaginationParameters { page_size, page }) +} + /// Builds a paginated Response. /// /// Builds a Response struct which contains the pagination information in the headers. diff --git a/migration/src/m20220101_000001_main.rs b/migration/src/m20220101_000001_main.rs index 5cf1af6..954ebe1 100644 --- a/migration/src/m20220101_000001_main.rs +++ b/migration/src/m20220101_000001_main.rs @@ -435,6 +435,8 @@ enum City { CityId, /// Country. Country, + /// Creation date. + CreatedAt, /// City latitude as defined in the U.S. census. Latitude, /// City longitude as defined in the U.S. census. @@ -443,14 +445,12 @@ enum City { Name, /// Assigned region. Region, + /// City speed limit (if different from the state speed limit). + SpeedLimit, /// State name. State, /// Two-letter state abbreviation. StateAbbrev, - /// City speed limit (if different from the state speed limit). - SpeedLimit, - /// Creation date. - CreatedAt, /// Update date. UpdatedAt, } diff --git a/openapi.yaml b/openapi.yaml index ec4f842..d0fe6cd 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -836,7 +836,9 @@ components: description: "State name" example: "Antwerp" state_abbrev: - $ref: "#/components/schemas/state_abbrev" + type: number + description: "A short version of the state name, usually 2 or 3 character long." + example: "VAN" updated_at: type: array description: "Date and time" @@ -881,7 +883,9 @@ components: description: "State name" example: "Antwerp" state_abbrev: - $ref: "#/components/schemas/state_abbrev" + type: number + description: "A short version of the state name, usually 2 or 3 character long." + example: "VAN" speed_limit: type: number description: "Speed limit in kilometer per hour (km/h)." @@ -1006,7 +1010,10 @@ components: description: "detailed error message" example: "the entry was not found" status: - $ref: "#/components/schemas/status" + type: integer + description: "HTTP status associated with the error" + minimum: 400 + example: 404 title: type: string description: "Error title" @@ -1111,19 +1118,10 @@ components: - $ref: "#/components/schemas/header" example: source: Parameter "/bnas/analysis/e6aade5a-b343-120b-dbaa-bd916cd99221?" - state_abbrev: - type: number - description: "A short version of the state name, usually 2 or 3 character long." - example: "VAN" state_machine_id: type: string description: "ID of the AWS state machine that was used to run the pipeline" example: "38f4f54e-98d6-4048-8c0f-99cde05a7e76" - status: - type: integer - description: "HTTP status associated with the error" - minimum: 400 - example: 404 step: type: string enum: