Skip to content

Commit

Permalink
fix: simplify to_array for scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
roeap committed May 25, 2024
1 parent 4ab555d commit e9fc656
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ debug = true
debug = "line-tables-only"

[workspace.dependencies]
delta_kernel = { git = "https://github.com/roeap/delta-kernel-rs", rev = "400a6c6091b225f814894426b322360509f7df5c" }
delta_kernel = { git = "https://github.com/roeap/delta-kernel-rs", rev = "06f767af3fe510a367eb63e35b76e99d1b3898e9" }
# delta_kernel = { path = "../delta-kernel-rs/kernel" }

# arrow
Expand Down
23 changes: 8 additions & 15 deletions crates/core/src/kernel/scalars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,14 @@ impl ScalarExt for Scalar {
.as_any()
.downcast_ref::<Date32Array>()
.map(|v| Self::Date(v.value(index))),
// TODO handle timezones when implementing timestamp ntz feature.
Timestamp(TimeUnit::Microsecond, tz) => match tz {
None => arr
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.map(|v| Self::Timestamp(v.value(index))),
Some(tz_str) if tz_str.as_ref() == "UTC" => arr
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.map(|v| Self::Timestamp(v.clone().with_timezone("UTC").value(index))),
_ => None,
},
Timestamp(TimeUnit::Microsecond, None) => arr
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.map(|v| Self::TimestampNtz(v.value(index))),
Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.eq_ignore_ascii_case("utc") => arr
.as_any()
.downcast_ref::<TimestampMicrosecondArray>()
.map(|v| Self::Timestamp(v.clone().value(index))),
Struct(fields) => {
let struct_fields = fields
.iter()
Expand All @@ -202,9 +198,6 @@ impl ScalarExt for Scalar {
})
.collect::<Option<Vec<_>>>()
})?;
if struct_fields.len() != values.len() {
return None;
}
Some(Self::Struct(
StructData::try_new(struct_fields, values).ok()?,
))
Expand Down
20 changes: 10 additions & 10 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1365,9 +1365,9 @@ fn scalar_to_py(value: &Scalar, py_date: &PyAny, py: Python) -> PyResult<PyObjec
date.to_object(py)
}
Decimal(_, _, _) => value.serialize().to_object(py),
Struct(values, fields) => {
Struct(data) => {
let py_struct = PyDict::new(py);
for (field, value) in fields.iter().zip(values.iter()) {
for (field, value) in data.fields().iter().zip(data.values().iter()) {
py_struct.set_item(field.name(), scalar_to_py(value, py_date, py)?)?;
}
py_struct.to_object(py)
Expand Down Expand Up @@ -1434,8 +1434,8 @@ fn filestats_to_expression_next<'py>(
let mut has_nulls_set: HashSet<String> = HashSet::new();

// NOTE: null_counts should always return a struct scalar.
if let Some(Scalar::Struct(values, fields)) = file_info.null_counts() {
for (field, value) in fields.iter().zip(values.iter()) {
if let Some(Scalar::Struct(data)) = file_info.null_counts() {
for (field, value) in data.fields().iter().zip(data.values().iter()) {
if let Scalar::Long(val) = value {
if *val == 0 {
expressions.push(py_field.call1((field.name(),))?.call_method0("is_valid"));
Expand All @@ -1449,11 +1449,11 @@ fn filestats_to_expression_next<'py>(
}

// NOTE: min_values should always return a struct scalar.
if let Some(Scalar::Struct(values, fields)) = file_info.min_values() {
for (field, value) in fields.iter().zip(values.iter()) {
if let Some(Scalar::Struct(data)) = file_info.min_values() {
for (field, value) in data.fields().iter().zip(data.values().iter()) {
match value {
// TODO: Handle nested field statistics.
Scalar::Struct(_, _) => {}
Scalar::Struct(_) => {}
_ => {
let maybe_minimum =
cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0);
Expand All @@ -1476,11 +1476,11 @@ fn filestats_to_expression_next<'py>(
}

// NOTE: max_values should always return a struct scalar.
if let Some(Scalar::Struct(values, fields)) = file_info.max_values() {
for (field, value) in fields.iter().zip(values.iter()) {
if let Some(Scalar::Struct(data)) = file_info.max_values() {
for (field, value) in data.fields().iter().zip(data.values().iter()) {
match value {
// TODO: Handle nested field statistics.
Scalar::Struct(_, _) => {}
Scalar::Struct(_) => {}
_ => {
let maybe_maximum =
cast_to_type(field.name(), scalar_to_py(value, py_date, py)?, &schema.0);
Expand Down

0 comments on commit e9fc656

Please sign in to comment.