Skip to content

Commit

Permalink
ic: for table provider (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
tshauck authored Mar 30, 2024
1 parent 09bf6d2 commit bca8335
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
pub mod config;
pub mod file_opener;
pub mod scanner;
pub mod table_factory;
pub mod table_provider;
106 changes: 106 additions & 0 deletions src/datafusion/table_factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use async_trait::async_trait;
use datafusion::{
datasource::{listing::ListingTableUrl, provider::TableProviderFactory, TableProvider},
error::DataFusionError,
execution::context::SessionState,
logical_expr::CreateExternalTable,
};

use super::table_provider::{ListingZarrTableConfig, ListingZarrTableOptions, ZarrTableProvider};

struct ZarrListingTableFactory {}

#[async_trait]
impl TableProviderFactory for ZarrListingTableFactory {
async fn create(
&self,
state: &SessionState,
cmd: &CreateExternalTable,
) -> datafusion::common::Result<Arc<dyn TableProvider>> {
if cmd.file_type != "ZARR" {
return Err(datafusion::error::DataFusionError::Execution(
"Invalid file type".to_string(),
));
}

let table_path = ListingTableUrl::parse(&cmd.location)?;

let options = ListingZarrTableOptions {};
let schema = options
.infer_schema(state, &table_path)
.await
.map_err(|e| DataFusionError::Execution(format!("infer error: {:?}", e)))?;

let table_provider =
ZarrTableProvider::new(ListingZarrTableConfig::new(table_path), schema);

Ok(Arc::new(table_provider))
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datafusion::execution::{
config::SessionConfig,
context::{SessionContext, SessionState},
runtime_env::RuntimeEnv,
};

use crate::tests::get_test_v2_data_path;

#[tokio::test]
async fn test_create() -> Result<(), Box<dyn std::error::Error>> {
let mut state = SessionState::new_with_config_rt(
SessionConfig::default(),
Arc::new(RuntimeEnv::default()),
);

state
.table_factories_mut()
.insert("ZARR".into(), Arc::new(super::ZarrListingTableFactory {}));

let test_data = get_test_v2_data_path("lat_lon_example.zarr".to_string());

let sql = format!(
"CREATE EXTERNAL TABLE zarr_table STORED AS ZARR LOCATION '{}'",
test_data.display(),
);

let session = SessionContext::new_with_state(state);
session.sql(&sql).await?;

let sql = "SELECT lat, lon FROM zarr_table LIMIT 10";
let df = session.sql(sql).await?;

let batches = df.collect().await?;
assert_eq!(batches.len(), 1);

let batch = &batches[0];

assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 10);

Ok(())
}
}
136 changes: 136 additions & 0 deletions src/datafusion/table_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow_schema::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion::{
common::Statistics,
datasource::{
listing::{ListingTableUrl, PartitionedFile},
physical_plan::FileScanConfig,
TableProvider, TableType,
},
execution::context::SessionState,
logical_expr::{Expr, TableProviderFilterPushDown},
physical_plan::ExecutionPlan,
};

use crate::{
async_reader::{ZarrPath, ZarrReadAsync},
reader::ZarrResult,
};

use super::scanner::ZarrScan;

pub struct ListingZarrTableOptions {}

impl ListingZarrTableOptions {
pub async fn infer_schema(
&self,
state: &SessionState,
table_path: &ListingTableUrl,
) -> ZarrResult<Schema> {
let store = state.runtime_env().object_store(table_path)?;

let zarr_path = ZarrPath::new(store, table_path.prefix().clone());
let schema = zarr_path.get_zarr_metadata().await?.arrow_schema()?;

Ok(schema)
}
}

pub struct ListingZarrTableConfig {
/// The inner listing table configuration
table_path: ListingTableUrl,
}

impl ListingZarrTableConfig {
/// Create a new ListingZarrTableConfig
pub fn new(table_path: ListingTableUrl) -> Self {
Self { table_path }
}
}

pub struct ZarrTableProvider {
table_schema: Schema,
config: ListingZarrTableConfig,
}

impl ZarrTableProvider {
pub fn new(config: ListingZarrTableConfig, table_schema: Schema) -> Self {
Self {
table_schema,
config,
}
}
}

#[async_trait]
impl TableProvider for ZarrTableProvider {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn schema(&self) -> SchemaRef {
Arc::new(self.table_schema.clone())
}

fn table_type(&self) -> TableType {
TableType::Base
}

fn supports_filters_pushdown(
&self,
filters: &[&Expr],
) -> datafusion::error::Result<Vec<TableProviderFilterPushDown>> {
// TODO: which filters can we push down?
Ok(filters
.iter()
.map(|_| TableProviderFilterPushDown::Unsupported)
.collect())
}

async fn scan(
&self,
_state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
limit: Option<usize>,
) -> datafusion::error::Result<Arc<dyn ExecutionPlan>> {
let object_store_url = self.config.table_path.object_store();

let pf = PartitionedFile::new(self.config.table_path.prefix().clone(), 0);
let file_groups = vec![vec![pf]];

let file_scan_config = FileScanConfig {
object_store_url,
file_schema: Arc::new(self.table_schema.clone()), // TODO differentiate between file and table schema
file_groups,
statistics: Statistics::new_unknown(&self.table_schema),
projection: projection.cloned(),
limit,
table_partition_cols: vec![],
output_ordering: vec![],
};

let scanner = ZarrScan::new(file_scan_config);

Ok(Arc::new(scanner))
}
}

0 comments on commit bca8335

Please sign in to comment.