Skip to content

Commit

Permalink
fix: timestamp truncation in stats parsed
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <[email protected]>
  • Loading branch information
ion-elgreco committed Mar 2, 2025
1 parent 2468ad9 commit 365070c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 3 deletions.
40 changes: 38 additions & 2 deletions crates/core/src/kernel/snapshot/log_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;

use arrow_array::{Array, Int32Array, Int64Array, MapArray, RecordBatch, StringArray, StructArray};
use chrono::{DateTime, Utc};
use delta_kernel::expressions::Scalar;
use delta_kernel::expressions::{Scalar, StructData};
use indexmap::IndexMap;
use object_store::path::Path;
use object_store::ObjectMeta;
Expand Down Expand Up @@ -276,12 +276,26 @@ impl LogicalFile<'_> {
.column_by_name(COL_MIN_VALUES)
.and_then(|c| Scalar::from_array(c.as_ref(), self.index))
}

/// Struct containing all available max values for the columns in this file.
pub fn max_values(&self) -> Option<Scalar> {
// With delta.checkpoint.writeStatsAsStruct the microsecond timestamps are truncated to ms as defined by protocol
// this basically implies that it's floored when we parse_stats on the fly they are not truncated
// to tackle this we always round upwards by 1ms
fn ceil_datetime(v: i64) -> i64 {
let remainder = v % 1000;
if remainder == 0 {
// if nanoseconds precision remainder is 0, we assume it was truncated
// else we use the exact stats
((v as f64 / 1000.0).floor() as i64 + 1) * 1000
} else {
v
}
}

self.stats
.column_by_name(COL_MAX_VALUES)
.and_then(|c| Scalar::from_array(c.as_ref(), self.index))
.map(|s| round_ms_datetimes(s, &ceil_datetime))
}

pub fn add_action(&self) -> Add {
Expand Down Expand Up @@ -349,6 +363,28 @@ impl LogicalFile<'_> {
}
}

fn round_ms_datetimes<F>(value: Scalar, func: &F) -> Scalar
where
F: Fn(i64) -> i64,
{
match value {
Scalar::Timestamp(v) => Scalar::Timestamp(func(v)),
Scalar::TimestampNtz(v) => Scalar::TimestampNtz(func(v)),
Scalar::Struct(struct_data) => {
let mut fields = Vec::new();
let mut scalars = Vec::new();

for (field, value) in struct_data.fields().iter().zip(struct_data.values().iter()) {
fields.push(field.clone());
scalars.push(round_ms_datetimes(value.clone(), func));
}
let data = StructData::try_new(fields, scalars).unwrap();
Scalar::Struct(data)
}
value => value,
}
}

impl<'a> TryFrom<&LogicalFile<'a>> for ObjectMeta {
type Error = DeltaTableError;

Expand Down
70 changes: 69 additions & 1 deletion python/tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from datetime import datetime, timezone

import pyarrow as pa
import pytest

from deltalake import DeltaTable, write_deltalake
Expand Down Expand Up @@ -68,3 +69,70 @@ def test_stats_usage_3201(tmp_path):
data = dt.to_pyarrow_table(filters=[("values", ">=", 0)])

assert_frame_equal(excepted, pl.from_arrow(data), check_row_order=False)


@pytest.mark.parametrize("use_stats_struct", (True, False))
def test_microsecond_truncation_parquet_stats(tmp_path, use_stats_struct):
"""In checkpoints the min,max value gets truncated to milliseconds precision.
For min values this is not an issue, but for max values we need to round upwards.
This checks whether we can still read tables with truncated timestamp stats.
"""

batch1 = pa.Table.from_pydict(
{
"p": [1],
"dt": [datetime(2023, 3, 29, 23, 59, 59, 807126, tzinfo=timezone.utc)],
}
)

write_deltalake(
tmp_path,
batch1,
mode="error",
partition_by=["p"],
configuration={
"delta.checkpoint.writeStatsAsStruct": str(use_stats_struct).lower()
},
)

batch2 = pa.Table.from_pydict(
{
"p": [1],
"dt": [datetime(2023, 3, 30, 0, 0, 0, 902, tzinfo=timezone.utc)],
}
)

write_deltalake(
tmp_path,
batch2,
mode="append",
partition_by=["p"],
)

dt = DeltaTable(tmp_path)

result = dt.to_pyarrow_table(
filters=[
(
"dt",
"<=",
datetime(2023, 3, 30, 0, 0, 0, 0, tzinfo=timezone.utc),
),
]
)
assert batch1 == result

dt.optimize.compact()
dt.create_checkpoint()

result = dt.to_pyarrow_table(
filters=[
(
"dt",
"<=",
datetime(2023, 3, 30, 0, 0, 0, 0, tzinfo=timezone.utc),
),
]
)
assert batch1 == result

0 comments on commit 365070c

Please sign in to comment.