Skip to content

Commit

Permalink
Break up all-in-one error type
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiimk committed Sep 11, 2024
1 parent 6ad0f40 commit dc662aa
Show file tree
Hide file tree
Showing 8 changed files with 370 additions and 162 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ Query using [xh](https://github.com/ducaale/xh):

Service root:
```sh
xh GET 'http://localhost:3000/'
xh GET 'http://localhost:50051/'
```

Metadata:
```sh
xh GET 'http://localhost:3000/$metadata'
xh GET 'http://localhost:50051/$metadata'
```

Query collection:
```sh
xh GET 'http://localhost:3000/tickers.spy/?$select=offset,from_symbol,to_symbol,close&$top=5'
xh GET 'http://localhost:50051/tickers.spy/?$select=offset,from_symbol,to_symbol,close&$top=5'
```

## Status
Expand Down
47 changes: 28 additions & 19 deletions examples/simple_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use chrono::{DateTime, Utc};
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::{prelude::*, sql::TableReference};

use axum::response::{Response, Result as AxumResult};
use axum::response::Response;

use datafusion_odata::{
collection::{CollectionAddr, QueryParams, QueryParamsRaw},
context::{CollectionContext, OnUnsupported, ServiceContext},
error::{Error, Result},
error::{CollectionNotFound, ODataError},
handlers::{MEDIA_TYPE_ATOM, MEDIA_TYPE_XML},
};

Expand All @@ -25,7 +25,7 @@ const DEFAULT_MAX_ROWS: usize = 100;
pub async fn odata_service_handler(
axum::extract::State(query_ctx): axum::extract::State<SessionContext>,
host: axum::extract::Host,
) -> AxumResult<Response<String>> {
) -> Result<Response<String>, ODataError> {
let ctx = Arc::new(ODataContext::new_service(query_ctx, host));
datafusion_odata::handlers::odata_service_handler(axum::Extension(ctx)).await
}
Expand All @@ -35,7 +35,7 @@ pub async fn odata_service_handler(
pub async fn odata_metadata_handler(
axum::extract::State(query_ctx): axum::extract::State<SessionContext>,
host: axum::extract::Host,
) -> AxumResult<Response<String>> {
) -> Result<Response<String>, ODataError> {
let ctx = ODataContext::new_service(query_ctx, host);
datafusion_odata::handlers::odata_metadata_handler(axum::Extension(Arc::new(ctx))).await
}
Expand All @@ -48,12 +48,9 @@ pub async fn odata_collection_handler(
axum::extract::Path(collection_path_element): axum::extract::Path<String>,
query: axum::extract::Query<QueryParamsRaw>,
headers: axum::http::HeaderMap,
) -> AxumResult<Response<String>> {
) -> Result<Response<String>, ODataError> {
let Some(addr) = CollectionAddr::decode(&collection_path_element) else {
return Ok(axum::response::Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body("".into())
.map_err(Error::from)?);
Err(CollectionNotFound::new(collection_path_element))?
};

let ctx = Arc::new(ODataContext::new_collection(query_ctx, host, addr));
Expand Down Expand Up @@ -99,7 +96,7 @@ impl ServiceContext for ODataContext {
self.service_base_url.clone()
}

async fn list_collections(&self) -> Result<Vec<Arc<dyn CollectionContext>>> {
async fn list_collections(&self) -> Result<Vec<Arc<dyn CollectionContext>>, ODataError> {
let cnames = self.query_ctx.catalog_names();
assert_eq!(
cnames.len(),
Expand Down Expand Up @@ -142,41 +139,53 @@ impl ServiceContext for ODataContext {

#[async_trait::async_trait]
impl CollectionContext for ODataContext {
fn addr(&self) -> Result<&CollectionAddr> {
fn addr(&self) -> Result<&CollectionAddr, ODataError> {
Ok(self.addr.as_ref().unwrap())
}

fn service_base_url(&self) -> Result<String> {
fn service_base_url(&self) -> Result<String, ODataError> {
Ok(self.service_base_url.clone())
}

fn collection_base_url(&self) -> Result<String> {
fn collection_base_url(&self) -> Result<String, ODataError> {
let service_base_url = &self.service_base_url;
let collection_name = self.collection_name()?;
Ok(format!("{service_base_url}{collection_name}"))
}

fn collection_name(&self) -> Result<String> {
fn collection_name(&self) -> Result<String, ODataError> {
Ok(self.addr()?.name.clone())
}

async fn last_updated_time(&self) -> DateTime<Utc> {
Utc::now()
}

async fn schema(&self) -> Result<SchemaRef> {
async fn schema(&self) -> Result<SchemaRef, ODataError> {
Ok(self
.query_ctx
.table_provider(TableReference::bare(self.collection_name()?))
.await?
.await
.map_err(|e| {
ODataError::handle_no_table_as_collection_not_found(
self.collection_name().unwrap(),
e,
)
})?
.schema())
}

async fn query(&self, query: QueryParams) -> Result<DataFrame> {
async fn query(&self, query: QueryParams) -> Result<DataFrame, ODataError> {
let df = self
.query_ctx
.table(TableReference::bare(self.collection_name()?))
.await?;
.await
.map_err(|e| {
ODataError::handle_no_table_as_collection_not_found(
self.collection_name().unwrap(),
e,
)
})?;

query
.apply(
Expand All @@ -187,7 +196,7 @@ impl CollectionContext for ODataContext {
DEFAULT_MAX_ROWS,
usize::MAX,
)
.map_err(Error::from)
.map_err(ODataError::internal)
}

fn on_unsupported_feature(&self) -> OnUnsupported {
Expand Down
80 changes: 77 additions & 3 deletions src/atom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use quick_xml::events::*;

use crate::{
context::{CollectionContext, OnUnsupported},
error::Result,
error::ODataError,
};

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -72,7 +72,7 @@ pub fn write_atom_feed_from_records<W>(
updated_time: DateTime<Utc>,
on_unsupported: OnUnsupported,
writer: &mut quick_xml::Writer<W>,
) -> Result<()>
) -> Result<(), ODataError>
where
W: std::io::Write,
{
Expand All @@ -91,6 +91,42 @@ where
collection_base_url.pop();
}

write_atom_feed_from_records_impl(
schema,
record_batches,
ctx,
updated_time,
on_unsupported,
writer,
service_base_url,
collection_base_url,
collection_name,
type_namespace,
type_name,
)
.map_err(ODataError::internal)
}

// TODO: Use erased dyn Writer type
// TODO: Extract `CollectionInfo` type to avoid propagating
// a bunch of individual parameters
#[allow(clippy::too_many_arguments)]
fn write_atom_feed_from_records_impl<W>(
schema: &Schema,
record_batches: Vec<RecordBatch>,
ctx: &dyn CollectionContext,
updated_time: DateTime<Utc>,
on_unsupported: OnUnsupported,
writer: &mut quick_xml::Writer<W>,
service_base_url: String,
collection_base_url: String,
collection_name: String,
type_namespace: String,
type_name: String,
) -> std::result::Result<(), quick_xml::Error>
where
W: std::io::Write,
{
let fq_type = format!("{type_namespace}.{type_name}");

let mut columns = Vec::new();
Expand Down Expand Up @@ -283,7 +319,7 @@ pub fn write_atom_entry_from_record<W>(
updated_time: DateTime<Utc>,
on_unsupported: OnUnsupported,
writer: &mut quick_xml::Writer<W>,
) -> Result<()>
) -> Result<(), ODataError>
where
W: std::io::Write,
{
Expand All @@ -304,6 +340,42 @@ where
collection_base_url.pop();
}

write_atom_entry_from_record_impl(
schema,
batch,
ctx,
updated_time,
on_unsupported,
writer,
service_base_url,
collection_base_url,
collection_name,
type_namespace,
type_name,
)
.map_err(ODataError::internal)
}

// TODO: Use erased dyn Writer type
// TODO: Extract `CollectionInfo` type to avoid propagating
// a bunch of individual parameters
#[allow(clippy::too_many_arguments)]
fn write_atom_entry_from_record_impl<W>(
schema: &Schema,
batch: RecordBatch,
ctx: &dyn CollectionContext,
updated_time: DateTime<Utc>,
on_unsupported: OnUnsupported,
writer: &mut quick_xml::Writer<W>,
service_base_url: String,
collection_base_url: String,
collection_name: String,
type_namespace: String,
type_name: String,
) -> std::result::Result<(), quick_xml::Error>
where
W: std::io::Write,
{
let fq_type = format!("{type_namespace}.{type_name}");

let mut columns = Vec::new();
Expand Down Expand Up @@ -522,6 +594,8 @@ fn encode_date_time(dt: &DateTime<Utc>) -> BytesText<'static> {
BytesText::from_escaped(s)
}

///////////////////////////////////////////////////////////////////////////////

#[cfg(test)]
mod tests {
use super::*;
Expand Down
26 changes: 15 additions & 11 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use datafusion::{arrow::datatypes::SchemaRef, dataframe::DataFrame};

use crate::{
collection::{CollectionAddr, QueryParams},
error::{Error, Result},
error::{KeyColumnNotAssigned, ODataError},
};

///////////////////////////////////////////////////////////////////////////////
Expand All @@ -18,43 +18,47 @@ pub const DEFAULT_NAMESPACE: &str = "default";
pub trait ServiceContext: Send + Sync {
fn service_base_url(&self) -> String;

async fn list_collections(&self) -> Result<Vec<Arc<dyn CollectionContext>>>;
async fn list_collections(&self) -> Result<Vec<Arc<dyn CollectionContext>>, ODataError>;

fn on_unsupported_feature(&self) -> OnUnsupported;
}

///////////////////////////////////////////////////////////////////////////////

#[async_trait::async_trait]
pub trait CollectionContext: Send + Sync {
fn addr(&self) -> Result<&CollectionAddr>;
fn addr(&self) -> Result<&CollectionAddr, ODataError>;

fn service_base_url(&self) -> Result<String>;
fn service_base_url(&self) -> Result<String, ODataError>;

fn collection_base_url(&self) -> Result<String>;
fn collection_base_url(&self) -> Result<String, ODataError>;

fn collection_namespace(&self) -> Result<String> {
fn collection_namespace(&self) -> Result<String, ODataError> {
Ok(DEFAULT_NAMESPACE.to_string())
}

fn collection_name(&self) -> Result<String>;
fn collection_name(&self) -> Result<String, ODataError>;

// Synthetic column name that will be used to propagate entity IDs
fn key_column_alias(&self) -> String {
"__id__".to_string()
}

fn key_column(&self) -> Result<String> {
Err(Error::KeyColumnNotAssigned)
fn key_column(&self) -> Result<String, ODataError> {
Err(KeyColumnNotAssigned)?
}

async fn last_updated_time(&self) -> DateTime<Utc>;

async fn schema(&self) -> Result<SchemaRef>;
async fn schema(&self) -> Result<SchemaRef, ODataError>;

async fn query(&self, query: QueryParams) -> Result<DataFrame>;
async fn query(&self, query: QueryParams) -> Result<DataFrame, ODataError>;

fn on_unsupported_feature(&self) -> OnUnsupported;
}

///////////////////////////////////////////////////////////////////////////////

pub enum OnUnsupported {
/// Return an error or crash
Error,
Expand Down
Loading

0 comments on commit dc662aa

Please sign in to comment.