diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 675dc4e527d0..0415090665d2 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -29,5 +29,6 @@ do # Skip tests that rely on external storage and flight if [ ! -d $filename ]; then cargo run --example $example_name + cargo clean -p datafusion-examples fi done diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs new file mode 100644 index 000000000000..fe936418bce4 --- /dev/null +++ b/datafusion-examples/examples/custom_file_format.rs @@ -0,0 +1,234 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{AsArray, RecordBatch, StringArray, UInt8Array}, + datatypes::UInt64Type, +}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::{ + datasource::{ + file_format::{ + csv::CsvFormatFactory, file_compression_type::FileCompressionType, + FileFormat, FileFormatFactory, + }, + physical_plan::{FileScanConfig, FileSinkConfig}, + MemTable, + }, + error::Result, + execution::{context::SessionState, runtime_env::RuntimeEnv}, + physical_plan::ExecutionPlan, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_common::{GetExt, Statistics}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use object_store::{ObjectMeta, ObjectStore}; +use tempfile::tempdir; + +/// Example of a custom file format that reads and writes TSV files. +/// +/// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. +/// The former, once registered with the SessionState, will then be used +/// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. + +#[derive(Debug)] +/// Custom file format that reads and writes TSV files +/// +/// This file format is a wrapper around the CSV file format +/// for demonstration purposes. +struct TSVFileFormat { + csv_file_format: Arc, +} + +impl TSVFileFormat { + pub fn new(csv_file_format: Arc) -> Self { + Self { csv_file_format } + } +} + +#[async_trait::async_trait] +impl FileFormat for TSVFileFormat { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_ext(&self) -> String { + "tsv".to_string() + } + + fn get_ext_with_compression( + &self, + c: &FileCompressionType, + ) -> datafusion::error::Result { + if c == &FileCompressionType::UNCOMPRESSED { + Ok("tsv".to_string()) + } else { + todo!("Compression not supported") + } + } + + async fn infer_schema( + &self, + state: &SessionState, + store: &Arc, + objects: &[ObjectMeta], + ) -> Result { + self.csv_file_format + .infer_schema(state, store, objects) + .await + } + + async fn infer_stats( + &self, + state: &SessionState, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + self.csv_file_format + .infer_stats(state, store, table_schema, object) + .await + } + + async fn create_physical_plan( + &self, + state: &SessionState, + conf: FileScanConfig, + filters: Option<&Arc>, + ) -> Result> { + self.csv_file_format + .create_physical_plan(state, conf, filters) + .await + } + + async fn create_writer_physical_plan( + &self, + input: Arc, + state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option>, + ) -> Result> { + self.csv_file_format + .create_writer_physical_plan(input, state, conf, order_requirements) + .await + } +} + +#[derive(Default)] +/// Factory for creating TSV file formats +/// +/// This factory is a wrapper around the CSV file format factory +/// for demonstration purposes. +pub struct TSVFileFactory { + csv_file_factory: CsvFormatFactory, +} + +impl TSVFileFactory { + pub fn new() -> Self { + Self { + csv_file_factory: CsvFormatFactory::new(), + } + } +} + +impl FileFormatFactory for TSVFileFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let mut new_options = format_options.clone(); + new_options.insert("format.delimiter".to_string(), "\t".to_string()); + + let csv_file_format = self.csv_file_factory.create(state, &new_options)?; + let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format)); + + Ok(tsv_file_format) + } + + fn default(&self) -> std::sync::Arc { + todo!() + } +} + +impl GetExt for TSVFileFactory { + fn get_ext(&self) -> String { + "tsv".to_string() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new context with the default configuration + let config = SessionConfig::new(); + let runtime = RuntimeEnv::default(); + let mut state = SessionState::new_with_config_rt(config, Arc::new(runtime)); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true).unwrap(); + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table).unwrap(); + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} + +// create a simple mem table +fn create_mem_table() -> Arc { + let fields = vec![ + Field::new("id", DataType::UInt8, false), + Field::new("data", DataType::Utf8, false), + ]; + let schema = Arc::new(Schema::new(fields)); + + let partitions = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["foo", "bar"])), + ], + ) + .unwrap(); + + Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap()) +}