From e9fc6569a161ce59c4465949c5231b216e39ff74 Mon Sep 17 00:00:00 2001 From: Robert Pack Date: Sat, 25 May 2024 08:10:13 +0200 Subject: [PATCH] fix: simplify to_array for scalars --- Cargo.toml | 2 +- crates/core/src/kernel/scalars.rs | 23 ++++++++--------------- python/src/lib.rs | 20 ++++++++++---------- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 928e31ca93..5d89671d30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ debug = true debug = "line-tables-only" [workspace.dependencies] -delta_kernel = { git = "https://github.com/roeap/delta-kernel-rs", rev = "400a6c6091b225f814894426b322360509f7df5c" } +delta_kernel = { git = "https://github.com/roeap/delta-kernel-rs", rev = "06f767af3fe510a367eb63e35b76e99d1b3898e9" } # delta_kernel = { path = "../delta-kernel-rs/kernel" } # arrow diff --git a/crates/core/src/kernel/scalars.rs b/crates/core/src/kernel/scalars.rs index ee5a603f5c..dad2958c95 100644 --- a/crates/core/src/kernel/scalars.rs +++ b/crates/core/src/kernel/scalars.rs @@ -172,18 +172,14 @@ impl ScalarExt for Scalar { .as_any() .downcast_ref::() .map(|v| Self::Date(v.value(index))), - // TODO handle timezones when implementing timestamp ntz feature. - Timestamp(TimeUnit::Microsecond, tz) => match tz { - None => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Timestamp(v.value(index))), - Some(tz_str) if tz_str.as_ref() == "UTC" => arr - .as_any() - .downcast_ref::() - .map(|v| Self::Timestamp(v.clone().with_timezone("UTC").value(index))), - _ => None, - }, + Timestamp(TimeUnit::Microsecond, None) => arr + .as_any() + .downcast_ref::() + .map(|v| Self::TimestampNtz(v.value(index))), + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.eq_ignore_ascii_case("utc") => arr + .as_any() + .downcast_ref::() + .map(|v| Self::Timestamp(v.clone().value(index))), Struct(fields) => { let struct_fields = fields .iter() @@ -202,9 +198,6 @@ impl ScalarExt for Scalar { }) .collect::>>() })?; - if struct_fields.len() != values.len() { - return None; - } Some(Self::Struct( StructData::try_new(struct_fields, values).ok()?, )) diff --git a/python/src/lib.rs b/python/src/lib.rs index 0ce463026f..9dbd2cbd7c 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1365,9 +1365,9 @@ fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult value.serialize().to_object(py), - Struct(values, fields) => { + Struct(data) => { let py_struct = PyDict::new(py); - for (field, value) in fields.iter().zip(values.iter()) { + for (field, value) in data.fields().iter().zip(data.values().iter()) { py_struct.set_item(field.name(), scalar_to_py(value, py_date, py)?)?; } py_struct.to_object(py) @@ -1434,8 +1434,8 @@ fn filestats_to_expression_next<'py>( let mut has_nulls_set: HashSet = HashSet::new(); // NOTE: null_counts should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.null_counts() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.null_counts() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { if let Scalar::Long(val) = value { if *val == 0 { expressions.push(py_field.call1((field.name(),))?.call_method0("is_valid")); @@ -1449,11 +1449,11 @@ fn filestats_to_expression_next<'py>( } // NOTE: min_values should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.min_values() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.min_values() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { match value { // TODO: Handle nested field statistics. - Scalar::Struct(_, _) => {} + Scalar::Struct(_) => {} _ => { let maybe_minimum = cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0); @@ -1476,11 +1476,11 @@ fn filestats_to_expression_next<'py>( } // NOTE: max_values should always return a struct scalar. - if let Some(Scalar::Struct(values, fields)) = file_info.max_values() { - for (field, value) in fields.iter().zip(values.iter()) { + if let Some(Scalar::Struct(data)) = file_info.max_values() { + for (field, value) in data.fields().iter().zip(data.values().iter()) { match value { // TODO: Handle nested field statistics. - Scalar::Struct(_, _) => {} + Scalar::Struct(_) => {} _ => { let maybe_maximum = cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0);