From dc662aa65a59f4e7dab87dba5278d6d80328eeba Mon Sep 17 00:00:00 2001 From: Sergii Mikhtoniuk Date: Wed, 11 Sep 2024 15:35:52 -0700 Subject: [PATCH] Break up all-in-one error type --- README.md | 6 +- examples/simple_service.rs | 47 +++++---- src/atom.rs | 80 +++++++++++++++- src/context.rs | 26 ++--- src/error.rs | 192 ++++++++++++++++++++++++++++++------- src/handlers.rs | 99 ++++++++----------- src/metadata.rs | 48 +++++----- tests/test_handlers.rs | 34 ++++--- 8 files changed, 370 insertions(+), 162 deletions(-) diff --git a/README.md b/README.md index 09bf331..2adaf07 100644 --- a/README.md +++ b/README.md @@ -15,17 +15,17 @@ Query using [xh](https://github.com/ducaale/xh): Service root: ```sh -xh GET 'http://localhost:3000/' +xh GET 'http://localhost:50051/' ``` Metadata: ```sh -xh GET 'http://localhost:3000/$metadata' +xh GET 'http://localhost:50051/$metadata' ``` Query collection: ```sh -xh GET 'http://localhost:3000/tickers.spy/?$select=offset,from_symbol,to_symbol,close&$top=5' +xh GET 'http://localhost:50051/tickers.spy/?$select=offset,from_symbol,to_symbol,close&$top=5' ``` ## Status diff --git a/examples/simple_service.rs b/examples/simple_service.rs index e4ab6a2..c30d1ad 100644 --- a/examples/simple_service.rs +++ b/examples/simple_service.rs @@ -4,12 +4,12 @@ use chrono::{DateTime, Utc}; use datafusion::arrow::datatypes::SchemaRef; use datafusion::{prelude::*, sql::TableReference}; -use axum::response::{Response, Result as AxumResult}; +use axum::response::Response; use datafusion_odata::{ collection::{CollectionAddr, QueryParams, QueryParamsRaw}, context::{CollectionContext, OnUnsupported, ServiceContext}, - error::{Error, Result}, + error::{CollectionNotFound, ODataError}, handlers::{MEDIA_TYPE_ATOM, MEDIA_TYPE_XML}, }; @@ -25,7 +25,7 @@ const DEFAULT_MAX_ROWS: usize = 100; pub async fn odata_service_handler( axum::extract::State(query_ctx): axum::extract::State, host: axum::extract::Host, -) -> AxumResult> { +) -> Result, ODataError> { let ctx = Arc::new(ODataContext::new_service(query_ctx, host)); datafusion_odata::handlers::odata_service_handler(axum::Extension(ctx)).await } @@ -35,7 +35,7 @@ pub async fn odata_service_handler( pub async fn odata_metadata_handler( axum::extract::State(query_ctx): axum::extract::State, host: axum::extract::Host, -) -> AxumResult> { +) -> Result, ODataError> { let ctx = ODataContext::new_service(query_ctx, host); datafusion_odata::handlers::odata_metadata_handler(axum::Extension(Arc::new(ctx))).await } @@ -48,12 +48,9 @@ pub async fn odata_collection_handler( axum::extract::Path(collection_path_element): axum::extract::Path, query: axum::extract::Query, headers: axum::http::HeaderMap, -) -> AxumResult> { +) -> Result, ODataError> { let Some(addr) = CollectionAddr::decode(&collection_path_element) else { - return Ok(axum::response::Response::builder() - .status(http::StatusCode::NOT_FOUND) - .body("".into()) - .map_err(Error::from)?); + Err(CollectionNotFound::new(collection_path_element))? }; let ctx = Arc::new(ODataContext::new_collection(query_ctx, host, addr)); @@ -99,7 +96,7 @@ impl ServiceContext for ODataContext { self.service_base_url.clone() } - async fn list_collections(&self) -> Result>> { + async fn list_collections(&self) -> Result>, ODataError> { let cnames = self.query_ctx.catalog_names(); assert_eq!( cnames.len(), @@ -142,21 +139,21 @@ impl ServiceContext for ODataContext { #[async_trait::async_trait] impl CollectionContext for ODataContext { - fn addr(&self) -> Result<&CollectionAddr> { + fn addr(&self) -> Result<&CollectionAddr, ODataError> { Ok(self.addr.as_ref().unwrap()) } - fn service_base_url(&self) -> Result { + fn service_base_url(&self) -> Result { Ok(self.service_base_url.clone()) } - fn collection_base_url(&self) -> Result { + fn collection_base_url(&self) -> Result { let service_base_url = &self.service_base_url; let collection_name = self.collection_name()?; Ok(format!("{service_base_url}{collection_name}")) } - fn collection_name(&self) -> Result { + fn collection_name(&self) -> Result { Ok(self.addr()?.name.clone()) } @@ -164,19 +161,31 @@ impl CollectionContext for ODataContext { Utc::now() } - async fn schema(&self) -> Result { + async fn schema(&self) -> Result { Ok(self .query_ctx .table_provider(TableReference::bare(self.collection_name()?)) - .await? + .await + .map_err(|e| { + ODataError::handle_no_table_as_collection_not_found( + self.collection_name().unwrap(), + e, + ) + })? .schema()) } - async fn query(&self, query: QueryParams) -> Result { + async fn query(&self, query: QueryParams) -> Result { let df = self .query_ctx .table(TableReference::bare(self.collection_name()?)) - .await?; + .await + .map_err(|e| { + ODataError::handle_no_table_as_collection_not_found( + self.collection_name().unwrap(), + e, + ) + })?; query .apply( @@ -187,7 +196,7 @@ impl CollectionContext for ODataContext { DEFAULT_MAX_ROWS, usize::MAX, ) - .map_err(Error::from) + .map_err(ODataError::internal) } fn on_unsupported_feature(&self) -> OnUnsupported { diff --git a/src/atom.rs b/src/atom.rs index 826670f..5e68888 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -9,7 +9,7 @@ use quick_xml::events::*; use crate::{ context::{CollectionContext, OnUnsupported}, - error::Result, + error::ODataError, }; /////////////////////////////////////////////////////////////////////////////// @@ -72,7 +72,7 @@ pub fn write_atom_feed_from_records( updated_time: DateTime, on_unsupported: OnUnsupported, writer: &mut quick_xml::Writer, -) -> Result<()> +) -> Result<(), ODataError> where W: std::io::Write, { @@ -91,6 +91,42 @@ where collection_base_url.pop(); } + write_atom_feed_from_records_impl( + schema, + record_batches, + ctx, + updated_time, + on_unsupported, + writer, + service_base_url, + collection_base_url, + collection_name, + type_namespace, + type_name, + ) + .map_err(ODataError::internal) +} + +// TODO: Use erased dyn Writer type +// TODO: Extract `CollectionInfo` type to avoid propagating +// a bunch of individual parameters +#[allow(clippy::too_many_arguments)] +fn write_atom_feed_from_records_impl( + schema: &Schema, + record_batches: Vec, + ctx: &dyn CollectionContext, + updated_time: DateTime, + on_unsupported: OnUnsupported, + writer: &mut quick_xml::Writer, + service_base_url: String, + collection_base_url: String, + collection_name: String, + type_namespace: String, + type_name: String, +) -> std::result::Result<(), quick_xml::Error> +where + W: std::io::Write, +{ let fq_type = format!("{type_namespace}.{type_name}"); let mut columns = Vec::new(); @@ -283,7 +319,7 @@ pub fn write_atom_entry_from_record( updated_time: DateTime, on_unsupported: OnUnsupported, writer: &mut quick_xml::Writer, -) -> Result<()> +) -> Result<(), ODataError> where W: std::io::Write, { @@ -304,6 +340,42 @@ where collection_base_url.pop(); } + write_atom_entry_from_record_impl( + schema, + batch, + ctx, + updated_time, + on_unsupported, + writer, + service_base_url, + collection_base_url, + collection_name, + type_namespace, + type_name, + ) + .map_err(ODataError::internal) +} + +// TODO: Use erased dyn Writer type +// TODO: Extract `CollectionInfo` type to avoid propagating +// a bunch of individual parameters +#[allow(clippy::too_many_arguments)] +fn write_atom_entry_from_record_impl( + schema: &Schema, + batch: RecordBatch, + ctx: &dyn CollectionContext, + updated_time: DateTime, + on_unsupported: OnUnsupported, + writer: &mut quick_xml::Writer, + service_base_url: String, + collection_base_url: String, + collection_name: String, + type_namespace: String, + type_name: String, +) -> std::result::Result<(), quick_xml::Error> +where + W: std::io::Write, +{ let fq_type = format!("{type_namespace}.{type_name}"); let mut columns = Vec::new(); @@ -522,6 +594,8 @@ fn encode_date_time(dt: &DateTime) -> BytesText<'static> { BytesText::from_escaped(s) } +/////////////////////////////////////////////////////////////////////////////// + #[cfg(test)] mod tests { use super::*; diff --git a/src/context.rs b/src/context.rs index 204891a..2acc3a6 100644 --- a/src/context.rs +++ b/src/context.rs @@ -5,7 +5,7 @@ use datafusion::{arrow::datatypes::SchemaRef, dataframe::DataFrame}; use crate::{ collection::{CollectionAddr, QueryParams}, - error::{Error, Result}, + error::{KeyColumnNotAssigned, ODataError}, }; /////////////////////////////////////////////////////////////////////////////// @@ -18,43 +18,47 @@ pub const DEFAULT_NAMESPACE: &str = "default"; pub trait ServiceContext: Send + Sync { fn service_base_url(&self) -> String; - async fn list_collections(&self) -> Result>>; + async fn list_collections(&self) -> Result>, ODataError>; fn on_unsupported_feature(&self) -> OnUnsupported; } +/////////////////////////////////////////////////////////////////////////////// + #[async_trait::async_trait] pub trait CollectionContext: Send + Sync { - fn addr(&self) -> Result<&CollectionAddr>; + fn addr(&self) -> Result<&CollectionAddr, ODataError>; - fn service_base_url(&self) -> Result; + fn service_base_url(&self) -> Result; - fn collection_base_url(&self) -> Result; + fn collection_base_url(&self) -> Result; - fn collection_namespace(&self) -> Result { + fn collection_namespace(&self) -> Result { Ok(DEFAULT_NAMESPACE.to_string()) } - fn collection_name(&self) -> Result; + fn collection_name(&self) -> Result; // Synthetic column name that will be used to propagate entity IDs fn key_column_alias(&self) -> String { "__id__".to_string() } - fn key_column(&self) -> Result { - Err(Error::KeyColumnNotAssigned) + fn key_column(&self) -> Result { + Err(KeyColumnNotAssigned)? } async fn last_updated_time(&self) -> DateTime; - async fn schema(&self) -> Result; + async fn schema(&self) -> Result; - async fn query(&self, query: QueryParams) -> Result; + async fn query(&self, query: QueryParams) -> Result; fn on_unsupported_feature(&self) -> OnUnsupported; } +/////////////////////////////////////////////////////////////////////////////// + pub enum OnUnsupported { /// Return an error or crash Error, diff --git a/src/error.rs b/src/error.rs index 561d7ce..0ecc58f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,35 +1,159 @@ -use std::io; - -use thiserror::Error as ThisError; - -pub type Result = core::result::Result; - -#[derive(ThisError, Debug)] -pub enum Error { - #[error("IO error {0}")] - IO(#[from] io::Error), - #[error("Datafusion error: {0}")] - Datafusion(#[from] datafusion::error::DataFusionError), - #[error("Address parse error: {0}")] - AddrParse(#[from] std::net::AddrParseError), - #[error("Axum error: {0}")] - AxumError(#[from] axum::Error), - #[error("Http error: {0}")] - HttpError(#[from] http::Error), - #[error("Quickxml error: {0}")] - QuickXMLError(#[from] quick_xml::Error), - #[error("Quickxml deserialise error: {0}")] - QuickXMLDeError(#[from] quick_xml::DeError), - #[error("From utf8 error: {0}")] - FromUtf8Error(#[from] std::string::FromUtf8Error), - #[error("Unsupported data type: {0}")] - UnsupportedDataType(datafusion::arrow::datatypes::DataType), - #[error("Unsupported feature: {0}")] - UnsupportedFeature(String), - #[error("Collection not found: {0}")] - CollectionNotFound(String), - #[error("Collection address not assigned")] - CollectionAddressNotAssigned, - #[error("Key column not assigned")] - KeyColumnNotAssigned, +use datafusion::arrow::datatypes::DataType; + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +pub enum ODataError { + #[error(transparent)] + UnsupportedDataType(#[from] UnsupportedDataType), + #[error(transparent)] + UnsupportedFeature(#[from] UnsupportedFeature), + #[error(transparent)] + CollectionNotFound(#[from] CollectionNotFound), + #[error(transparent)] + CollectionAddressNotAssigned(#[from] CollectionAddressNotAssigned), + #[error(transparent)] + KeyColumnNotAssigned(#[from] KeyColumnNotAssigned), + #[error(transparent)] + Internal(InternalError), +} + +impl ODataError { + pub fn internal(error: impl Into>) -> Self { + Self::Internal(InternalError::new(error)) + } + + pub fn handle_no_table_as_collection_not_found( + collection: impl Into, + err: datafusion::error::DataFusionError, + ) -> Self { + match err { + datafusion::error::DataFusionError::Plan(e) if e.contains("No table named") => { + Self::CollectionNotFound(CollectionNotFound::new(collection)) + } + _ => Self::internal(err), + } + } +} + +impl axum::response::IntoResponse for ODataError { + fn into_response(self) -> axum::response::Response { + match self { + Self::Internal(_) => { + (http::StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() + } + Self::CollectionNotFound(e) => e.into_response(), + Self::UnsupportedDataType(e) => e.into_response(), + Self::UnsupportedFeature(e) => e.into_response(), + Self::CollectionAddressNotAssigned(e) => e.into_response(), + Self::KeyColumnNotAssigned(e) => e.into_response(), + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Internal error")] +pub struct InternalError { + #[source] + pub source: Box, +} + +impl InternalError { + pub fn new(error: impl Into>) -> Self { + Self { + source: error.into(), + } + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Collection {collection} not found")] +pub struct CollectionNotFound { + pub collection: String, +} + +impl CollectionNotFound { + pub fn new(collection: impl Into) -> Self { + Self { + collection: collection.into(), + } + } +} + +impl axum::response::IntoResponse for CollectionNotFound { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::NOT_FOUND, self.to_string()).into_response() + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Key column not assigned")] +pub struct KeyColumnNotAssigned; + +impl axum::response::IntoResponse for KeyColumnNotAssigned { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::NOT_IMPLEMENTED, self.to_string()).into_response() + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Collection address not assigned")] +pub struct CollectionAddressNotAssigned; + +impl axum::response::IntoResponse for CollectionAddressNotAssigned { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::NOT_IMPLEMENTED, self.to_string()).into_response() + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Unsupported data type: {data_type}")] +pub struct UnsupportedDataType { + pub data_type: DataType, +} + +impl UnsupportedDataType { + pub fn new(data_type: DataType) -> Self { + Self { data_type } + } } + +impl axum::response::IntoResponse for UnsupportedDataType { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::NOT_IMPLEMENTED, self.to_string()).into_response() + } +} + +/////////////////////////////////////////////////////////////////////////////// + +#[derive(thiserror::Error, Debug)] +#[error("Unsupported feature: {feature}")] +pub struct UnsupportedFeature { + pub feature: String, +} + +impl UnsupportedFeature { + pub fn new(feature: impl Into) -> Self { + Self { + feature: feature.into(), + } + } +} + +impl axum::response::IntoResponse for UnsupportedFeature { + fn into_response(self) -> axum::response::Response { + (http::StatusCode::NOT_IMPLEMENTED, self.to_string()).into_response() + } +} + +/////////////////////////////////////////////////////////////////////////////// diff --git a/src/handlers.rs b/src/handlers.rs index c891fbd..0abfaa3 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,15 +1,11 @@ use std::sync::Arc; -use axum::{ - extract::Query, - response::{IntoResponse, Response, Result}, - Extension, -}; +use axum::{extract::Query, response::Response, Extension}; use crate::{ collection::QueryParamsRaw, context::{CollectionContext, OnUnsupported, ServiceContext, DEFAULT_NAMESPACE}, - error::Error, + error::{ODataError, UnsupportedDataType}, metadata::{ to_edm_type, DataServices, Edmx, EntityContainer, EntityKey, EntitySet, EntityType, Property, PropertyRef, @@ -17,14 +13,18 @@ use crate::{ service::{Collection, Service, Workspace}, }; +/////////////////////////////////////////////////////////////////////////////// + pub const MEDIA_TYPE_ATOM: &str = "application/atom+xml;type=feed;charset=utf-8"; pub const MEDIA_TYPE_XML: &str = "application/xml;charset=utf-8"; const DEFAULT_COLLECTION_RESPONSE_SIZE: usize = 512_000; +/////////////////////////////////////////////////////////////////////////////// + pub async fn odata_service_handler( Extension(odata_ctx): Extension>, -) -> Result> { +) -> Result, ODataError> { let mut collections = Vec::new(); for coll in odata_ctx.list_collections().await? { @@ -42,15 +42,19 @@ pub async fn odata_service_handler( }, ); - Ok(Response::builder() + let xml = write_object_to_xml("service", &service)?; + + Response::builder() .header(http::header::CONTENT_TYPE.as_str(), MEDIA_TYPE_XML) - .body(write_object_to_xml("service", &service)?) - .map_err(Error::from)?) + .body(xml) + .map_err(ODataError::internal) } +/////////////////////////////////////////////////////////////////////////////// + pub async fn odata_metadata_handler( Extension(odata_ctx): Extension>, -) -> Result> { +) -> Result, ODataError> { let mut entity_types = Vec::new(); let mut entity_container = EntityContainer { name: DEFAULT_NAMESPACE.to_string(), @@ -67,11 +71,7 @@ pub async fn odata_metadata_handler( Ok(typ) => typ, Err(err) => match odata_ctx.on_unsupported_feature() { OnUnsupported::Error => { - return Err(Error::UnsupportedFeature(format!( - "Unsupported field type {:?}", - field.data_type() - )) - .into()); + Err(UnsupportedDataType::new(field.data_type().clone()))? } OnUnsupported::Warn => { tracing::error!( @@ -92,7 +92,7 @@ pub async fn odata_metadata_handler( // https://www.odata.org/documentation/odata-version-3-0/common-schema-definition-language-csdl/#csdl6.3 let property_ref_name = match coll.key_column() { Ok(kc) => kc, - Err(Error::KeyColumnNotAssigned) => match properties.first() { + Err(ODataError::KeyColumnNotAssigned(_)) => match properties.first() { Some(prop) => prop.name.clone(), None => collection_name.to_string(), }, @@ -103,7 +103,7 @@ pub async fn odata_metadata_handler( error_dbg = ?err, "Failed to get key column", ); - return Err(err.into()); + Err(err)? } }; @@ -127,24 +127,28 @@ pub async fn odata_metadata_handler( vec![entity_container], )])); - Ok(Response::builder() + let xml = write_object_to_xml("edmx:Edmx", &metadata)?; + + Response::builder() .header(http::header::CONTENT_TYPE.as_str(), MEDIA_TYPE_XML) - .body(write_object_to_xml("edmx:Edmx", &metadata)?) - .map_err(Error::from)?) + .body(xml) + .map_err(ODataError::internal) } +/////////////////////////////////////////////////////////////////////////////// + pub async fn odata_collection_handler( Extension(ctx): Extension>, Query(query): Query, _headers: axum::http::HeaderMap, -) -> Result> { +) -> Result, ODataError> { let query = query.decode(); tracing::debug!(?query, "Decoded query"); - let df = ctx.query(query).await.map_err(Error::from)?; + let df = ctx.query(query).await.map_err(ODataError::from)?; let schema: datafusion::arrow::datatypes::Schema = df.schema().clone().into(); - let record_batches = df.collect().await.map_err(Error::from)?; + let record_batches = df.collect().await.map_err(ODataError::internal)?; let num_rows: usize = record_batches.iter().map(|b| b.num_rows()).sum(); let raw_bytes: usize = record_batches @@ -162,11 +166,9 @@ pub async fn odata_collection_handler( ctx.last_updated_time().await, ctx.on_unsupported_feature(), &mut writer, - ) - .map_err(Error::from)?; + )?; } else { let num_rows: usize = record_batches.iter().map(|b| b.num_rows()).sum(); - // TODO assert!(num_rows <= 1, "Request by key returned {} rows", num_rows); assert!( record_batches.len() <= 1, @@ -175,10 +177,10 @@ pub async fn odata_collection_handler( ); if record_batches.len() != 1 || record_batches[0].num_rows() != 1 { - return Ok(Response::builder() + return Response::builder() .status(http::StatusCode::NOT_FOUND) - .body("".into()) - .map_err(Error::from)?); + .body(String::new()) + .map_err(ODataError::internal); } crate::atom::write_atom_entry_from_record( @@ -188,11 +190,10 @@ pub async fn odata_collection_handler( ctx.last_updated_time().await, ctx.on_unsupported_feature(), &mut writer, - ) - .map_err(Error::from)?; + )?; } - let body = String::from_utf8(writer.into_inner()).map_err(Error::from)?; + let body = String::from_utf8(writer.into_inner()).map_err(ODataError::internal)?; tracing::debug!( media_type = MEDIA_TYPE_ATOM, @@ -202,13 +203,15 @@ pub async fn odata_collection_handler( "Prepared a response" ); - Ok(Response::builder() + Response::builder() .header(http::header::CONTENT_TYPE.as_str(), MEDIA_TYPE_ATOM) .body(body) - .map_err(Error::from)?) + .map_err(ODataError::internal) } -fn write_object_to_xml(tag: &str, object: &T) -> Result +/////////////////////////////////////////////////////////////////////////////// + +fn write_object_to_xml(tag: &str, object: &T) -> Result where T: serde::ser::Serialize, { @@ -219,29 +222,11 @@ where .write_event(quick_xml::events::Event::Decl( quick_xml::events::BytesDecl::new("1.0", Some("utf-8"), None), )) - .map_err(Error::from)?; + .map_err(ODataError::internal)?; writer .write_serializable(tag, object) - .map_err(Error::from)?; - - Ok(String::from_utf8(writer.into_inner()).map_err(Error::from)?) -} + .map_err(ODataError::internal)?; -impl IntoResponse for Error { - fn into_response(self) -> Response { - tracing::error!("Error: {self}"); - match self { - Error::Datafusion(datafusion::error::DataFusionError::Plan(ref e)) => { - if e.contains("No table named") { - return (http::StatusCode::NOT_FOUND, "Not found").into_response(); - } - } - Error::CollectionNotFound(_) => { - return (http::StatusCode::NOT_FOUND, "Not found").into_response(); - } - _ => {} - } - (http::StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response() - } + Ok(String::from_utf8(writer.into_inner()).unwrap()) } diff --git a/src/metadata.rs b/src/metadata.rs index f1d317b..b2e2e67 100644 --- a/src/metadata.rs +++ b/src/metadata.rs @@ -9,7 +9,7 @@ use datafusion::arrow::datatypes::DataType; -use crate::error::{Error, Result}; +use crate::error::UnsupportedDataType; #[derive(Debug, serde::Serialize)] pub struct Edmx { @@ -172,9 +172,8 @@ pub struct EntitySet { /////////////////////////////////////////////////////////////////////////////// // See: https://www.odata.org/documentation/odata-version-3-0/common-schema-definition-language-csdl/ -pub fn to_edm_type(dt: &DataType) -> Result<&'static str> { +pub fn to_edm_type(dt: &DataType) -> std::result::Result<&'static str, UnsupportedDataType> { match dt { - DataType::Null => Err(Error::UnsupportedDataType(dt.clone())), DataType::Boolean => Ok("Edm.Boolean"), // TODO: Use Edm.Byte / Edm.SByte? DataType::Int8 => Ok("Edm.Int16"), @@ -188,32 +187,33 @@ pub fn to_edm_type(dt: &DataType) -> Result<&'static str> { DataType::UInt64 => Ok("Edm.Int64"), DataType::Utf8 => Ok("Edm.String"), DataType::LargeUtf8 => Ok("Edm.String"), - DataType::Utf8View => Err(Error::UnsupportedDataType(dt.clone())), DataType::Float16 => Ok("Edm.Single"), DataType::Float32 => Ok("Edm.Single"), DataType::Float64 => Ok("Edm.Double"), DataType::Timestamp(_, _) => Ok("Edm.DateTime"), DataType::Date32 => Ok("Edm.DateTime"), DataType::Date64 => Ok("Edm.DateTime"), - DataType::Time32(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Time64(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Duration(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Interval(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Binary => Err(Error::UnsupportedDataType(dt.clone())), - DataType::BinaryView => Err(Error::UnsupportedDataType(dt.clone())), - DataType::FixedSizeBinary(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::LargeBinary => Err(Error::UnsupportedDataType(dt.clone())), - DataType::List(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::FixedSizeList(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::LargeList(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::ListView(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::LargeListView(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Struct(_) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Union(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Dictionary(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Decimal128(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Decimal256(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::Map(_, _) => Err(Error::UnsupportedDataType(dt.clone())), - DataType::RunEndEncoded(_, _) => Err(Error::UnsupportedDataType(dt.clone())), + DataType::Null + | DataType::Utf8View + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Binary + | DataType::BinaryView + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::Struct(_) + | DataType::Union(_, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) => Err(UnsupportedDataType::new(dt.clone())), } } diff --git a/tests/test_handlers.rs b/tests/test_handlers.rs index b43d63d..3b5face 100644 --- a/tests/test_handlers.rs +++ b/tests/test_handlers.rs @@ -5,7 +5,7 @@ use datafusion::{arrow::datatypes::SchemaRef, prelude::*, sql::TableReference}; use datafusion_odata::{ collection::{CollectionAddr, QueryParams, QueryParamsRaw}, context::*, - error::{Error, Result}, + error::ODataError, }; use indoc::indoc; @@ -286,7 +286,7 @@ impl ServiceContext for ODataContext { self.service_base_url.clone() } - async fn list_collections(&self) -> Result>> { + async fn list_collections(&self) -> Result>, ODataError> { let catalog_name = self.query_ctx.catalog_names().into_iter().next().unwrap(); let catalog = self.query_ctx.catalog(&catalog_name).unwrap(); @@ -318,21 +318,21 @@ impl ServiceContext for ODataContext { #[async_trait::async_trait] impl CollectionContext for ODataContext { - fn addr(&self) -> Result<&CollectionAddr> { + fn addr(&self) -> Result<&CollectionAddr, ODataError> { Ok(self.addr.as_ref().unwrap()) } - fn service_base_url(&self) -> Result { + fn service_base_url(&self) -> Result { Ok(self.service_base_url.clone()) } - fn collection_base_url(&self) -> Result { + fn collection_base_url(&self) -> Result { let service_base_url = &self.service_base_url; let collection_name = self.collection_name()?; Ok(format!("{service_base_url}{collection_name}")) } - fn collection_name(&self) -> Result { + fn collection_name(&self) -> Result { Ok(self.addr()?.name.clone()) } @@ -342,19 +342,31 @@ impl CollectionContext for ODataContext { .into() } - async fn schema(&self) -> Result { + async fn schema(&self) -> Result { Ok(self .query_ctx .table_provider(TableReference::bare(self.collection_name()?)) - .await? + .await + .map_err(|e| { + ODataError::handle_no_table_as_collection_not_found( + self.collection_name().unwrap(), + e, + ) + })? .schema()) } - async fn query(&self, query: QueryParams) -> Result { + async fn query(&self, query: QueryParams) -> Result { let df = self .query_ctx .table(TableReference::bare(self.collection_name()?)) - .await?; + .await + .map_err(|e| { + ODataError::handle_no_table_as_collection_not_found( + self.collection_name().unwrap(), + e, + ) + })?; query .apply( @@ -365,7 +377,7 @@ impl CollectionContext for ODataContext { 100, usize::MAX, ) - .map_err(Error::from) + .map_err(ODataError::internal) } fn on_unsupported_feature(&self) -> OnUnsupported {