Skip to content

Commit

Permalink
apply a schema to fix column names (#331)
Browse files Browse the repository at this point in the history
Enforce/apply the schema given when we evaluate an expression. 

Given we want to go to expression based fixup and allow the final schema
to dictate the output, we will need to do this.

This code will fix-up at all levels of the output, which is messy in
arrow since schemas are embedded all over the place. The schema is only
applied if the output of the expression doesn't exactly match the passed
schema.

---------

Co-authored-by: Nick Lanham <[email protected]>
Co-authored-by: Ryan Johnson <[email protected]>
  • Loading branch information
3 people authored Oct 21, 2024
1 parent 2b1c46f commit cd53bc1
Show file tree
Hide file tree
Showing 10 changed files with 792 additions and 296 deletions.
9 changes: 9 additions & 0 deletions kernel/examples/read-table-single-threaded/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct Cli {
/// to the aws metadata server, which will fail unless you're on an ec2 instance.
#[arg(long)]
public: bool,

/// Only print the schema of the table
#[arg(long)]
schema_only: bool,
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
Expand Down Expand Up @@ -90,6 +94,11 @@ fn try_main() -> DeltaResult<()> {

let snapshot = table.snapshot(engine.as_ref(), None)?;

if cli.schema_only {
println!("{:#?}", snapshot.schema());
return Ok(());
}

let read_schema_opt = cli
.columns
.map(|cols| -> DeltaResult<_> {
Expand Down
192 changes: 175 additions & 17 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! Expression handling based on arrow-rs compute kernels.
use std::borrow::Borrow;
use std::collections::HashMap;
use std::sync::Arc;

use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene};
use arrow_arith::numeric::{add, div, mul, sub};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{types::*, MapArray};
use arrow_array::{
Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch,
Expand All @@ -21,20 +23,21 @@ use arrow_select::concat::concat;
use itertools::Itertools;

use super::arrow_conversion::LIST_ARRAY_ROOT;
use super::arrow_utils::make_arrow_error;
use crate::engine::arrow_data::ArrowEngineData;
use crate::engine::arrow_utils::ensure_data_types;
use crate::engine::arrow_utils::prim_array_cmp;
use crate::engine::ensure_data_types::ensure_data_types;
use crate::error::{DeltaResult, Error};
use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator};
use crate::schema::{DataType, PrimitiveType, SchemaRef};
use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField};
use crate::{EngineData, ExpressionEvaluator, ExpressionHandler};

// TODO leverage scalars / Datum

fn downcast_to_bool(arr: &dyn Array) -> DeltaResult<&BooleanArray> {
arr.as_any()
.downcast_ref::<BooleanArray>()
.ok_or(Error::generic("expected boolean array"))
.ok_or_else(|| Error::generic("expected boolean array"))
}

impl Scalar {
Expand Down Expand Up @@ -128,21 +131,21 @@ impl Scalar {
}

fn wrap_comparison_result(arr: BooleanArray) -> ArrayRef {
Arc::new(arr) as Arc<dyn Array>
Arc::new(arr) as _
}

trait ProvidesColumnByName {
fn column_by_name(&self, name: &str) -> Option<&Arc<dyn Array>>;
fn column_by_name(&self, name: &str) -> Option<&ArrayRef>;
}

impl ProvidesColumnByName for RecordBatch {
fn column_by_name(&self, name: &str) -> Option<&Arc<dyn Array>> {
fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
self.column_by_name(name)
}
}

impl ProvidesColumnByName for StructArray {
fn column_by_name(&self, name: &str) -> Option<&Arc<dyn Array>> {
fn column_by_name(&self, name: &str) -> Option<&ArrayRef> {
self.column_by_name(name)
}
}
Expand Down Expand Up @@ -201,12 +204,11 @@ fn evaluate_expression(
.iter()
.zip(output_schema.fields())
.map(|(expr, field)| evaluate_expression(expr, batch, Some(field.data_type())));
let output_cols: Vec<Arc<dyn Array>> = columns.try_collect()?;
let output_cols: Vec<ArrayRef> = columns.try_collect()?;
let output_fields: Vec<ArrowField> = output_cols
.iter()
.zip(output_schema.fields())
.map(|(output_col, output_field)| -> DeltaResult<_> {
ensure_data_types(output_field.data_type(), output_col.data_type())?;
Ok(ArrowField::new(
output_field.name(),
output_col.data_type().clone(),
Expand Down Expand Up @@ -306,7 +308,7 @@ fn evaluate_expression(
let left_arr = evaluate_expression(left.as_ref(), batch, None)?;
let right_arr = evaluate_expression(right.as_ref(), batch, None)?;

type Operation = fn(&dyn Datum, &dyn Datum) -> Result<Arc<dyn Array>, ArrowError>;
type Operation = fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, ArrowError>;
let eval: Operation = match op {
Plus => add,
Minus => sub,
Expand Down Expand Up @@ -350,6 +352,164 @@ fn evaluate_expression(
}
}

// Apply a schema to an array. The array _must_ be a `StructArray`. Returns a `RecordBatch where the
// names of fields, nullable, and metadata in the struct have been transformed to match those in
// schema specified by `schema`
fn apply_schema(array: &dyn Array, schema: &DataType) -> DeltaResult<RecordBatch> {
let DataType::Struct(struct_schema) = schema else {
return Err(Error::generic(
"apply_schema at top-level must be passed a struct schema",
));
};
let applied = apply_schema_to_struct(array, struct_schema)?;
Ok(applied.into())
}

// helper to transform an arrow field+col into the specified target type. If `rename` is specified
// the field will be renamed to the contained `str`.
fn new_field_with_metadata(
field_name: &str,
data_type: &ArrowDataType,
nullable: bool,
metadata: Option<HashMap<String, String>>,
) -> ArrowField {
let mut field = ArrowField::new(field_name, data_type.clone(), nullable);
if let Some(metadata) = metadata {
field.set_metadata(metadata);
};
field
}

// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct
// and use that method to transform each column and then put the struct back together. Target types
// and names for each column should be passed in `target_types_and_names`. The number of elements in
// the `target_types_and_names` iterator _must_ be the same as the number of columns in
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields`
// _must_ match the order of the columns in `struct_array`.
fn transform_struct(
struct_array: &StructArray,
target_fields: impl Iterator<Item = impl Borrow<StructField>>,
) -> DeltaResult<StructArray> {
let (_, arrow_cols, nulls) = struct_array.clone().into_parts();
let input_col_count = arrow_cols.len();
let result_iter =
arrow_cols
.into_iter()
.zip(target_fields)
.map(|(sa_col, target_field)| -> DeltaResult<_> {
let target_field = target_field.borrow();
let transformed_col = apply_schema_to(&sa_col, target_field.data_type())?;
let transformed_field = new_field_with_metadata(
&target_field.name,
transformed_col.data_type(),
target_field.nullable,
Some(target_field.metadata_with_string_values()),
);
Ok((transformed_field, transformed_col))
});
let (transformed_fields, transformed_cols): (Vec<ArrowField>, Vec<ArrayRef>) =
result_iter.process_results(|iter| iter.unzip())?;
if transformed_cols.len() != input_col_count {
return Err(Error::InternalError(format!(
"Passed struct had {input_col_count} columns, but transformed column has {}",
transformed_cols.len()
)));
}
Ok(StructArray::try_new(
transformed_fields.into(),
transformed_cols,
nulls,
)?)
}

// Transform a struct array. The data is in `array`, and the target fields are in `kernel_fields`.
fn apply_schema_to_struct(array: &dyn Array, kernel_fields: &Schema) -> DeltaResult<StructArray> {
let Some(sa) = array.as_struct_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a struct but isn't a StructArray",
));
};
transform_struct(sa, kernel_fields.fields())
}

// deconstruct the array, then rebuild the mapped version
fn apply_schema_to_list(
array: &dyn Array,
target_inner_type: &ArrayType,
) -> DeltaResult<ListArray> {
let Some(la) = array.as_list_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a list but isn't a ListArray",
));
};
let (field, offset_buffer, values, nulls) = la.clone().into_parts();

let transformed_values = apply_schema_to(&values, &target_inner_type.element_type)?;
let transformed_field = ArrowField::new(
field.name(),
transformed_values.data_type().clone(),
target_inner_type.contains_null,
);
Ok(ListArray::try_new(
Arc::new(transformed_field),
offset_buffer,
transformed_values,
nulls,
)?)
}

// deconstruct a map, and rebuild it with the specified target kernel type
fn apply_schema_to_map(array: &dyn Array, kernel_map_type: &MapType) -> DeltaResult<MapArray> {
let Some(ma) = array.as_map_opt() else {
return Err(make_arrow_error(
"Arrow claimed to be a map but isn't a MapArray",
));
};
let (map_field, offset_buffer, map_struct_array, nulls, ordered) = ma.clone().into_parts();
let target_fields = map_struct_array
.fields()
.iter()
.zip([&kernel_map_type.key_type, &kernel_map_type.value_type])
.zip([false, kernel_map_type.value_contains_null])
.map(|((arrow_field, target_type), nullable)| {
StructField::new(arrow_field.name(), target_type.clone(), nullable)
});

// Arrow puts the key type/val as the first field/col and the value type/val as the second. So
// we just transform like a 'normal' struct, but we know there are two fields/cols and we
// specify the key/value types as the target type iterator.
let transformed_map_struct_array = transform_struct(&map_struct_array, target_fields)?;

let transformed_map_field = ArrowField::new(
map_field.name().clone(),
transformed_map_struct_array.data_type().clone(),
map_field.is_nullable(),
);
Ok(MapArray::try_new(
Arc::new(transformed_map_field),
offset_buffer,
transformed_map_struct_array,
nulls,
ordered,
)?)
}

// apply `schema` to `array`. This handles renaming, and adjusting nullability and metadata. if the
// actual data types don't match, this will return an error
fn apply_schema_to(array: &ArrayRef, schema: &DataType) -> DeltaResult<ArrayRef> {
use DataType::*;
let array: ArrayRef = match schema {
Struct(stype) => Arc::new(apply_schema_to_struct(array, stype)?),
Array(atype) => Arc::new(apply_schema_to_list(array, atype)?),
Map(mtype) => Arc::new(apply_schema_to_map(array, mtype)?),
_ => {
ensure_data_types(schema, array.data_type(), true)?;
array.clone()
}
};
Ok(array)
}

#[derive(Debug)]
pub struct ArrowExpressionHandler;

Expand Down Expand Up @@ -380,7 +540,7 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {
let batch = batch
.as_any()
.downcast_ref::<ArrowEngineData>()
.ok_or(Error::engine_data_type("ArrowEngineData"))?
.ok_or_else(|| Error::engine_data_type("ArrowEngineData"))?
.record_batch();
let _input_schema: ArrowSchema = self.input_schema.as_ref().try_into()?;
// TODO: make sure we have matching schemas for validation
Expand All @@ -392,13 +552,11 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator {
// )));
// };
let array_ref = evaluate_expression(&self.expression, batch, Some(&self.output_type))?;
let arrow_type: ArrowDataType = ArrowDataType::try_from(&self.output_type)?;
let batch: RecordBatch = if let DataType::Struct(_) = self.output_type {
array_ref
.as_struct_opt()
.ok_or(Error::unexpected_column_type("Expected a struct array"))?
.into()
apply_schema(&array_ref, &self.output_type)?
} else {
let array_ref = apply_schema_to(&array_ref, &self.output_type)?;
let arrow_type: ArrowDataType = ArrowDataType::try_from(&self.output_type)?;
let schema = ArrowSchema::new(vec![ArrowField::new("output", arrow_type, true)]);
RecordBatch::try_new(Arc::new(schema), vec![array_ref])?
};
Expand Down
Loading

0 comments on commit cd53bc1

Please sign in to comment.