diff --git a/kernel/examples/read-table-single-threaded/src/main.rs b/kernel/examples/read-table-single-threaded/src/main.rs index bc4145755..e11c74428 100644 --- a/kernel/examples/read-table-single-threaded/src/main.rs +++ b/kernel/examples/read-table-single-threaded/src/main.rs @@ -41,6 +41,10 @@ struct Cli { /// to the aws metadata server, which will fail unless you're on an ec2 instance. #[arg(long)] public: bool, + + /// Only print the schema of the table + #[arg(long)] + schema_only: bool, } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] @@ -90,6 +94,11 @@ fn try_main() -> DeltaResult<()> { let snapshot = table.snapshot(engine.as_ref(), None)?; + if cli.schema_only { + println!("{:#?}", snapshot.schema()); + return Ok(()); + } + let read_schema_opt = cli .columns .map(|cols| -> DeltaResult<_> { diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index eab128158..fa87cb8bb 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -1,10 +1,12 @@ //! Expression handling based on arrow-rs compute kernels. +use std::borrow::Borrow; +use std::collections::HashMap; use std::sync::Arc; use arrow_arith::boolean::{and_kleene, is_null, not, or_kleene}; use arrow_arith::numeric::{add, div, mul, sub}; use arrow_array::cast::AsArray; -use arrow_array::types::*; +use arrow_array::{types::*, MapArray}; use arrow_array::{ Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Datum, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, RecordBatch, @@ -21,12 +23,13 @@ use arrow_select::concat::concat; use itertools::Itertools; use super::arrow_conversion::LIST_ARRAY_ROOT; +use super::arrow_utils::make_arrow_error; use crate::engine::arrow_data::ArrowEngineData; -use crate::engine::arrow_utils::ensure_data_types; use crate::engine::arrow_utils::prim_array_cmp; +use crate::engine::ensure_data_types::ensure_data_types; use crate::error::{DeltaResult, Error}; use crate::expressions::{BinaryOperator, Expression, Scalar, UnaryOperator, VariadicOperator}; -use crate::schema::{DataType, PrimitiveType, SchemaRef}; +use crate::schema::{ArrayType, DataType, MapType, PrimitiveType, Schema, SchemaRef, StructField}; use crate::{EngineData, ExpressionEvaluator, ExpressionHandler}; // TODO leverage scalars / Datum @@ -34,7 +37,7 @@ use crate::{EngineData, ExpressionEvaluator, ExpressionHandler}; fn downcast_to_bool(arr: &dyn Array) -> DeltaResult<&BooleanArray> { arr.as_any() .downcast_ref::() - .ok_or(Error::generic("expected boolean array")) + .ok_or_else(|| Error::generic("expected boolean array")) } impl Scalar { @@ -128,21 +131,21 @@ impl Scalar { } fn wrap_comparison_result(arr: BooleanArray) -> ArrayRef { - Arc::new(arr) as Arc + Arc::new(arr) as _ } trait ProvidesColumnByName { - fn column_by_name(&self, name: &str) -> Option<&Arc>; + fn column_by_name(&self, name: &str) -> Option<&ArrayRef>; } impl ProvidesColumnByName for RecordBatch { - fn column_by_name(&self, name: &str) -> Option<&Arc> { + fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { self.column_by_name(name) } } impl ProvidesColumnByName for StructArray { - fn column_by_name(&self, name: &str) -> Option<&Arc> { + fn column_by_name(&self, name: &str) -> Option<&ArrayRef> { self.column_by_name(name) } } @@ -201,12 +204,11 @@ fn evaluate_expression( .iter() .zip(output_schema.fields()) .map(|(expr, field)| evaluate_expression(expr, batch, Some(field.data_type()))); - let output_cols: Vec> = columns.try_collect()?; + let output_cols: Vec = columns.try_collect()?; let output_fields: Vec = output_cols .iter() .zip(output_schema.fields()) .map(|(output_col, output_field)| -> DeltaResult<_> { - ensure_data_types(output_field.data_type(), output_col.data_type())?; Ok(ArrowField::new( output_field.name(), output_col.data_type().clone(), @@ -306,7 +308,7 @@ fn evaluate_expression( let left_arr = evaluate_expression(left.as_ref(), batch, None)?; let right_arr = evaluate_expression(right.as_ref(), batch, None)?; - type Operation = fn(&dyn Datum, &dyn Datum) -> Result, ArrowError>; + type Operation = fn(&dyn Datum, &dyn Datum) -> Result; let eval: Operation = match op { Plus => add, Minus => sub, @@ -350,6 +352,164 @@ fn evaluate_expression( } } +// Apply a schema to an array. The array _must_ be a `StructArray`. Returns a `RecordBatch where the +// names of fields, nullable, and metadata in the struct have been transformed to match those in +// schema specified by `schema` +fn apply_schema(array: &dyn Array, schema: &DataType) -> DeltaResult { + let DataType::Struct(struct_schema) = schema else { + return Err(Error::generic( + "apply_schema at top-level must be passed a struct schema", + )); + }; + let applied = apply_schema_to_struct(array, struct_schema)?; + Ok(applied.into()) +} + +// helper to transform an arrow field+col into the specified target type. If `rename` is specified +// the field will be renamed to the contained `str`. +fn new_field_with_metadata( + field_name: &str, + data_type: &ArrowDataType, + nullable: bool, + metadata: Option>, +) -> ArrowField { + let mut field = ArrowField::new(field_name, data_type.clone(), nullable); + if let Some(metadata) = metadata { + field.set_metadata(metadata); + }; + field +} + +// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct +// and use that method to transform each column and then put the struct back together. Target types +// and names for each column should be passed in `target_types_and_names`. The number of elements in +// the `target_types_and_names` iterator _must_ be the same as the number of columns in +// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields` +// _must_ match the order of the columns in `struct_array`. +fn transform_struct( + struct_array: &StructArray, + target_fields: impl Iterator>, +) -> DeltaResult { + let (_, arrow_cols, nulls) = struct_array.clone().into_parts(); + let input_col_count = arrow_cols.len(); + let result_iter = + arrow_cols + .into_iter() + .zip(target_fields) + .map(|(sa_col, target_field)| -> DeltaResult<_> { + let target_field = target_field.borrow(); + let transformed_col = apply_schema_to(&sa_col, target_field.data_type())?; + let transformed_field = new_field_with_metadata( + &target_field.name, + transformed_col.data_type(), + target_field.nullable, + Some(target_field.metadata_with_string_values()), + ); + Ok((transformed_field, transformed_col)) + }); + let (transformed_fields, transformed_cols): (Vec, Vec) = + result_iter.process_results(|iter| iter.unzip())?; + if transformed_cols.len() != input_col_count { + return Err(Error::InternalError(format!( + "Passed struct had {input_col_count} columns, but transformed column has {}", + transformed_cols.len() + ))); + } + Ok(StructArray::try_new( + transformed_fields.into(), + transformed_cols, + nulls, + )?) +} + +// Transform a struct array. The data is in `array`, and the target fields are in `kernel_fields`. +fn apply_schema_to_struct(array: &dyn Array, kernel_fields: &Schema) -> DeltaResult { + let Some(sa) = array.as_struct_opt() else { + return Err(make_arrow_error( + "Arrow claimed to be a struct but isn't a StructArray", + )); + }; + transform_struct(sa, kernel_fields.fields()) +} + +// deconstruct the array, then rebuild the mapped version +fn apply_schema_to_list( + array: &dyn Array, + target_inner_type: &ArrayType, +) -> DeltaResult { + let Some(la) = array.as_list_opt() else { + return Err(make_arrow_error( + "Arrow claimed to be a list but isn't a ListArray", + )); + }; + let (field, offset_buffer, values, nulls) = la.clone().into_parts(); + + let transformed_values = apply_schema_to(&values, &target_inner_type.element_type)?; + let transformed_field = ArrowField::new( + field.name(), + transformed_values.data_type().clone(), + target_inner_type.contains_null, + ); + Ok(ListArray::try_new( + Arc::new(transformed_field), + offset_buffer, + transformed_values, + nulls, + )?) +} + +// deconstruct a map, and rebuild it with the specified target kernel type +fn apply_schema_to_map(array: &dyn Array, kernel_map_type: &MapType) -> DeltaResult { + let Some(ma) = array.as_map_opt() else { + return Err(make_arrow_error( + "Arrow claimed to be a map but isn't a MapArray", + )); + }; + let (map_field, offset_buffer, map_struct_array, nulls, ordered) = ma.clone().into_parts(); + let target_fields = map_struct_array + .fields() + .iter() + .zip([&kernel_map_type.key_type, &kernel_map_type.value_type]) + .zip([false, kernel_map_type.value_contains_null]) + .map(|((arrow_field, target_type), nullable)| { + StructField::new(arrow_field.name(), target_type.clone(), nullable) + }); + + // Arrow puts the key type/val as the first field/col and the value type/val as the second. So + // we just transform like a 'normal' struct, but we know there are two fields/cols and we + // specify the key/value types as the target type iterator. + let transformed_map_struct_array = transform_struct(&map_struct_array, target_fields)?; + + let transformed_map_field = ArrowField::new( + map_field.name().clone(), + transformed_map_struct_array.data_type().clone(), + map_field.is_nullable(), + ); + Ok(MapArray::try_new( + Arc::new(transformed_map_field), + offset_buffer, + transformed_map_struct_array, + nulls, + ordered, + )?) +} + +// apply `schema` to `array`. This handles renaming, and adjusting nullability and metadata. if the +// actual data types don't match, this will return an error +fn apply_schema_to(array: &ArrayRef, schema: &DataType) -> DeltaResult { + use DataType::*; + let array: ArrayRef = match schema { + Struct(stype) => Arc::new(apply_schema_to_struct(array, stype)?), + Array(atype) => Arc::new(apply_schema_to_list(array, atype)?), + Map(mtype) => Arc::new(apply_schema_to_map(array, mtype)?), + _ => { + ensure_data_types(schema, array.data_type(), true)?; + array.clone() + } + }; + Ok(array) +} + #[derive(Debug)] pub struct ArrowExpressionHandler; @@ -380,7 +540,7 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator { let batch = batch .as_any() .downcast_ref::() - .ok_or(Error::engine_data_type("ArrowEngineData"))? + .ok_or_else(|| Error::engine_data_type("ArrowEngineData"))? .record_batch(); let _input_schema: ArrowSchema = self.input_schema.as_ref().try_into()?; // TODO: make sure we have matching schemas for validation @@ -392,13 +552,11 @@ impl ExpressionEvaluator for DefaultExpressionEvaluator { // ))); // }; let array_ref = evaluate_expression(&self.expression, batch, Some(&self.output_type))?; - let arrow_type: ArrowDataType = ArrowDataType::try_from(&self.output_type)?; let batch: RecordBatch = if let DataType::Struct(_) = self.output_type { - array_ref - .as_struct_opt() - .ok_or(Error::unexpected_column_type("Expected a struct array"))? - .into() + apply_schema(&array_ref, &self.output_type)? } else { + let array_ref = apply_schema_to(&array_ref, &self.output_type)?; + let arrow_type: ArrowDataType = ArrowDataType::try_from(&self.output_type)?; let schema = ArrowSchema::new(vec![ArrowField::new("output", arrow_type, true)]); RecordBatch::try_new(Arc::new(schema), vec![array_ref])? }; diff --git a/kernel/src/engine/arrow_utils.rs b/kernel/src/engine/arrow_utils.rs index 7edbe2828..f8680b403 100644 --- a/kernel/src/engine/arrow_utils.rs +++ b/kernel/src/engine/arrow_utils.rs @@ -4,9 +4,10 @@ use std::collections::HashSet; use std::io::{BufRead, BufReader}; use std::sync::Arc; +use crate::engine::ensure_data_types::DataTypeCompat; use crate::{ engine::arrow_data::ArrowEngineData, - schema::{DataType, PrimitiveType, Schema, SchemaRef, StructField, StructType}, + schema::{DataType, Schema, SchemaRef, StructField, StructType}, utils::require, DeltaResult, EngineData, Error, }; @@ -58,169 +59,8 @@ pub(crate) use prim_array_cmp; /// returns a tuples of (mask_indicies: Vec, reorder_indicies: /// Vec). `mask_indicies` is used for generating the mask for reading from the -fn make_arrow_error(s: String) -> Error { - Error::Arrow(arrow_schema::ArrowError::InvalidArgumentError(s)) -} - -/// Capture the compatibility between two data-types, as passed to [`ensure_data_types`] -pub(crate) enum DataTypeCompat { - /// The two types are the same - Identical, - /// What is read from parquet needs to be cast to the associated type - NeedsCast(ArrowDataType), - /// Types are compatible, but are nested types. This is used when comparing types where casting - /// is not desired (i.e. in the expression evaluator) - Nested, -} - -// Check if two types can be cast -fn check_cast_compat( - target_type: ArrowDataType, - source_type: &ArrowDataType, -) -> DeltaResult { - use ArrowDataType::*; - - match (source_type, &target_type) { - (source_type, target_type) if source_type == target_type => Ok(DataTypeCompat::Identical), - (&ArrowDataType::Timestamp(_, _), &ArrowDataType::Timestamp(_, _)) => { - // timestamps are able to be cast between each other - Ok(DataTypeCompat::NeedsCast(target_type)) - } - // Allow up-casting to a larger type if it's safe and can't cause overflow or loss of precision. - (Int8, Int16 | Int32 | Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), - (Int16, Int32 | Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), - (Int32, Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), - (Float32, Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), - (_, Decimal128(p, s)) if can_upcast_to_decimal(source_type, *p, *s) => { - Ok(DataTypeCompat::NeedsCast(target_type)) - } - (Date32, Timestamp(_, None)) => Ok(DataTypeCompat::NeedsCast(target_type)), - _ => Err(make_arrow_error(format!( - "Incorrect datatype. Expected {}, got {}", - target_type, source_type - ))), - } -} - -// Returns whether the given source type can be safely cast to a decimal with the given precision and scale without -// loss of information. -fn can_upcast_to_decimal( - source_type: &ArrowDataType, - target_precision: u8, - target_scale: i8, -) -> bool { - use ArrowDataType::*; - - let (source_precision, source_scale) = match source_type { - Decimal128(p, s) => (*p, *s), - // Allow converting integers to a decimal that can hold all possible values. - Int8 => (3u8, 0i8), - Int16 => (5u8, 0i8), - Int32 => (10u8, 0i8), - Int64 => (20u8, 0i8), - _ => return false, - }; - - target_precision >= source_precision - && target_scale >= source_scale - && target_precision - source_precision >= (target_scale - source_scale) as u8 -} - -/// Ensure a kernel data type matches an arrow data type. This only ensures that the actual "type" -/// is the same, but does so recursively into structs, and ensures lists and maps have the correct -/// associated types as well. This returns an `Ok(DataTypeCompat)` if the types are compatible, and -/// will indicate what kind of compatibility they have, or an error if the types do not match. If -/// there is a `struct` type included, we only ensure that the named fields that the kernel is -/// asking for exist, and that for those fields the types match. Un-selected fields are ignored. -pub(crate) fn ensure_data_types( - kernel_type: &DataType, - arrow_type: &ArrowDataType, -) -> DeltaResult { - match (kernel_type, arrow_type) { - (DataType::Primitive(_), _) if arrow_type.is_primitive() => { - check_cast_compat(kernel_type.try_into()?, arrow_type) - } - (&DataType::BOOLEAN, ArrowDataType::Boolean) - | (&DataType::STRING, ArrowDataType::Utf8) - | (&DataType::BINARY, ArrowDataType::Binary) => { - // strings, bools, and binary aren't primitive in arrow - Ok(DataTypeCompat::Identical) - } - ( - DataType::Primitive(PrimitiveType::Decimal(kernel_prec, kernel_scale)), - ArrowDataType::Decimal128(arrow_prec, arrow_scale), - ) if arrow_prec == kernel_prec && *arrow_scale == *kernel_scale as i8 => { - // decimal isn't primitive in arrow. cast above is okay as we limit range - Ok(DataTypeCompat::Identical) - } - (DataType::Array(inner_type), ArrowDataType::List(arrow_list_type)) => { - let kernel_array_type = &inner_type.element_type; - let arrow_list_type = arrow_list_type.data_type(); - ensure_data_types(kernel_array_type, arrow_list_type) - } - (DataType::Map(kernel_map_type), ArrowDataType::Map(arrow_map_type, _)) => { - if let ArrowDataType::Struct(fields) = arrow_map_type.data_type() { - let mut fields = fields.iter(); - if let Some(key_type) = fields.next() { - ensure_data_types(&kernel_map_type.key_type, key_type.data_type())?; - } else { - return Err(make_arrow_error( - "Arrow map struct didn't have a key type".to_string(), - )); - } - if let Some(value_type) = fields.next() { - ensure_data_types(&kernel_map_type.value_type, value_type.data_type())?; - } else { - return Err(make_arrow_error( - "Arrow map struct didn't have a value type".to_string(), - )); - } - if fields.next().is_some() { - return Err(Error::generic("map fields had more than 2 members")); - } - Ok(DataTypeCompat::Nested) - } else { - Err(make_arrow_error( - "Arrow map type wasn't a struct.".to_string(), - )) - } - } - (DataType::Struct(kernel_fields), ArrowDataType::Struct(arrow_fields)) => { - // build a list of kernel fields that matches the order of the arrow fields - let mapped_fields = arrow_fields - .iter() - .filter_map(|f| kernel_fields.fields.get(f.name())); - - // keep track of how many fields we matched up - let mut found_fields = 0; - // ensure that for the fields that we found, the types match - for (kernel_field, arrow_field) in mapped_fields.zip(arrow_fields) { - ensure_data_types(&kernel_field.data_type, arrow_field.data_type())?; - found_fields += 1; - } - - // require that we found the number of fields that we requested. - require!(kernel_fields.fields.len() == found_fields, { - let arrow_field_map: HashSet<&String> = - HashSet::from_iter(arrow_fields.iter().map(|f| f.name())); - let missing_field_names = kernel_fields - .fields - .keys() - .filter(|kernel_field| !arrow_field_map.contains(kernel_field)) - .take(5) - .join(", "); - make_arrow_error(format!( - "Missing Struct fields {} (Up to five missing fields shown)", - missing_field_names - )) - }); - Ok(DataTypeCompat::Nested) - } - _ => Err(make_arrow_error(format!( - "Incorrect datatype. Expected {}, got {}", - kernel_type, arrow_type - ))), - } +pub(crate) fn make_arrow_error(s: impl Into) -> Error { + Error::Arrow(arrow_schema::ArrowError::InvalidArgumentError(s.into())).with_backtrace() } /* @@ -516,7 +356,15 @@ fn get_indices( } } _ => { - match ensure_data_types(&requested_field.data_type, field.data_type())? { + // we don't care about matching on nullability or metadata here so pass `false` + // as the final argument. These can differ between the delta schema and the + // parquet schema without causing issues in reading the data. We fix them up in + // expression evaluation later. + match super::ensure_data_types::ensure_data_types( + &requested_field.data_type, + field.data_type(), + false, + )? { DataTypeCompat::Identical => { reorder_indices.push(ReorderIndex::identity(index)) } @@ -636,6 +484,7 @@ pub(crate) fn reorder_struct_array( input_data: StructArray, requested_ordering: &[ReorderIndex], ) -> DeltaResult { + debug!("Reordering {input_data:?} with ordering: {requested_ordering:?}"); if !ordering_needs_transform(requested_ordering) { // indices is already sorted, meaning we requested in the order that the columns were // stored in the parquet @@ -944,7 +793,6 @@ mod tests { ArrowField::new("s", ArrowDataType::Int32, true), ])); let res = get_requested_indices(&requested_schema, &parquet_schema); - println!("{res:#?}"); assert!(res.is_err()); } @@ -1560,61 +1408,4 @@ mod tests { assert_eq!(mask_indices, expect_mask); assert_eq!(reorder_indices, expect_reorder); } - - #[test] - fn accepts_safe_decimal_casts() { - use super::can_upcast_to_decimal; - use ArrowDataType::*; - - assert!(can_upcast_to_decimal(&Decimal128(1, 0), 2u8, 0i8)); - assert!(can_upcast_to_decimal(&Decimal128(1, 0), 2u8, 1i8)); - assert!(can_upcast_to_decimal(&Decimal128(5, -2), 6u8, -2i8)); - assert!(can_upcast_to_decimal(&Decimal128(5, -2), 6u8, -1i8)); - assert!(can_upcast_to_decimal(&Decimal128(5, 1), 6u8, 1i8)); - assert!(can_upcast_to_decimal(&Decimal128(5, 1), 6u8, 2i8)); - assert!(can_upcast_to_decimal( - &Decimal128(10, 5), - arrow_schema::DECIMAL128_MAX_PRECISION, - arrow_schema::DECIMAL128_MAX_SCALE - 5 - )); - - assert!(can_upcast_to_decimal(&Int8, 3u8, 0i8)); - assert!(can_upcast_to_decimal(&Int8, 4u8, 0i8)); - assert!(can_upcast_to_decimal(&Int8, 4u8, 1i8)); - assert!(can_upcast_to_decimal(&Int8, 7u8, 2i8)); - - assert!(can_upcast_to_decimal(&Int16, 5u8, 0i8)); - assert!(can_upcast_to_decimal(&Int16, 6u8, 0i8)); - assert!(can_upcast_to_decimal(&Int16, 6u8, 1i8)); - assert!(can_upcast_to_decimal(&Int16, 9u8, 2i8)); - - assert!(can_upcast_to_decimal(&Int32, 10u8, 0i8)); - assert!(can_upcast_to_decimal(&Int32, 11u8, 0i8)); - assert!(can_upcast_to_decimal(&Int32, 11u8, 1i8)); - assert!(can_upcast_to_decimal(&Int32, 14u8, 2i8)); - - assert!(can_upcast_to_decimal(&Int64, 20u8, 0i8)); - assert!(can_upcast_to_decimal(&Int64, 21u8, 0i8)); - assert!(can_upcast_to_decimal(&Int64, 21u8, 1i8)); - assert!(can_upcast_to_decimal(&Int64, 24u8, 2i8)); - } - - #[test] - fn rejects_unsafe_decimal_casts() { - use super::can_upcast_to_decimal; - use ArrowDataType::*; - - assert!(!can_upcast_to_decimal(&Decimal128(2, 0), 2u8, 1i8)); - assert!(!can_upcast_to_decimal(&Decimal128(2, 0), 2u8, -1i8)); - assert!(!can_upcast_to_decimal(&Decimal128(5, 2), 6u8, 4i8)); - - assert!(!can_upcast_to_decimal(&Int8, 2u8, 0i8)); - assert!(!can_upcast_to_decimal(&Int8, 3u8, 1i8)); - assert!(!can_upcast_to_decimal(&Int16, 4u8, 0i8)); - assert!(!can_upcast_to_decimal(&Int16, 5u8, 1i8)); - assert!(!can_upcast_to_decimal(&Int32, 9u8, 0i8)); - assert!(!can_upcast_to_decimal(&Int32, 10u8, 1i8)); - assert!(!can_upcast_to_decimal(&Int64, 19u8, 0i8)); - assert!(!can_upcast_to_decimal(&Int64, 20u8, 1i8)); - } } diff --git a/kernel/src/engine/ensure_data_types.rs b/kernel/src/engine/ensure_data_types.rs new file mode 100644 index 000000000..9b7ea7819 --- /dev/null +++ b/kernel/src/engine/ensure_data_types.rs @@ -0,0 +1,467 @@ +//! Helpers to ensure that kernel data types match arrow data types + +use std::{collections::{HashMap, HashSet}, ops::Deref}; + +use arrow_schema::{DataType as ArrowDataType, Field as ArrowField}; +use itertools::Itertools; + +use crate::{ + engine::arrow_utils::make_arrow_error, + schema::{DataType, MetadataValue, StructField}, + utils::require, + DeltaResult, Error, +}; + +/// Ensure a kernel data type matches an arrow data type. This only ensures that the actual "type" +/// is the same, but does so recursively into structs, and ensures lists and maps have the correct +/// associated types as well. +/// +/// If `check_nullability_and_metadata` is true, this will also return an error if it finds a struct +/// field that differs in nullability or metadata between the kernel and arrow schema. If it is +/// false, no checks on nullability or metadata are performed. +/// +/// This returns an `Ok(DataTypeCompat)` if the types are compatible, and +/// will indicate what kind of compatibility they have, or an error if the types do not match. If +/// there is a `struct` type included, we only ensure that the named fields that the kernel is +/// asking for exist, and that for those fields the types match. Un-selected fields are ignored. +pub(crate) fn ensure_data_types( + kernel_type: &DataType, + arrow_type: &ArrowDataType, + check_nullability_and_metadata: bool, +) -> DeltaResult { + let check = EnsureDataTypes { check_nullability_and_metadata }; + check.ensure_data_types(kernel_type, arrow_type) +} + +struct EnsureDataTypes { + check_nullability_and_metadata: bool, +} + +/// Capture the compatibility between two data-types, as passed to [`ensure_data_types`] +pub(crate) enum DataTypeCompat { + /// The two types are the same + Identical, + /// What is read from parquet needs to be cast to the associated type + NeedsCast(ArrowDataType), + /// Types are compatible, but are nested types. This is used when comparing types where casting + /// is not desired (i.e. in the expression evaluator) + Nested, +} + +impl EnsureDataTypes { + // Perform the check. See documentation for `ensure_data_types` entry point method above + fn ensure_data_types( + &self, + kernel_type: &DataType, + arrow_type: &ArrowDataType, + ) -> DeltaResult { + match (kernel_type, arrow_type) { + (DataType::Primitive(_), _) if arrow_type.is_primitive() => { + check_cast_compat(kernel_type.try_into()?, arrow_type) + } + // strings, bools, and binary aren't primitive in arrow + (&DataType::BOOLEAN, ArrowDataType::Boolean) + | (&DataType::STRING, ArrowDataType::Utf8) + | (&DataType::BINARY, ArrowDataType::Binary) => { + Ok(DataTypeCompat::Identical) + } + (DataType::Array(inner_type), ArrowDataType::List(arrow_list_field)) => { + self.ensure_nullability( + "List", + inner_type.contains_null, + arrow_list_field.is_nullable(), + )?; + self.ensure_data_types( + &inner_type.element_type, + arrow_list_field.data_type(), + ) + } + (DataType::Map(kernel_map_type), ArrowDataType::Map(arrow_map_type, _)) => { + let ArrowDataType::Struct(fields) = arrow_map_type.data_type() else { + return Err(make_arrow_error("Arrow map type wasn't a struct.")); + }; + let [key_type, value_type] = fields.deref() else { + return Err(make_arrow_error("Arrow map type didn't have expected key/value fields")); + }; + self.ensure_data_types( + &kernel_map_type.key_type, + key_type.data_type(), + )?; + self.ensure_nullability( + "Map", + kernel_map_type.value_contains_null, + value_type.is_nullable(), + )?; + self.ensure_data_types( + &kernel_map_type.value_type, + value_type.data_type(), + )?; + Ok(DataTypeCompat::Nested) + } + (DataType::Struct(kernel_fields), ArrowDataType::Struct(arrow_fields)) => { + // build a list of kernel fields that matches the order of the arrow fields + let mapped_fields = arrow_fields + .iter() + .filter_map(|f| kernel_fields.fields.get(f.name())); + + // keep track of how many fields we matched up + let mut found_fields = 0; + // ensure that for the fields that we found, the types match + for (kernel_field, arrow_field) in mapped_fields.zip(arrow_fields) { + self.ensure_nullability_and_metadata(kernel_field, arrow_field)?; + self.ensure_data_types( + &kernel_field.data_type, + arrow_field.data_type(), + )?; + found_fields += 1; + } + + // require that we found the number of fields that we requested. + require!(kernel_fields.fields.len() == found_fields, { + let arrow_field_map: HashSet<&String> = + HashSet::from_iter(arrow_fields.iter().map(|f| f.name())); + let missing_field_names = kernel_fields + .fields + .keys() + .filter(|kernel_field| !arrow_field_map.contains(kernel_field)) + .take(5) + .join(", "); + make_arrow_error(format!( + "Missing Struct fields {} (Up to five missing fields shown)", + missing_field_names + )) + }); + Ok(DataTypeCompat::Nested) + } + _ => Err(make_arrow_error(format!( + "Incorrect datatype. Expected {}, got {}", + kernel_type, arrow_type + ))), + } + } + + fn ensure_nullability( + &self, + desc: &str, + kernel_field_is_nullable: bool, + arrow_field_is_nullable: bool, + ) -> DeltaResult<()> { + if self.check_nullability_and_metadata && kernel_field_is_nullable != arrow_field_is_nullable { + Err(Error::Generic(format!( + "{desc} has nullablily {} in kernel and {} in arrow", + kernel_field_is_nullable, + arrow_field_is_nullable, + ))) + } else { + Ok(()) + } + } + + fn ensure_nullability_and_metadata( + &self, + kernel_field: &StructField, + arrow_field: &ArrowField + ) -> DeltaResult<()> { + self.ensure_nullability(&kernel_field.name, kernel_field.nullable, arrow_field.is_nullable())?; + if self.check_nullability_and_metadata && !metadata_eq(&kernel_field.metadata, arrow_field.metadata()) { + Err(Error::Generic(format!( + "Field {} has metadata {:?} in kernel and {:?} in arrow", + kernel_field.name, + kernel_field.metadata, + arrow_field.metadata(), + ))) + } else { + Ok(()) + } + } +} + +// Check if two types can be cast +fn check_cast_compat( + target_type: ArrowDataType, + source_type: &ArrowDataType, +) -> DeltaResult { + use ArrowDataType::*; + + match (source_type, &target_type) { + (source_type, target_type) if source_type == target_type => Ok(DataTypeCompat::Identical), + (&ArrowDataType::Timestamp(_, _), &ArrowDataType::Timestamp(_, _)) => { + // timestamps are able to be cast between each other + Ok(DataTypeCompat::NeedsCast(target_type)) + } + // Allow up-casting to a larger type if it's safe and can't cause overflow or loss of precision. + (Int8, Int16 | Int32 | Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), + (Int16, Int32 | Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), + (Int32, Int64 | Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), + (Float32, Float64) => Ok(DataTypeCompat::NeedsCast(target_type)), + (_, Decimal128(p, s)) if can_upcast_to_decimal(source_type, *p, *s) => { + Ok(DataTypeCompat::NeedsCast(target_type)) + } + (Date32, Timestamp(_, None)) => Ok(DataTypeCompat::NeedsCast(target_type)), + _ => Err(make_arrow_error(format!( + "Incorrect datatype. Expected {}, got {}", + target_type, source_type + ))), + } +} + +// Returns whether the given source type can be safely cast to a decimal with the given precision and scale without +// loss of information. +fn can_upcast_to_decimal( + source_type: &ArrowDataType, + target_precision: u8, + target_scale: i8, +) -> bool { + use ArrowDataType::*; + + let (source_precision, source_scale) = match source_type { + Decimal128(p, s) => (*p, *s), + // Allow converting integers to a decimal that can hold all possible values. + Int8 => (3u8, 0i8), + Int16 => (5u8, 0i8), + Int32 => (10u8, 0i8), + Int64 => (20u8, 0i8), + _ => return false, + }; + + target_precision >= source_precision + && target_scale >= source_scale + && target_precision - source_precision >= (target_scale - source_scale) as u8 +} + +impl PartialEq for MetadataValue { + fn eq(&self, other: &String) -> bool { + self.to_string().eq(other) + } +} + +// allow for comparing our metadata maps to arrow ones. We can't implement PartialEq because both +// are HashMaps which aren't defined in this crate +fn metadata_eq( + kernel_metadata: &HashMap, + arrow_metadata: &HashMap, +) -> bool { + let kernel_len = kernel_metadata.len(); + if kernel_len != arrow_metadata.len() { + return false; + } + if kernel_len == 0 { + // lens are equal, so two empty maps are equal + return true; + } + kernel_metadata + .iter() + .all(|(key, value)| arrow_metadata.get(key).is_some_and(|v| *value == *v)) +} + +#[cfg(test)] +mod tests { + use arrow_schema::{DataType as ArrowDataType, Field as ArrowField, Fields}; + + use crate::{ + engine::ensure_data_types::ensure_data_types, + schema::{ArrayType, DataType, MapType, StructField}, + }; + + #[test] + fn accepts_safe_decimal_casts() { + use super::can_upcast_to_decimal; + use ArrowDataType::*; + + assert!(can_upcast_to_decimal(&Decimal128(1, 0), 2u8, 0i8)); + assert!(can_upcast_to_decimal(&Decimal128(1, 0), 2u8, 1i8)); + assert!(can_upcast_to_decimal(&Decimal128(5, -2), 6u8, -2i8)); + assert!(can_upcast_to_decimal(&Decimal128(5, -2), 6u8, -1i8)); + assert!(can_upcast_to_decimal(&Decimal128(5, 1), 6u8, 1i8)); + assert!(can_upcast_to_decimal(&Decimal128(5, 1), 6u8, 2i8)); + assert!(can_upcast_to_decimal( + &Decimal128(10, 5), + arrow_schema::DECIMAL128_MAX_PRECISION, + arrow_schema::DECIMAL128_MAX_SCALE - 5 + )); + + assert!(can_upcast_to_decimal(&Int8, 3u8, 0i8)); + assert!(can_upcast_to_decimal(&Int8, 4u8, 0i8)); + assert!(can_upcast_to_decimal(&Int8, 4u8, 1i8)); + assert!(can_upcast_to_decimal(&Int8, 7u8, 2i8)); + + assert!(can_upcast_to_decimal(&Int16, 5u8, 0i8)); + assert!(can_upcast_to_decimal(&Int16, 6u8, 0i8)); + assert!(can_upcast_to_decimal(&Int16, 6u8, 1i8)); + assert!(can_upcast_to_decimal(&Int16, 9u8, 2i8)); + + assert!(can_upcast_to_decimal(&Int32, 10u8, 0i8)); + assert!(can_upcast_to_decimal(&Int32, 11u8, 0i8)); + assert!(can_upcast_to_decimal(&Int32, 11u8, 1i8)); + assert!(can_upcast_to_decimal(&Int32, 14u8, 2i8)); + + assert!(can_upcast_to_decimal(&Int64, 20u8, 0i8)); + assert!(can_upcast_to_decimal(&Int64, 21u8, 0i8)); + assert!(can_upcast_to_decimal(&Int64, 21u8, 1i8)); + assert!(can_upcast_to_decimal(&Int64, 24u8, 2i8)); + } + + #[test] + fn rejects_unsafe_decimal_casts() { + use super::can_upcast_to_decimal; + use ArrowDataType::*; + + assert!(!can_upcast_to_decimal(&Decimal128(2, 0), 2u8, 1i8)); + assert!(!can_upcast_to_decimal(&Decimal128(2, 0), 2u8, -1i8)); + assert!(!can_upcast_to_decimal(&Decimal128(5, 2), 6u8, 4i8)); + + assert!(!can_upcast_to_decimal(&Int8, 2u8, 0i8)); + assert!(!can_upcast_to_decimal(&Int8, 3u8, 1i8)); + assert!(!can_upcast_to_decimal(&Int16, 4u8, 0i8)); + assert!(!can_upcast_to_decimal(&Int16, 5u8, 1i8)); + assert!(!can_upcast_to_decimal(&Int32, 9u8, 0i8)); + assert!(!can_upcast_to_decimal(&Int32, 10u8, 1i8)); + assert!(!can_upcast_to_decimal(&Int64, 19u8, 0i8)); + assert!(!can_upcast_to_decimal(&Int64, 20u8, 1i8)); + } + + #[test] + fn ensure_decimals() { + assert!(ensure_data_types( + &DataType::decimal_unchecked(5, 2), + &ArrowDataType::Decimal128(5, 2), + false + ) + .is_ok()); + assert!(ensure_data_types( + &DataType::decimal_unchecked(5, 2), + &ArrowDataType::Decimal128(5, 3), + false + ) + .is_err()); + } + + #[test] + fn ensure_map() { + let arrow_field = ArrowField::new_map( + "map", + "entries", + ArrowField::new("key", ArrowDataType::Int64, false), + ArrowField::new("val", ArrowDataType::Utf8, true), + false, + false, + ); + assert!(ensure_data_types( + &DataType::Map(Box::new(MapType::new( + DataType::LONG, + DataType::STRING, + true + ))), + arrow_field.data_type(), + false + ) + .is_ok()); + + assert!(ensure_data_types( + &DataType::Map(Box::new(MapType::new( + DataType::LONG, + DataType::STRING, + false + ))), + arrow_field.data_type(), + true + ) + .is_err()); + + assert!(ensure_data_types( + &DataType::Map(Box::new(MapType::new(DataType::LONG, DataType::LONG, true))), + arrow_field.data_type(), + false + ) + .is_err()); + } + + #[test] + fn ensure_list() { + assert!(ensure_data_types( + &DataType::Array(Box::new(ArrayType::new(DataType::LONG, true))), + &ArrowDataType::new_list(ArrowDataType::Int64, true), + false + ) + .is_ok()); + assert!(ensure_data_types( + &DataType::Array(Box::new(ArrayType::new(DataType::STRING, true))), + &ArrowDataType::new_list(ArrowDataType::Int64, true), + false + ) + .is_err()); + assert!(ensure_data_types( + &DataType::Array(Box::new(ArrayType::new(DataType::LONG, true))), + &ArrowDataType::new_list(ArrowDataType::Int64, false), + true + ) + .is_err()); + } + + #[test] + fn ensure_struct() { + let schema = DataType::struct_type([StructField::new( + "a", + ArrayType::new( + DataType::struct_type([ + StructField::new("w", DataType::LONG, true), + StructField::new("x", ArrayType::new(DataType::LONG, true), true), + StructField::new( + "y", + MapType::new(DataType::LONG, DataType::STRING, true), + true, + ), + StructField::new( + "z", + DataType::struct_type([ + StructField::new("n", DataType::LONG, true), + StructField::new("m", DataType::STRING, true), + ]), + true, + ), + ]), + true, + ), + true, + )]); + let arrow_struct: ArrowDataType = (&schema).try_into().unwrap(); + assert!(ensure_data_types(&schema, &arrow_struct, true).is_ok()); + + let kernel_simple = DataType::struct_type([ + StructField::new("w", DataType::LONG, true), + StructField::new("x", DataType::LONG, true), + ]); + + let arrow_simple_ok = ArrowField::new_struct( + "arrow_struct", + Fields::from(vec![ + ArrowField::new("w", ArrowDataType::Int64, true), + ArrowField::new("x", ArrowDataType::Int64, true), + ]), + true, + ); + assert!(ensure_data_types(&kernel_simple, arrow_simple_ok.data_type(), true).is_ok()); + + let arrow_missing_simple = ArrowField::new_struct( + "arrow_struct", + Fields::from(vec![ArrowField::new("w", ArrowDataType::Int64, true)]), + true, + ); + assert!(ensure_data_types(&kernel_simple, arrow_missing_simple.data_type(), true).is_err()); + + let arrow_nullable_mismatch_simple = ArrowField::new_struct( + "arrow_struct", + Fields::from(vec![ + ArrowField::new("w", ArrowDataType::Int64, false), + ArrowField::new("x", ArrowDataType::Int64, true), + ]), + true, + ); + assert!(ensure_data_types( + &kernel_simple, + arrow_nullable_mismatch_simple.data_type(), + true + ) + .is_err()); + } +} diff --git a/kernel/src/engine/mod.rs b/kernel/src/engine/mod.rs index 626bc134a..c19bc8fc6 100644 --- a/kernel/src/engine/mod.rs +++ b/kernel/src/engine/mod.rs @@ -5,26 +5,32 @@ #[cfg(feature = "arrow-conversion")] pub(crate) mod arrow_conversion; -#[cfg(feature = "arrow-expression")] +#[cfg(all( + feature = "arrow-expression", + any(feature = "default-engine", feature = "sync-engine") +))] pub mod arrow_expression; -#[cfg(any(feature = "default-engine", feature = "sync-engine"))] -pub mod arrow_data; - -#[cfg(any(feature = "default-engine", feature = "sync-engine"))] -pub mod parquet_row_group_skipping; - -#[cfg(any(feature = "default-engine", feature = "sync-engine"))] -pub mod parquet_stats_skipping; - -#[cfg(any(feature = "default-engine", feature = "sync-engine"))] -pub(crate) mod arrow_get_data; - -#[cfg(any(feature = "default-engine", feature = "sync-engine"))] -pub(crate) mod arrow_utils; - #[cfg(feature = "default-engine")] pub mod default; #[cfg(feature = "sync-engine")] pub mod sync; + +macro_rules! declare_modules { + ( $(($vis:vis, $module:ident)),*) => { + $( + $vis mod $module; + )* + }; +} + +#[cfg(any(feature = "default-engine", feature = "sync-engine"))] +declare_modules!( + (pub, arrow_data), + (pub, parquet_row_group_skipping), + (pub, parquet_stats_skipping), + (pub(crate), arrow_get_data), + (pub(crate), arrow_utils), + (pub(crate), ensure_data_types) +); diff --git a/kernel/src/error.rs b/kernel/src/error.rs index b35cf8319..78cab4ad6 100644 --- a/kernel/src/error.rs +++ b/kernel/src/error.rs @@ -175,7 +175,7 @@ impl Error { Self::FileNotFound(path.to_string()) } pub fn missing_column(name: impl ToString) -> Self { - Self::MissingColumn(name.to_string()) + Self::MissingColumn(name.to_string()).with_backtrace() } pub fn unexpected_column_type(name: impl ToString) -> Self { Self::UnexpectedColumnType(name.to_string()) diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index 32503246f..3c52e5e2c 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -80,19 +80,20 @@ impl DataVisitor for AddRemoveVisitor { // for `scan_row_schema` in scan/mod.rs! You'll also need to update ScanFileVisitor as the // indexes will be off pub(crate) static SCAN_ROW_SCHEMA: LazyLock> = LazyLock::new(|| { + // Note that fields projected out of a nullable struct must be nullable Arc::new(StructType::new([ - StructField::new("path", DataType::STRING, false), + StructField::new("path", DataType::STRING, true), StructField::new("size", DataType::LONG, true), StructField::new("modificationTime", DataType::LONG, true), StructField::new("stats", DataType::STRING, true), StructField::new( "deletionVector", StructType::new([ - StructField::new("storageType", DataType::STRING, false), - StructField::new("pathOrInlineDv", DataType::STRING, false), + StructField::new("storageType", DataType::STRING, true), + StructField::new("pathOrInlineDv", DataType::STRING, true), StructField::new("offset", DataType::INTEGER, true), - StructField::new("sizeInBytes", DataType::INTEGER, false), - StructField::new("cardinality", DataType::LONG, false), + StructField::new("sizeInBytes", DataType::INTEGER, true), + StructField::new("cardinality", DataType::LONG, true), ]), true, ), @@ -100,7 +101,7 @@ pub(crate) static SCAN_ROW_SCHEMA: LazyLock> = LazyLock::new(|| "fileConstantValues", StructType::new([StructField::new( "partitionValues", - MapType::new(DataType::STRING, DataType::STRING, false), + MapType::new(DataType::STRING, DataType::STRING, true), true, )]), true, diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index 91ee54c4e..b0bbfbfba 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -418,6 +418,7 @@ fn get_state_info( // Add to read schema, store field so we can build a `Column` expression later // if needed (i.e. if we have partition columns) let physical_field = logical_field.make_physical(column_mapping_mode)?; + debug!("\n\n{logical_field:#?}\nAfter mapping: {physical_field:#?}\n\n"); let physical_name = physical_field.name.clone(); read_fields.push(physical_field); Ok(ColumnType::Selected(physical_name)) diff --git a/kernel/src/schema.rs b/kernel/src/schema.rs index e41db305e..61dd6e05b 100644 --- a/kernel/src/schema.rs +++ b/kernel/src/schema.rs @@ -29,6 +29,17 @@ pub enum MetadataValue { Other(serde_json::Value), } +impl std::fmt::Display for MetadataValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MetadataValue::Number(n) => write!(f, "{n}"), + MetadataValue::String(s) => write!(f, "{s}"), + MetadataValue::Boolean(b) => write!(f, "{b}"), + MetadataValue::Other(v) => write!(f, "{v}"), // just write the json back + } + } +} + impl From for MetadataValue { fn from(value: String) -> Self { Self::String(value) @@ -167,6 +178,15 @@ impl StructField { &self.metadata } + /// Convert our metadata into a HashMap. Note this copies all the data so can be + /// expensive for large metadata + pub fn metadata_with_string_values(&self) -> HashMap { + self.metadata + .iter() + .map(|(key, val)| (key.clone(), val.to_string())) + .collect() + } + pub fn make_physical(&self, mapping_mode: ColumnMappingMode) -> DeltaResult { use ColumnMappingMode::*; match mapping_mode { @@ -1091,4 +1111,25 @@ mod tests { assert_eq!(check_with_call_count(7), (7, 32)); assert_eq!(check_with_call_count(8), (7, 32)); } + + #[test] + fn test_metadata_value_to_string() { + assert_eq!(MetadataValue::Number(0).to_string(), "0"); + assert_eq!( + MetadataValue::String("hello".to_string()).to_string(), + "hello" + ); + assert_eq!(MetadataValue::Boolean(true).to_string(), "true"); + assert_eq!(MetadataValue::Boolean(false).to_string(), "false"); + let object_json = serde_json::json!({ "an": "object" }); + assert_eq!( + MetadataValue::Other(object_json).to_string(), + "{\"an\":\"object\"}" + ); + let array_json = serde_json::json!(["an", "array"]); + assert_eq!( + MetadataValue::Other(array_json).to_string(), + "[\"an\",\"array\"]" + ); + } } diff --git a/kernel/tests/golden_tables.rs b/kernel/tests/golden_tables.rs index 806171373..ea17deb70 100644 --- a/kernel/tests/golden_tables.rs +++ b/kernel/tests/golden_tables.rs @@ -3,9 +3,10 @@ //! Data (golden tables) are stored in tests/golden_data/.tar.zst //! Each table directory has a table/ and expected/ subdirectory with the input/output respectively +use arrow::array::AsArray; use arrow::{compute::filter_record_batch, record_batch::RecordBatch}; use arrow_ord::sort::{lexsort_to_indices, SortColumn}; -use arrow_schema::Schema; +use arrow_schema::{FieldRef, Schema}; use arrow_select::{concat::concat_batches, take::take}; use itertools::Itertools; use paste::paste; @@ -17,7 +18,7 @@ use futures::{stream::TryStreamExt, StreamExt}; use object_store::{local::LocalFileSystem, ObjectStore}; use parquet::arrow::async_reader::{ParquetObjectReader, ParquetRecordBatchStreamBuilder}; -use arrow_array::Array; +use arrow_array::{Array, StructArray}; use arrow_schema::DataType; use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; @@ -99,52 +100,76 @@ fn sort_record_batch(batch: RecordBatch) -> DeltaResult { Ok(RecordBatch::try_new(batch.schema(), columns)?) } -// copied from DAT -// Ensure that two schema have the same field names, and dict_id/ordering. +// Ensure that two sets of fields have the same names, and dict_id/ordering. // We ignore: // - data type: This is checked already in `assert_columns_match` // - nullability: parquet marks many things as nullable that we don't in our schema // - metadata: because that diverges from the real data to the golden tabled data -fn assert_schema_fields_match(schema: &Schema, golden: &Schema) { - for (schema_field, golden_field) in schema.fields.iter().zip(golden.fields.iter()) { +fn assert_fields_match<'a>( + actual: impl Iterator, + expected: impl Iterator, +) { + for (actual_field, expected_field) in actual.zip(expected) { assert!( - schema_field.name() == golden_field.name(), + actual_field.name() == expected_field.name(), "Field names don't match" ); assert!( - schema_field.dict_id() == golden_field.dict_id(), + actual_field.dict_id() == expected_field.dict_id(), "Field dict_id doesn't match" ); assert!( - schema_field.dict_is_ordered() == golden_field.dict_is_ordered(), + actual_field.dict_is_ordered() == expected_field.dict_is_ordered(), "Field dict_is_ordered doesn't match" ); } } -// copied from DAT -// some things are equivalent, but don't show up as equivalent for `==`, so we normalize here -fn normalize_col(col: Arc) -> Arc { - if let DataType::Timestamp(unit, Some(zone)) = col.data_type() { - if **zone == *"+00:00" { - arrow_cast::cast::cast(&col, &DataType::Timestamp(*unit, Some("UTC".into()))) - .expect("Could not cast to UTC") - } else { - col +fn assert_cols_eq(actual: &dyn Array, expected: &dyn Array) { + // Our testing only exercises these nested types so far. In the future we may need to expand + // this to more types. Any `DataType` with a nested `Field` is a candidate for needing to be + // compared this way. + match actual.data_type() { + DataType::Struct(_) => { + let actual_sa = actual.as_struct(); + let expected_sa = expected.as_struct(); + assert_eq(actual_sa, expected_sa); + } + DataType::List(_) => { + let actual_la = actual.as_list::(); + let expected_la = expected.as_list::(); + assert_cols_eq(actual_la.values(), expected_la.values()); + } + DataType::Map(_, _) => { + let actual_ma = actual.as_map(); + let expected_ma = expected.as_map(); + assert_cols_eq(actual_ma.keys(), expected_ma.keys()); + assert_cols_eq(actual_ma.values(), expected_ma.values()); + } + _ => { + assert_eq!(actual, expected, "Column data didn't match."); } - } else { - col } } -// copied from DAT -fn assert_columns_match(actual: &[Arc], expected: &[Arc]) { - for (actual, expected) in actual.iter().zip(expected) { - let actual = normalize_col(actual.clone()); - let expected = normalize_col(expected.clone()); - // note that array equality includes data_type equality - // See: https://arrow.apache.org/rust/arrow_data/equal/fn.equal.html - assert_eq!(&actual, &expected, "Column data didn't match."); +fn assert_eq(actual: &StructArray, expected: &StructArray) { + let actual_fields = actual.fields(); + let expected_fields = expected.fields(); + assert_eq!( + actual_fields.len(), + expected_fields.len(), + "Number of fields differed" + ); + assert_fields_match(actual_fields.iter(), expected_fields.iter()); + let actual_cols = actual.columns(); + let expected_cols = expected.columns(); + assert_eq!( + actual_cols.len(), + expected_cols.len(), + "Number of columns differed" + ); + for (actual_col, expected_col) in actual_cols.iter().zip(expected_cols) { + assert_cols_eq(actual_col, expected_col); } } @@ -175,16 +200,14 @@ async fn latest_snapshot_test( let expected = read_expected(&expected_path.expect("expect an expected dir")).await?; let schema: Arc = Arc::new(scan.schema().as_ref().try_into()?); - let result = concat_batches(&schema, &batches)?; let result = sort_record_batch(result)?; let expected = sort_record_batch(expected)?; - assert_columns_match(result.columns(), expected.columns()); - assert_schema_fields_match(expected.schema().as_ref(), result.schema().as_ref()); assert!( expected.num_rows() == result.num_rows(), "Didn't have same number of rows" ); + assert_eq(&result.into(), &expected.into()); Ok(()) } @@ -389,10 +412,9 @@ golden_test!("snapshot-data3", latest_snapshot_test); golden_test!("snapshot-repartitioned", latest_snapshot_test); golden_test!("snapshot-vacuumed", latest_snapshot_test); +golden_test!("table-with-columnmapping-mode-name", latest_snapshot_test); // TODO fix column mapping skip_test!("table-with-columnmapping-mode-id": "id column mapping mode not supported"); -skip_test!("table-with-columnmapping-mode-name": - "BUG: Parquet(General('partial projection of MapArray is not supported'))"); // TODO scan at different versions golden_test!("time-travel-partition-changes-a", latest_snapshot_test);