diff --git a/Cargo.toml b/Cargo.toml index 025249708..38d369eb2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,7 @@ byte-slice-cast = { version = "1.2.1" } clap = { version = "4.5.4" } criterion = { version = "0.5.1" } chrono-tz = {version = "0.9.0", features = ["serde"]} +chrono = { version = "0.4.38" } curve25519-dalek = { version = "4", features = ["rand_core"] } derive_more = { version = "0.99" } dyn_partial_eq = { version = "0.1.2" } diff --git a/crates/proof-of-sql-parser/Cargo.toml b/crates/proof-of-sql-parser/Cargo.toml index 3ffbff9e3..60a1c089e 100644 --- a/crates/proof-of-sql-parser/Cargo.toml +++ b/crates/proof-of-sql-parser/Cargo.toml @@ -17,6 +17,7 @@ test = true [dependencies] arrayvec = { workspace = true, features = ["serde"] } bigdecimal = { workspace = true } +chrono = { workspace = true } lalrpop-util = { workspace = true, features = ["lexer", "unicode"] } serde = { workspace = true, features = ["serde_derive"] } thiserror = { workspace = true } diff --git a/crates/proof-of-sql-parser/src/intermediate_ast.rs b/crates/proof-of-sql-parser/src/intermediate_ast.rs index 93df7b9a7..77d757cac 100644 --- a/crates/proof-of-sql-parser/src/intermediate_ast.rs +++ b/crates/proof-of-sql-parser/src/intermediate_ast.rs @@ -4,7 +4,9 @@ * https://docs.rs/vervolg/latest/vervolg/ast/enum.Statement.html ***/ -use crate::{intermediate_decimal::IntermediateDecimal, Identifier}; +use crate::{ + intermediate_decimal::IntermediateDecimal, intermediate_time::IntermediateTimestamp, Identifier, +}; use serde::{Deserialize, Serialize}; /// Representation of a SetExpression, a collection of rows, each having one or more columns. @@ -328,6 +330,8 @@ pub enum Literal { VarChar(String), /// Decimal Literal Decimal(IntermediateDecimal), + /// Timestamp Literal + TimestampTZ(IntermediateTimestamp), } impl From for Literal { diff --git a/crates/proof-of-sql-parser/src/intermediate_decimal.rs b/crates/proof-of-sql-parser/src/intermediate_decimal.rs index 850637e16..a99e961a4 100644 --- a/crates/proof-of-sql-parser/src/intermediate_decimal.rs +++ b/crates/proof-of-sql-parser/src/intermediate_decimal.rs @@ -12,7 +12,7 @@ use thiserror::Error; /// Errors related to the processing of decimal values in proof-of-sql #[derive(Error, Debug, PartialEq)] -pub enum DecimalError { +pub enum IntermediateDecimalError { /// Represents an error encountered during the parsing of a decimal string. #[error(transparent)] ParseError(#[from] ParseBigDecimalError), @@ -27,6 +27,8 @@ pub enum DecimalError { ConversionFailure, } +impl Eq for IntermediateDecimalError {} + /// An intermediate placeholder for a decimal #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] pub struct IntermediateDecimal { @@ -55,10 +57,10 @@ impl IntermediateDecimal { &self, precision: u8, scale: i8, - ) -> Result { + ) -> Result { let scaled_decimal = self.value.with_scale(scale.into()); if scaled_decimal.digits() > precision.into() { - return Err(DecimalError::LossyCast); + return Err(IntermediateDecimalError::LossyCast); } let (d, _) = scaled_decimal.into_bigint_and_exponent(); Ok(d) @@ -72,14 +74,14 @@ impl fmt::Display for IntermediateDecimal { } impl FromStr for IntermediateDecimal { - type Err = DecimalError; + type Err = IntermediateDecimalError; fn from_str(decimal_string: &str) -> Result { BigDecimal::from_str(decimal_string) .map(|value| IntermediateDecimal { value: value.normalized(), }) - .map_err(DecimalError::ParseError) + .map_err(IntermediateDecimalError::ParseError) } } @@ -100,7 +102,7 @@ impl From for IntermediateDecimal { } impl TryFrom<&str> for IntermediateDecimal { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(s: &str) -> Result { IntermediateDecimal::from_str(s) @@ -108,7 +110,7 @@ impl TryFrom<&str> for IntermediateDecimal { } impl TryFrom for IntermediateDecimal { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(s: String) -> Result { IntermediateDecimal::from_str(&s) @@ -116,31 +118,31 @@ impl TryFrom for IntermediateDecimal { } impl TryFrom for i128 { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(decimal: IntermediateDecimal) -> Result { if !decimal.value.is_integer() { - return Err(DecimalError::LossyCast); + return Err(IntermediateDecimalError::LossyCast); } match decimal.value.to_i128() { Some(value) if (i128::MIN..=i128::MAX).contains(&value) => Ok(value), - _ => Err(DecimalError::OutOfRange), + _ => Err(IntermediateDecimalError::OutOfRange), } } } impl TryFrom for i64 { - type Error = DecimalError; + type Error = IntermediateDecimalError; fn try_from(decimal: IntermediateDecimal) -> Result { if !decimal.value.is_integer() { - return Err(DecimalError::LossyCast); + return Err(IntermediateDecimalError::LossyCast); } match decimal.value.to_i64() { Some(value) if (i64::MIN..=i64::MAX).contains(&value) => Ok(value), - _ => Err(DecimalError::OutOfRange), + _ => Err(IntermediateDecimalError::OutOfRange), } } } @@ -197,7 +199,7 @@ mod tests { }; assert_eq!( i128::try_from(overflow_decimal), - Err(DecimalError::OutOfRange) + Err(IntermediateDecimalError::OutOfRange) ); let valid_decimal_negative = IntermediateDecimal { @@ -211,7 +213,10 @@ mod tests { let non_integer = IntermediateDecimal { value: BigDecimal::from_str("100.5").unwrap(), }; - assert_eq!(i128::try_from(non_integer), Err(DecimalError::LossyCast)); + assert_eq!( + i128::try_from(non_integer), + Err(IntermediateDecimalError::LossyCast) + ); } #[test] @@ -231,7 +236,7 @@ mod tests { }; assert_eq!( i64::try_from(overflow_decimal), - Err(DecimalError::OutOfRange) + Err(IntermediateDecimalError::OutOfRange) ); let valid_decimal_negative = IntermediateDecimal { @@ -245,6 +250,9 @@ mod tests { let non_integer = IntermediateDecimal { value: BigDecimal::from_str("100.5").unwrap(), }; - assert_eq!(i64::try_from(non_integer), Err(DecimalError::LossyCast)); + assert_eq!( + i64::try_from(non_integer), + Err(IntermediateDecimalError::LossyCast) + ); } } diff --git a/crates/proof-of-sql-parser/src/intermediate_time.rs b/crates/proof-of-sql-parser/src/intermediate_time.rs new file mode 100644 index 000000000..11ef9956f --- /dev/null +++ b/crates/proof-of-sql-parser/src/intermediate_time.rs @@ -0,0 +1,344 @@ +use chrono::{DateTime, NaiveDateTime, Offset, TimeZone, Utc}; +use core::fmt; +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// Errors from converting an intermediate AST into a provable AST. +#[derive(Error, Debug, PartialEq, Eq)] +pub enum IntermediateTimestampError { + #[error("Invalid timeunit")] + /// Error converting intermediate time units to PoSQL time units + InvalidTimeUnit, + + #[error("Invalid timezone")] + /// Error converting intermediate time zones to PoSQL timezones + InvalidTimeZone, + + /// Could not parse a timestamp from string + #[error("Invalid timestamp format")] + InvalidFormat, +} + +/// An initermediate type of components extracted from a timestamp string. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum IntermediateTimeUnit { + /// Represents seconds with precision 0: ex "2024-06-20 12:34:56" + Second, + /// Represents milliseconds with precision 3: ex "2024-06-20 12:34:56.123" + Millisecond, + /// Represents microseconds with precision 6: ex "2024-06-20 12:34:56.123456" + Microsecond, + /// Represents nanoseconds with precision 9: ex "2024-06-20 12:34:56.123456789" + Nanosecond, +} + +impl fmt::Display for IntermediateTimeUnit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + IntermediateTimeUnit::Second => write!(f, "Second"), + IntermediateTimeUnit::Millisecond => write!(f, "Millisecond"), + IntermediateTimeUnit::Microsecond => write!(f, "Microsecond"), + IntermediateTimeUnit::Nanosecond => write!(f, "Nanosecond"), + } + } +} + +/// Captures a timezone from a timestamp query +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub enum IntermediateTimeZone { + /// Default variant for UTC timezoen + Utc, + /// TImezone offset in seconds + FixedOffset(i32), +} + +impl IntermediateTimeZone { + /// Parse a timezone from a count of seconds + pub fn from_offset(offset: i32) -> Self { + if offset == 0 { + IntermediateTimeZone::Utc + } else { + IntermediateTimeZone::FixedOffset(offset) + } + } +} + +impl fmt::Display for IntermediateTimeZone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + IntermediateTimeZone::Utc => write!(f, "Z"), + IntermediateTimeZone::FixedOffset(offset) => { + if *offset == 0 { + write!(f, "Z") + } else { + let total_minutes = offset / 60; + let hours = total_minutes / 60; + let minutes = total_minutes.abs() % 60; + write!(f, "{:+03}:{:02}", hours, minutes) + } + } + } + } +} + +/// Intermediate Time +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub struct IntermediateTimestamp { + /// Count of time units since the unix epoch + pub timestamp: i64, + /// Seconds, milliseconds, microseconds, or nanoseconds + pub unit: IntermediateTimeUnit, + /// Timezone captured from parsed string + pub timezone: IntermediateTimeZone, +} + +impl TryFrom<&str> for IntermediateTimestamp { + type Error = IntermediateTimestampError; + + fn try_from(value: &str) -> Result { + parse_intermediate_timestamp(value).map_err(|_| IntermediateTimestampError::InvalidFormat) + } +} + +/// Parses a timestamp from valid strings obtained from the lexer +pub fn parse_intermediate_timestamp(ts: &str) -> Result { + let format_with_tz = "%Y-%m-%d %H:%M:%S%.f%:z"; + let format_without_tz = "%Y-%m-%d %H:%M:%S%.f"; + + // Helper function to determine the precision of the fractional seconds + fn determine_precision(fraction: &str) -> IntermediateTimeUnit { + match fraction.len() { + 0 => IntermediateTimeUnit::Second, + 1..=3 => IntermediateTimeUnit::Millisecond, + 4..=6 => IntermediateTimeUnit::Microsecond, + _ => IntermediateTimeUnit::Nanosecond, + } + } + + // Extract the fractional part correctly + fn extract_fraction(ts: &str) -> &str { + if let Some((_, fractional)) = ts.split_once('.') { + if let Some((fractional, _)) = fractional.split_once(|c| c == '+' || c == '-') { + return fractional; + } + return fractional; + } + "" + } + + // First try parsing with timezone + if let Ok(dt) = DateTime::parse_from_str(ts, format_with_tz) { + if let Some(timestamp_nanos) = dt.timestamp_nanos_opt() { + let offset_seconds = dt.offset().fix().local_minus_utc(); + let fraction = extract_fraction(ts); + let unit = determine_precision(fraction); + return Ok(IntermediateTimestamp { + timestamp: timestamp_nanos, + unit, + timezone: IntermediateTimeZone::from_offset(offset_seconds), + }); + } else { + return Err("Failed to convert datetime to nanoseconds"); + } + } + + // If that fails, try parsing without timezone and assume UTC + if let Ok(naive_dt) = NaiveDateTime::parse_from_str(ts, format_without_tz) { + let datetime_utc = Utc.from_utc_datetime(&naive_dt); + if let Some(timestamp_nanos) = datetime_utc.timestamp_nanos_opt() { + let fraction = extract_fraction(ts); + let unit = determine_precision(fraction); + return Ok(IntermediateTimestamp { + timestamp: timestamp_nanos, + unit, + timezone: IntermediateTimeZone::Utc, + }); + } else { + return Err("Failed to convert datetime to nanoseconds"); + } + } + + Err("Invalid timestamp format") +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{FixedOffset, TimeZone, Timelike, Utc}; + + #[test] + fn test_display_intermediate_timezone() { + // Test Utc + let tz_utc = IntermediateTimeZone::Utc; + assert_eq!(format!("{}", tz_utc), "Z"); + + // Test positive offsets + let tz_offset_1 = IntermediateTimeZone::FixedOffset(3600); // +01:00 + assert_eq!(format!("{}", tz_offset_1), "+01:00"); + + let tz_offset_2 = IntermediateTimeZone::FixedOffset(19800); // +05:30 + assert_eq!(format!("{}", tz_offset_2), "+05:30"); + + let tz_offset_3 = IntermediateTimeZone::FixedOffset(3600 * 12); // +12:00 + assert_eq!(format!("{}", tz_offset_3), "+12:00"); + + // Test negative offsets + let tz_offset_4 = IntermediateTimeZone::FixedOffset(-3600); // -01:00 + assert_eq!(format!("{}", tz_offset_4), "-01:00"); + + let tz_offset_5 = IntermediateTimeZone::FixedOffset(-12600); // -03:30 + assert_eq!(format!("{}", tz_offset_5), "-03:30"); + + let tz_offset_6 = IntermediateTimeZone::FixedOffset(-3600 * 12); // -12:00 + assert_eq!(format!("{}", tz_offset_6), "-12:00"); + + // Test edge cases + let tz_offset_7 = IntermediateTimeZone::FixedOffset(0); // +00:00 + assert_eq!(format!("{}", tz_offset_7), "Z"); + + let tz_offset_8 = IntermediateTimeZone::FixedOffset(3600 * 14); // +14:00 + assert_eq!(format!("{}", tz_offset_8), "+14:00"); + + let tz_offset_9 = IntermediateTimeZone::FixedOffset(-3600 * 14); // -14:00 + assert_eq!(format!("{}", tz_offset_9), "-14:00"); + } + + #[test] + fn test_parse_with_timezone() { + let ts_with_tz = "2024-06-20 12:34:56+02:00"; + let result = parse_intermediate_timestamp(ts_with_tz) + .expect("Failed to parse timestamp with timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Second); + + let ts_with_tz = "2024-06-20 12:34:56.123+02:00"; + let result = parse_intermediate_timestamp(ts_with_tz) + .expect("Failed to parse timestamp with timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Millisecond); + + let ts_with_tz = "2024-06-20 12:34:56.123456+02:00"; + let result = parse_intermediate_timestamp(ts_with_tz) + .expect("Failed to parse timestamp with timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Microsecond); + + let ts_with_tz = "2024-06-20 12:34:56.123456789+02:00"; + let result = parse_intermediate_timestamp(ts_with_tz) + .expect("Failed to parse timestamp with timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Nanosecond); + assert_eq!(result.timezone, IntermediateTimeZone::FixedOffset(7200)); // +02:00 is 7200 seconds + let expected_timestamp: DateTime = FixedOffset::east_opt(7200) + .unwrap() + .with_ymd_and_hms(2024, 6, 20, 12, 34, 56) + .unwrap() + .with_nanosecond(123_456_789) + .unwrap(); + assert_eq!( + result.timestamp, + expected_timestamp.timestamp_nanos_opt().unwrap() + ); + } + + #[test] + fn test_parse_without_timezone() { + let ts_without_tz = "2024-06-20 12:34:56"; + let result = parse_intermediate_timestamp(ts_without_tz) + .expect("Failed to parse timestamp without timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Second); + assert_eq!(result.timezone, IntermediateTimeZone::Utc); + + let ts_without_tz = "2024-06-20 12:34:56.123"; + let result = parse_intermediate_timestamp(ts_without_tz) + .expect("Failed to parse timestamp without timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Millisecond); + assert_eq!(result.timezone, IntermediateTimeZone::Utc); + + let ts_without_tz = "2024-06-20 12:34:56.123456"; + let result = parse_intermediate_timestamp(ts_without_tz) + .expect("Failed to parse timestamp without timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Microsecond); + assert_eq!(result.timezone, IntermediateTimeZone::Utc); + + let ts_without_tz = "2024-06-20 12:34:56.123456789"; + let result = parse_intermediate_timestamp(ts_without_tz) + .expect("Failed to parse timestamp without timezone"); + + assert_eq!(result.unit, IntermediateTimeUnit::Nanosecond); + assert_eq!(result.timezone, IntermediateTimeZone::Utc); + let expected_timestamp = Utc + .with_ymd_and_hms(2024, 6, 20, 12, 34, 56) + .unwrap() + .with_nanosecond(123_456_789) + .unwrap(); + assert_eq!( + result.timestamp, + expected_timestamp.timestamp_nanos_opt().unwrap() + ); + } + + #[test] + fn test_parse_invalid_format() { + let invalid_ts = "invalid timestamp"; + let result = parse_intermediate_timestamp(invalid_ts); + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), "Invalid timestamp format"); + } + + #[test] + fn test_parse_missing_fractional_seconds() { + let ts_missing_fractional = "2024-06-20 12:34:56+02:00"; + let result = parse_intermediate_timestamp(ts_missing_fractional) + .expect("Failed to parse timestamp without fractional seconds"); + + assert_eq!(result.unit, IntermediateTimeUnit::Second); + assert_eq!(result.timezone, IntermediateTimeZone::FixedOffset(7200)); + let expected_timestamp: DateTime = FixedOffset::east_opt(7200) + .unwrap() + .with_ymd_and_hms(2024, 6, 20, 12, 34, 56) + .unwrap(); + assert_eq!( + result.timestamp, + expected_timestamp.timestamp_nanos_opt().unwrap() + ); + } + + #[test] + fn test_parse_different_timezones() { + let timezones = [ + ("2024-06-20 12:34:56.123456789-05:00", -18000), // -05:00 is -18000 seconds + ("2024-06-20 12:34:56.123456789+00:00", 0), // +00:00 is 0 seconds + ("2024-06-20 12:34:56.123456789+05:30", 19800), // +05:30 is 19800 seconds + ("2024-06-20 12:34:56.123456789-08:00", -28800), // -08:00 is -28800 seconds + ("2024-06-20 12:34:56.123456789+09:00", 32400), // +09:00 is 32400 seconds + ("2024-06-20 12:34:56.123456789-03:30", -12600), // -03:30 is -12600 seconds + ("2024-06-20 12:34:56.123456789+12:00", 43200), // +12:00 is 43200 seconds + ("2024-06-20 12:34:56.123456789-12:00", -43200), // -12:00 is -43200 seconds + ]; + + for (ts, offset_seconds) in &timezones { + let result = parse_intermediate_timestamp(ts) + .unwrap_or_else(|_| panic!("Failed to parse timestamp with timezone {}", ts)); + + assert_eq!(result.unit, IntermediateTimeUnit::Nanosecond); + assert_eq!( + result.timezone, + IntermediateTimeZone::from_offset(*offset_seconds) + ); + let expected_timestamp: DateTime = FixedOffset::east_opt(*offset_seconds) + .unwrap() + .with_ymd_and_hms(2024, 6, 20, 12, 34, 56) + .unwrap() + .with_nanosecond(123_456_789) + .unwrap(); + assert_eq!( + result.timestamp, + expected_timestamp.timestamp_nanos_opt().unwrap() + ); + } + } +} diff --git a/crates/proof-of-sql-parser/src/lib.rs b/crates/proof-of-sql-parser/src/lib.rs index 1bd38217a..7fd6be5ea 100644 --- a/crates/proof-of-sql-parser/src/lib.rs +++ b/crates/proof-of-sql-parser/src/lib.rs @@ -2,6 +2,8 @@ /// Module for handling an intermediate decimal type received from the lexer. pub mod intermediate_decimal; +/// Module for handling an intermediate timestamp type received from the lexer. +pub mod intermediate_time; #[macro_use] extern crate lalrpop_util; diff --git a/crates/proof-of-sql-parser/src/sql.lalrpop b/crates/proof-of-sql-parser/src/sql.lalrpop index dec34a067..fbe023eb9 100644 --- a/crates/proof-of-sql-parser/src/sql.lalrpop +++ b/crates/proof-of-sql-parser/src/sql.lalrpop @@ -3,6 +3,7 @@ use crate::select_statement; use crate::identifier; use lalrpop_util::ParseError::User; use crate::intermediate_decimal::IntermediateDecimal; +use crate::intermediate_time::IntermediateTimestamp; grammar; @@ -337,6 +338,11 @@ LiteralValue: Box = { }, => Box::new(intermediate_ast::Literal::Decimal(value)), + => Box::new(intermediate_ast::Literal::TimestampTZ(value)), +}; + +TimestampLiteral: IntermediateTimestamp = { + =>? IntermediateTimestamp::try_from(ts).map_err(|_| User {error: "Invalid timestamp format"}), }; Int128UnaryNumericLiteral: i128 = { @@ -435,4 +441,6 @@ match { // Integer numbers (without a fractional part) r"[+-]?[0-9]+" => INTEGER_LIT, r"'(?s)(?:''|[^'])*'" => STRING_LITERAL, + // Timestamp literals with optional fractional seconds and time zone offset + r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(?:\.\d{1,9})?(?:Z|[+-]\d{2}:\d{2})?" => TIMESTAMP_LIT, } diff --git a/crates/proof-of-sql/src/base/commitment/column_bounds.rs b/crates/proof-of-sql/src/base/commitment/column_bounds.rs index 22106c930..e0fa208d1 100644 --- a/crates/proof-of-sql/src/base/commitment/column_bounds.rs +++ b/crates/proof-of-sql/src/base/commitment/column_bounds.rs @@ -292,7 +292,7 @@ mod tests { database::OwnedColumn, math::decimal::Precision, scalar::Curve25519Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use itertools::Itertools; diff --git a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs index 9d3fce8cd..d37ec23df 100644 --- a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs +++ b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs @@ -169,7 +169,7 @@ mod tests { database::OwnedColumn, math::decimal::Precision, scalar::Curve25519Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; #[test] diff --git a/crates/proof-of-sql/src/base/commitment/committable_column.rs b/crates/proof-of-sql/src/base/commitment/committable_column.rs index a6fcb89dd..33477367d 100644 --- a/crates/proof-of-sql/src/base/commitment/committable_column.rs +++ b/crates/proof-of-sql/src/base/commitment/committable_column.rs @@ -3,7 +3,7 @@ use crate::base::{ math::decimal::Precision, ref_into::RefInto, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; #[cfg(feature = "blitzar")] use blitzar::sequence::Sequence; @@ -194,7 +194,10 @@ impl<'a, 'b> From<&'a CommittableColumn<'b>> for Sequence<'a> { #[cfg(all(test, feature = "blitzar"))] mod tests { use super::*; - use crate::{base::scalar::Curve25519Scalar, proof_primitive::dory::DoryScalar}; + use crate::{ + base::{scalar::Curve25519Scalar, time::timezone::PoSQLTimeZone}, + proof_primitive::dory::DoryScalar, + }; use blitzar::compute::compute_curve25519_commitments; use curve25519_dalek::ristretto::CompressedRistretto; diff --git a/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs b/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs index 91bb4b6c1..00d278f68 100644 --- a/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs +++ b/crates/proof-of-sql/src/base/database/arrow_array_to_column_conversion.rs @@ -4,7 +4,7 @@ use crate::{ database::Column, math::decimal::Precision, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }, sql::parse::ConversionError, }; diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 4bba7b38c..7477fdfb9 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -2,7 +2,7 @@ use super::{LiteralValue, TableRef}; use crate::base::{ math::decimal::{scale_scalar, Precision}, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use arrow::datatypes::{DataType, Field, TimeUnit as ArrowTimeUnit}; use bumpalo::Bump; @@ -113,7 +113,7 @@ impl<'a, S: Scalar> Column<'a, S> { *scale, alloc.alloc_slice_fill_copy(length, *value), ), - LiteralValue::TimeStampTZ(tu, tz, value) => { + LiteralValue::TimestampTZ(tu, tz, value) => { Column::TimestampTZ(*tu, *tz, alloc.alloc_slice_fill_copy(length, *value)) } LiteralValue::VarChar((string, scalar)) => Column::VarChar(( diff --git a/crates/proof-of-sql/src/base/database/literal_value.rs b/crates/proof-of-sql/src/base/database/literal_value.rs index 76bc41865..7d0adc8a4 100644 --- a/crates/proof-of-sql/src/base/database/literal_value.rs +++ b/crates/proof-of-sql/src/base/database/literal_value.rs @@ -2,7 +2,7 @@ use crate::base::{ database::ColumnType, math::decimal::Precision, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use serde::{Deserialize, Serialize}; @@ -36,7 +36,7 @@ pub enum LiteralValue { Scalar(S), /// TimeStamp defined over a unit (s, ms, ns, etc) and timezone with backing store /// mapped to i64, which is time units since unix epoch - TimeStampTZ(PoSQLTimeUnit, PoSQLTimeZone, i64), + TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, i64), } impl LiteralValue { @@ -51,7 +51,7 @@ impl LiteralValue { Self::Int128(_) => ColumnType::Int128, Self::Scalar(_) => ColumnType::Scalar, Self::Decimal75(precision, scale, _) => ColumnType::Decimal75(*precision, *scale), - Self::TimeStampTZ(tu, tz, _) => ColumnType::TimestampTZ(*tu, *tz), + Self::TimestampTZ(tu, tz, _) => ColumnType::TimestampTZ(*tu, *tz), } } @@ -66,7 +66,7 @@ impl LiteralValue { Self::Int128(i) => i.into(), Self::Decimal75(_, _, s) => *s, Self::Scalar(scalar) => *scalar, - Self::TimeStampTZ(_, _, time) => time.into(), + Self::TimestampTZ(_, _, time) => time.into(), } } } diff --git a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs b/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs index 3e551b89e..d2286d167 100644 --- a/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/database/owned_and_arrow_conversions.rs @@ -20,7 +20,7 @@ use crate::base::{ }, math::decimal::Precision, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use arrow::{ array::{ diff --git a/crates/proof-of-sql/src/base/database/owned_column.rs b/crates/proof-of-sql/src/base/database/owned_column.rs index 466c40871..2704d8595 100644 --- a/crates/proof-of-sql/src/base/database/owned_column.rs +++ b/crates/proof-of-sql/src/base/database/owned_column.rs @@ -6,7 +6,7 @@ use super::ColumnType; use crate::base::{ math::decimal::Precision, scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; #[derive(Debug, PartialEq, Clone, Eq)] #[non_exhaustive] diff --git a/crates/proof-of-sql/src/base/database/owned_table_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test.rs index 617fda933..4c6b8e092 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test.rs @@ -2,7 +2,7 @@ use crate::{ base::{ database::{owned_table_utility::*, OwnedColumn, OwnedTable, OwnedTableError}, scalar::Curve25519Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }, proof_primitive::dory::DoryScalar, }; diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs index 0c70a1ade..6a795d18d 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs @@ -5,7 +5,7 @@ use super::{ use crate::base::{ database::owned_table_utility::*, scalar::{compute_commitment_for_testing, Curve25519Scalar}, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use blitzar::proof::InnerProductProof; diff --git a/crates/proof-of-sql/src/base/database/owned_table_utility.rs b/crates/proof-of-sql/src/base/database/owned_table_utility.rs index 0b2131b30..f6661d6f1 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_utility.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_utility.rs @@ -16,7 +16,7 @@ use super::{OwnedColumn, OwnedTable}; use crate::base::{ scalar::Scalar, - time::timestamp::{PoSQLTimeUnit, PoSQLTimeZone}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }; use core::ops::Deref; use proof_of_sql_parser::Identifier; diff --git a/crates/proof-of-sql/src/base/database/record_batch_utility.rs b/crates/proof-of-sql/src/base/database/record_batch_utility.rs index 7c67c8f7c..43974aa01 100644 --- a/crates/proof-of-sql/src/base/database/record_batch_utility.rs +++ b/crates/proof-of-sql/src/base/database/record_batch_utility.rs @@ -1,4 +1,4 @@ -use crate::base::time::timestamp::{PoSQLTimeUnit, Time}; +use crate::base::time::timeunit::{PoSQLTimeUnit, Time}; use arrow::array::{ TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, diff --git a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs index f14397c4d..c9538a0f3 100644 --- a/crates/proof-of-sql/src/base/database/test_accessor_utility.rs +++ b/crates/proof-of-sql/src/base/database/test_accessor_utility.rs @@ -1,4 +1,4 @@ -use crate::base::{database::ColumnType, time::timestamp::PoSQLTimeUnit}; +use crate::base::{database::ColumnType, time::timeunit::PoSQLTimeUnit}; use arrow::{ array::{ Array, BooleanArray, Decimal128Array, Decimal256Array, Int16Array, Int32Array, Int64Array, diff --git a/crates/proof-of-sql/src/base/math/decimal.rs b/crates/proof-of-sql/src/base/math/decimal.rs index 38f19013b..b2dcd9020 100644 --- a/crates/proof-of-sql/src/base/math/decimal.rs +++ b/crates/proof-of-sql/src/base/math/decimal.rs @@ -5,6 +5,27 @@ use crate::{ }; use proof_of_sql_parser::intermediate_decimal::IntermediateDecimal; use serde::{Deserialize, Deserializer, Serialize}; +use thiserror::Error; + +/// Errors related to decimal operations. +#[derive(Error, Debug, PartialEq, Eq)] +pub enum DecimalError { + #[error("Invalid decimal format or value: {0}")] + /// Error when a decimal format or value is incorrect + InvalidDecimal(String), + + #[error("Unsupported operation: cannot round decimal: {0}")] + /// Decimal rounding is not supported + DecimalRoundingError(String), + + #[error("Error while parsing precision from query: {0}")] + /// Error in parsing precision in a query + PrecisionParseError(String), + + #[error("Decimal precision is not valid: {0}")] + /// Decimal precision exceeds the allowed limit + InvalidPrecision(u8), +} #[derive(Eq, PartialEq, Debug, Clone, Hash, Serialize, Copy)] /// limit-enforced precision @@ -15,9 +36,11 @@ impl Precision { /// Constructor for creating a Precision instance pub fn new(value: u8) -> Result { if value > MAX_SUPPORTED_PRECISION || value == 0 { - Err(ConversionError::PrecisionParseError(format!( - "Failed to parse precision. Value of {} exceeds max supported precision of {}", - value, MAX_SUPPORTED_PRECISION + Err(ConversionError::Decimal(DecimalError::PrecisionParseError( + format!( + "Failed to parse precision. Value of {} exceeds max supported precision of {}", + value, MAX_SUPPORTED_PRECISION + ), ))) } else { Ok(Precision(value)) @@ -73,8 +96,8 @@ impl Decimal { ) -> ConversionResult> { let scale_factor = new_scale - self.scale; if scale_factor < 0 || new_precision.value() < self.precision.value() + scale_factor as u8 { - return Err(ConversionError::DecimalRoundingError( - "Scale factor must be non-negative".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Scale factor must be non-negative".to_string()), )); } let scaled_value = scale_scalar(self.value, scale_factor)?; @@ -86,13 +109,13 @@ impl Decimal { const MINIMAL_PRECISION: u8 = 19; let raw_precision = precision.value(); if raw_precision < MINIMAL_PRECISION { - return Err(ConversionError::DecimalRoundingError( - "Precision must be at least 19".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Precision must be at least 19".to_string()), )); } if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 { - return Err(ConversionError::DecimalRoundingError( - "Can not scale down a decimal".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Can not scale down a decimal".to_string()), )); } let scaled_value = scale_scalar(S::from(&value), scale)?; @@ -104,13 +127,13 @@ impl Decimal { const MINIMAL_PRECISION: u8 = 39; let raw_precision = precision.value(); if raw_precision < MINIMAL_PRECISION { - return Err(ConversionError::DecimalRoundingError( - "Precision must be at least 19".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Precision must be at least 19".to_string()), )); } if scale < 0 || raw_precision < MINIMAL_PRECISION + scale as u8 { - return Err(ConversionError::DecimalRoundingError( - "Can not scale down a decimal".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Can not scale down a decimal".to_string()), )); } let scaled_value = scale_scalar(S::from(&value), scale)?; @@ -147,8 +170,8 @@ pub(crate) fn try_into_to_scalar( /// Note that we do not check for overflow. pub(crate) fn scale_scalar(s: S, scale: i8) -> ConversionResult { if scale < 0 { - return Err(ConversionError::DecimalRoundingError( - "Scale factor must be non-negative".to_string(), + return Err(ConversionError::Decimal( + DecimalError::DecimalRoundingError("Scale factor must be non-negative".to_string()), )); } let ten = S::from(10); diff --git a/crates/proof-of-sql/src/base/scalar/mont_scalar.rs b/crates/proof-of-sql/src/base/scalar/mont_scalar.rs index 2ef60e34c..f2da03261 100644 --- a/crates/proof-of-sql/src/base/scalar/mont_scalar.rs +++ b/crates/proof-of-sql/src/base/scalar/mont_scalar.rs @@ -1,5 +1,8 @@ use super::{scalar_conversion_to_int, Scalar, ScalarConversionError}; -use crate::{base::math::decimal::MAX_SUPPORTED_PRECISION, sql::parse::ConversionError}; +use crate::{ + base::math::decimal::{DecimalError, MAX_SUPPORTED_PRECISION}, + sql::parse::ConversionError, +}; use ark_ff::{BigInteger, Field, Fp, Fp256, MontBackend, MontConfig, PrimeField}; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use bytemuck::TransparentWrapper; @@ -163,11 +166,11 @@ impl> TryFrom for MontScalar { // Check if the number of digits exceeds the maximum precision allowed if digits.len() > MAX_SUPPORTED_PRECISION.into() { - return Err(ConversionError::InvalidDecimal(format!( + return Err(ConversionError::Decimal(DecimalError::InvalidDecimal(format!( "Attempted to parse a number with {} digits, which exceeds the max supported precision of {}", digits.len(), MAX_SUPPORTED_PRECISION - ))); + )))); } // Continue with the previous logic diff --git a/crates/proof-of-sql/src/base/time/mod.rs b/crates/proof-of-sql/src/base/time/mod.rs index 2775c0048..c58422bf3 100644 --- a/crates/proof-of-sql/src/base/time/mod.rs +++ b/crates/proof-of-sql/src/base/time/mod.rs @@ -1,2 +1,6 @@ -/// Stores all functionality relelvant to timestamps +/// Native timunit type for proof-of-sql +pub mod timeunit; +/// Typed timezone type for proof-of-sql +pub mod timezone; +/// Native timestamp type for proof-of-sql pub mod timestamp; diff --git a/crates/proof-of-sql/src/base/time/timestamp.rs b/crates/proof-of-sql/src/base/time/timestamp.rs index 3ae5d5b8c..09df72632 100644 --- a/crates/proof-of-sql/src/base/time/timestamp.rs +++ b/crates/proof-of-sql/src/base/time/timestamp.rs @@ -1,241 +1,13 @@ -use crate::base::database::{ArrowArrayToColumnConversionError, OwnedArrowConversionError}; -use arrow::datatypes::TimeUnit as ArrowTimeUnit; -use chrono_tz::Tz; -use core::fmt; +use super::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}; use serde::{Deserialize, Serialize}; -use std::{str::FromStr, sync::Arc}; -/// A wrapper around i64 to mitigate conflicting From -/// implementations -#[derive(Clone, Copy)] -pub struct Time { - /// i64 count of timeunits since unix epoch +/// Intermediate Time +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +pub struct Timestamp { + /// Count of time units since the unix epoch pub timestamp: i64, - /// Timeunit of this time + /// Seconds, milliseconds, microseconds, or nanoseconds pub unit: PoSQLTimeUnit, -} - -/// A typed TimeZone for a [`TimeStamp`]. It is optionally -/// used to define a timezone other than UTC for a new TimeStamp. -/// It exists as a wrapper around chrono-tz because chrono-tz does -/// not implement uniform bit distribution -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)] -pub struct PoSQLTimeZone(Tz); - -impl PoSQLTimeZone { - /// Convenience constant for the UTC timezone - pub const UTC: PoSQLTimeZone = PoSQLTimeZone(Tz::UTC); -} - -impl PoSQLTimeZone { - /// Create a new ProofsTimeZone from a chrono TimeZone - pub fn new(tz: Tz) -> Self { - PoSQLTimeZone(tz) - } -} - -impl From<&PoSQLTimeZone> for Arc { - fn from(timezone: &PoSQLTimeZone) -> Self { - Arc::from(timezone.0.name()) - } -} - -impl From for PoSQLTimeZone { - fn from(tz: Tz) -> Self { - PoSQLTimeZone(tz) - } -} - -impl fmt::Display for PoSQLTimeZone { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - -impl TryFrom>> for PoSQLTimeZone { - type Error = &'static str; - - fn try_from(value: Option>) -> Result { - match value { - Some(arc_str) => Tz::from_str(&arc_str) - .map(PoSQLTimeZone) - .map_err(|_| "Invalid timezone string"), - None => Ok(PoSQLTimeZone(Tz::UTC)), // Default to UTC - } - } -} - -impl TryFrom<&str> for PoSQLTimeZone { - type Error = &'static str; - - fn try_from(value: &str) -> Result { - Tz::from_str(value) - .map(PoSQLTimeZone) - .map_err(|_| "Invalid timezone string") - } -} - -/// Specifies different units of time measurement relative to the Unix epoch. It is essentially -/// a wrapper over [arrow::datatypes::TimeUnit] so that we can derive Copy and implement custom traits -/// such as bit distribution and Hash. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize, Hash)] -pub enum PoSQLTimeUnit { - /// Represents a time unit of one second. - Second, - /// Represents a time unit of one millisecond (1/1,000 of a second). - Millisecond, - /// Represents a time unit of one microsecond (1/1,000,000 of a second). - Microsecond, - /// Represents a time unit of one nanosecond (1/1,000,000,000 of a second). - Nanosecond, -} - -impl From for ArrowTimeUnit { - fn from(unit: PoSQLTimeUnit) -> Self { - match unit { - PoSQLTimeUnit::Second => ArrowTimeUnit::Second, - PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond, - PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond, - PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond, - } - } -} - -impl fmt::Display for PoSQLTimeUnit { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - PoSQLTimeUnit::Second => write!(f, "Second"), - PoSQLTimeUnit::Millisecond => write!(f, "Millisecond"), - PoSQLTimeUnit::Microsecond => write!(f, "Microsecond"), - PoSQLTimeUnit::Nanosecond => write!(f, "Nanosecond"), - } - } -} - -impl From for PoSQLTimeUnit { - fn from(unit: ArrowTimeUnit) -> Self { - match unit { - ArrowTimeUnit::Second => PoSQLTimeUnit::Second, - ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond, - ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond, - ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond, - } - } -} - -impl From<&'static str> for OwnedArrowConversionError { - fn from(error: &'static str) -> Self { - OwnedArrowConversionError::InvalidTimezone(error.to_string()) - } -} - -impl From<&'static str> for ArrowArrayToColumnConversionError { - fn from(error: &'static str) -> Self { - ArrowArrayToColumnConversionError::TimezoneConversionError(error.to_string()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use chrono_tz::Tz; - - #[test] - fn valid_timezones_convert_correctly() { - let valid_timezones = ["Europe/London", "America/New_York", "Asia/Tokyo", "UTC"]; - - for tz_str in &valid_timezones { - let arc_tz = Arc::new(tz_str.to_string()); - // Convert Arc to Arc by dereferencing to &str then creating a new Arc - let arc_tz_str: Arc = Arc::from(&**arc_tz); - let timezone = PoSQLTimeZone::try_from(Some(arc_tz_str)); - assert!(timezone.is_ok(), "Timezone should be valid: {}", tz_str); - assert_eq!( - timezone.unwrap().0, - Tz::from_str(tz_str).unwrap(), - "Timezone mismatch for {}", - tz_str - ); - } - } - - #[test] - fn test_edge_timezone_strings() { - let edge_timezones = ["Etc/GMT+12", "Etc/GMT-14", "America/Argentina/Ushuaia"]; - for tz_str in &edge_timezones { - let arc_tz = Arc::from(*tz_str); - let result = PoSQLTimeZone::try_from(Some(arc_tz)); - assert!(result.is_ok(), "Edge timezone should be valid: {}", tz_str); - assert_eq!( - result.unwrap().0, - Tz::from_str(tz_str).unwrap(), - "Mismatch for edge timezone {}", - tz_str - ); - } - } - - #[test] - fn test_empty_timezone_string() { - let empty_tz = Arc::from(""); - let result = PoSQLTimeZone::try_from(Some(empty_tz)); - assert!(result.is_err(), "Empty timezone string should fail"); - } - - #[test] - fn test_unicode_timezone_strings() { - let unicode_tz = Arc::from("Europe/Paris\u{00A0}"); // Non-breaking space character - let result = PoSQLTimeZone::try_from(Some(unicode_tz)); - assert!( - result.is_err(), - "Unicode characters should not be valid in timezone strings" - ); - } - - #[test] - fn test_null_option() { - let result = PoSQLTimeZone::try_from(None); - assert!(result.is_ok(), "None should convert without error"); - assert_eq!(result.unwrap().0, Tz::UTC, "None should default to UTC"); - } - - #[test] - fn we_can_convert_from_arrow_time_units() { - assert_eq!( - PoSQLTimeUnit::from(ArrowTimeUnit::Second), - PoSQLTimeUnit::Second - ); - assert_eq!( - PoSQLTimeUnit::from(ArrowTimeUnit::Millisecond), - PoSQLTimeUnit::Millisecond - ); - assert_eq!( - PoSQLTimeUnit::from(ArrowTimeUnit::Microsecond), - PoSQLTimeUnit::Microsecond - ); - assert_eq!( - PoSQLTimeUnit::from(ArrowTimeUnit::Nanosecond), - PoSQLTimeUnit::Nanosecond - ); - } - - #[test] - fn we_can_convert_to_arrow_time_units() { - assert_eq!( - ArrowTimeUnit::from(PoSQLTimeUnit::Second), - ArrowTimeUnit::Second - ); - assert_eq!( - ArrowTimeUnit::from(PoSQLTimeUnit::Millisecond), - ArrowTimeUnit::Millisecond - ); - assert_eq!( - ArrowTimeUnit::from(PoSQLTimeUnit::Microsecond), - ArrowTimeUnit::Microsecond - ); - assert_eq!( - ArrowTimeUnit::from(PoSQLTimeUnit::Nanosecond), - ArrowTimeUnit::Nanosecond - ); - } -} + /// Timezone captured from parsed string + pub timezone: PoSQLTimeZone, +} \ No newline at end of file diff --git a/crates/proof-of-sql/src/base/time/timeunit.rs b/crates/proof-of-sql/src/base/time/timeunit.rs new file mode 100644 index 000000000..51aec094f --- /dev/null +++ b/crates/proof-of-sql/src/base/time/timeunit.rs @@ -0,0 +1,134 @@ +use crate::base::database::{ArrowArrayToColumnConversionError, OwnedArrowConversionError}; +use arrow::datatypes::TimeUnit as ArrowTimeUnit; +use core::fmt; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; + +/// A wrapper around i64 to mitigate conflicting From +/// implementations +#[derive(Clone, Copy)] +pub struct Time { + /// i64 count of timeunits since unix epoch + pub timestamp: i64, + /// Timeunit of this time + pub unit: PoSQLTimeUnit, +} + +/// Specifies different units of time measurement relative to the Unix epoch. It is essentially +/// a wrapper over [arrow::datatypes::TimeUnit] so that we can derive Copy and implement custom traits +/// such as bit distribution and Hash. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize, Serialize, Hash)] +pub enum PoSQLTimeUnit { + /// Represents a time unit of one second. + Second, + /// Represents a time unit of one millisecond (1/1,000 of a second). + Millisecond, + /// Represents a time unit of one microsecond (1/1,000,000 of a second). + Microsecond, + /// Represents a time unit of one nanosecond (1/1,000,000,000 of a second). + Nanosecond, +} + +impl From for ArrowTimeUnit { + fn from(unit: PoSQLTimeUnit) -> Self { + match unit { + PoSQLTimeUnit::Second => ArrowTimeUnit::Second, + PoSQLTimeUnit::Millisecond => ArrowTimeUnit::Millisecond, + PoSQLTimeUnit::Microsecond => ArrowTimeUnit::Microsecond, + PoSQLTimeUnit::Nanosecond => ArrowTimeUnit::Nanosecond, + } + } +} + +impl fmt::Display for PoSQLTimeUnit { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PoSQLTimeUnit::Second => write!(f, "Second"), + PoSQLTimeUnit::Millisecond => write!(f, "Millisecond"), + PoSQLTimeUnit::Microsecond => write!(f, "Microsecond"), + PoSQLTimeUnit::Nanosecond => write!(f, "Nanosecond"), + } + } +} + +impl FromStr for PoSQLTimeUnit { + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "Second" => Ok(PoSQLTimeUnit::Second), + "Millisecond" => Ok(PoSQLTimeUnit::Millisecond), + "Microsecond" => Ok(PoSQLTimeUnit::Microsecond), + "Nanosecond" => Ok(PoSQLTimeUnit::Nanosecond), + _ => Err(()), + } + } +} + +impl From for PoSQLTimeUnit { + fn from(unit: ArrowTimeUnit) -> Self { + match unit { + ArrowTimeUnit::Second => PoSQLTimeUnit::Second, + ArrowTimeUnit::Millisecond => PoSQLTimeUnit::Millisecond, + ArrowTimeUnit::Microsecond => PoSQLTimeUnit::Microsecond, + ArrowTimeUnit::Nanosecond => PoSQLTimeUnit::Nanosecond, + } + } +} + +impl From<&'static str> for OwnedArrowConversionError { + fn from(error: &'static str) -> Self { + OwnedArrowConversionError::InvalidTimezone(error.to_string()) + } +} + +impl From<&'static str> for ArrowArrayToColumnConversionError { + fn from(error: &'static str) -> Self { + ArrowArrayToColumnConversionError::TimezoneConversionError(error.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn we_can_convert_from_arrow_time_units() { + assert_eq!( + PoSQLTimeUnit::from(ArrowTimeUnit::Second), + PoSQLTimeUnit::Second + ); + assert_eq!( + PoSQLTimeUnit::from(ArrowTimeUnit::Millisecond), + PoSQLTimeUnit::Millisecond + ); + assert_eq!( + PoSQLTimeUnit::from(ArrowTimeUnit::Microsecond), + PoSQLTimeUnit::Microsecond + ); + assert_eq!( + PoSQLTimeUnit::from(ArrowTimeUnit::Nanosecond), + PoSQLTimeUnit::Nanosecond + ); + } + + #[test] + fn we_can_convert_to_arrow_time_units() { + assert_eq!( + ArrowTimeUnit::from(PoSQLTimeUnit::Second), + ArrowTimeUnit::Second + ); + assert_eq!( + ArrowTimeUnit::from(PoSQLTimeUnit::Millisecond), + ArrowTimeUnit::Millisecond + ); + assert_eq!( + ArrowTimeUnit::from(PoSQLTimeUnit::Microsecond), + ArrowTimeUnit::Microsecond + ); + assert_eq!( + ArrowTimeUnit::from(PoSQLTimeUnit::Nanosecond), + ArrowTimeUnit::Nanosecond + ); + } +} diff --git a/crates/proof-of-sql/src/base/time/timezone.rs b/crates/proof-of-sql/src/base/time/timezone.rs new file mode 100644 index 000000000..f7a9888b1 --- /dev/null +++ b/crates/proof-of-sql/src/base/time/timezone.rs @@ -0,0 +1,140 @@ +use chrono_tz::Tz; +use core::fmt; +use serde::{Deserialize, Serialize}; +use std::{str::FromStr, sync::Arc}; + +/// A typed TimeZone for a [`TimeStamp`]. It is optionally +/// used to define a timezone other than UTC for a new TimeStamp. +/// It exists as a wrapper around chrono-tz because chrono-tz does +/// not implement uniform bit distribution +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct PoSQLTimeZone(Tz); + +impl PoSQLTimeZone { + /// Convenience constant for the UTC timezone + pub const UTC: PoSQLTimeZone = PoSQLTimeZone(Tz::UTC); +} + +impl PoSQLTimeZone { + /// Create a new ProofsTimeZone from a chrono TimeZone + pub fn new(tz: Tz) -> Self { + PoSQLTimeZone(tz) + } +} + +impl From<&PoSQLTimeZone> for Arc { + fn from(timezone: &PoSQLTimeZone) -> Self { + Arc::from(timezone.0.name()) + } +} + +impl From for PoSQLTimeZone { + fn from(tz: Tz) -> Self { + PoSQLTimeZone(tz) + } +} + +impl fmt::Display for PoSQLTimeZone { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl TryFrom>> for PoSQLTimeZone { + type Error = &'static str; + + fn try_from(value: Option>) -> Result { + match value { + Some(arc_str) => Tz::from_str(&arc_str) + .map(PoSQLTimeZone) + .map_err(|_| "Invalid timezone string"), + None => Ok(PoSQLTimeZone(Tz::UTC)), // Default to UTC + } + } +} + +impl TryFrom<&str> for PoSQLTimeZone { + type Error = &'static str; + + fn try_from(value: &str) -> Result { + Tz::from_str(value) + .map(PoSQLTimeZone) + .map_err(|_| "Invalid timezone string") + } +} + +impl FromStr for PoSQLTimeZone { + type Err = &'static str; + + fn from_str(value: &str) -> Result { + Tz::from_str(value) + .map(PoSQLTimeZone) + .map_err(|_| "Invalid timezone string") + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono_tz::Tz; + use std::sync::Arc; + + #[test] + fn valid_timezones_convert_correctly() { + let valid_timezones = ["Europe/London", "America/New_York", "Asia/Tokyo", "UTC"]; + + for tz_str in &valid_timezones { + let arc_tz = Arc::new(tz_str.to_string()); + // Convert Arc to Arc by dereferencing to &str then creating a new Arc + let arc_tz_str: Arc = Arc::from(&**arc_tz); + let timezone = PoSQLTimeZone::try_from(Some(arc_tz_str)); + assert!(timezone.is_ok(), "Timezone should be valid: {}", tz_str); + assert_eq!( + timezone.unwrap().0, + Tz::from_str(tz_str).unwrap(), + "Timezone mismatch for {}", + tz_str + ); + } + } + + #[test] + fn test_edge_timezone_strings() { + let edge_timezones = ["Etc/GMT+12", "Etc/GMT-14", "America/Argentina/Ushuaia"]; + for tz_str in &edge_timezones { + let arc_tz = Arc::from(*tz_str); + let result = PoSQLTimeZone::try_from(Some(arc_tz)); + assert!(result.is_ok(), "Edge timezone should be valid: {}", tz_str); + assert_eq!( + result.unwrap().0, + Tz::from_str(tz_str).unwrap(), + "Mismatch for edge timezone {}", + tz_str + ); + } + } + + #[test] + fn test_empty_timezone_string() { + let empty_tz = Arc::from(""); + let result = PoSQLTimeZone::try_from(Some(empty_tz)); + assert!(result.is_err(), "Empty timezone string should fail"); + } + + #[test] + fn test_unicode_timezone_strings() { + let unicode_tz = Arc::from("Europe/Paris\u{00A0}"); // Non-breaking space character + let result = PoSQLTimeZone::try_from(Some(unicode_tz)); + assert!( + result.is_err(), + "Unicode characters should not be valid in timezone strings" + ); + } + + #[test] + fn test_null_option() { + let result = PoSQLTimeZone::try_from(None); + assert!(result.is_ok(), "None should convert without error"); + assert_eq!(result.unwrap().0, Tz::UTC, "None should default to UTC"); + } +} diff --git a/crates/proof-of-sql/src/sql/ast/comparison_util.rs b/crates/proof-of-sql/src/sql/ast/comparison_util.rs index 4b1d28461..04ac57e8e 100644 --- a/crates/proof-of-sql/src/sql/ast/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/ast/comparison_util.rs @@ -1,5 +1,9 @@ use crate::{ - base::{database::Column, math::decimal::Precision, scalar::Scalar}, + base::{ + database::Column, + math::decimal::{scale_scalar, DecimalError, Precision}, + scalar::Scalar, + }, sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, }; use bumpalo::Bump; @@ -67,8 +71,9 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( rhs_precision_value + (max_scale - rhs_scale) as u8, ); // Check if the precision is valid - let _max_precision = Precision::new(max_precision_value) - .map_err(|_| ConversionError::InvalidPrecision(max_precision_value as i16))?; + let _max_precision = Precision::new(max_precision_value).map_err(|_| { + ConversionError::Decimal(DecimalError::InvalidPrecision(max_precision_value)) + })?; } unchecked_subtract_impl( alloc, diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 255522b3e..68c59f465 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -1,5 +1,8 @@ -use crate::base::database::ColumnType; -use proof_of_sql_parser::{intermediate_decimal::DecimalError, Identifier, ResourceId}; +use crate::base::{database::ColumnType, math::decimal::DecimalError}; +use proof_of_sql_parser::{ + intermediate_decimal::IntermediateDecimalError, intermediate_time::IntermediateTimestampError, + Identifier, ResourceId, +}; use thiserror::Error; /// Errors from converting an intermediate AST into a provable AST. @@ -50,18 +53,6 @@ pub enum ConversionError { /// General error for invalid expressions InvalidExpression(String), - #[error("Unsupported operation: cannot round decimal: {0}")] - /// Decimal rounding is not supported - DecimalRoundingError(String), - - #[error("Error while parsing precision from query: {0}")] - /// Error in parsing precision in a query - PrecisionParseError(String), - - #[error("Decimal precision is not valid: {0}")] - /// Decimal precision is an integer but exceeds the allowed limit. We use i16 here to include all kinds of invalid precision values. - InvalidPrecision(i16), - #[error("Encountered parsing error: {0}")] /// General parsing error ParseError(String), @@ -74,26 +65,17 @@ pub enum ConversionError { /// Query requires unprovable feature Unprovable(String), - #[error("Invalid decimal format or value: {0}")] - /// Error when a decimal format or value is incorrect - InvalidDecimal(String), -} + #[error(transparent)] + /// Errors related to decimal operations + Decimal(#[from] DecimalError), -impl From for ConversionError { - fn from(error: DecimalError) -> Self { - match error { - DecimalError::ParseError(e) => ConversionError::ParseError(e.to_string()), - DecimalError::OutOfRange => ConversionError::ParseError( - "Intermediate decimal cannot be cast to primitive".into(), - ), - DecimalError::LossyCast => ConversionError::ParseError( - "Intermediate decimal has non-zero fractional part".into(), - ), - DecimalError::ConversionFailure => { - ConversionError::ParseError("Could not cast into intermediate decimal.".into()) - } - } - } + #[error(transparent)] + /// Errors related to the processing of intermediate decimal values + IntermediateDecimal(#[from] IntermediateDecimalError), + + #[error(transparent)] + /// Errors from intermediate timestamp conversion + IntermediateTimestamp(#[from] IntermediateTimestampError), } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs index d33637533..70aa660f7 100644 --- a/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/provable_expr_plan_builder.rs @@ -3,12 +3,14 @@ use crate::{ base::{ commitment::Commitment, database::{ColumnRef, LiteralValue}, - math::decimal::{try_into_to_scalar, Precision}, + math::decimal::{try_into_to_scalar, DecimalError, Precision}, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }, sql::ast::{ColumnExpr, ProvableExprPlan}, }; use proof_of_sql_parser::{ intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, + intermediate_time::IntermediateTimestampError, Identifier, }; use std::collections::HashMap; @@ -72,8 +74,9 @@ impl ProvableExprPlanBuilder<'_> { Literal::Int128(i) => Ok(ProvableExprPlan::new_literal(LiteralValue::Int128(*i))), Literal::Decimal(d) => { let scale = d.scale(); - let precision = Precision::new(d.precision()) - .map_err(|_| ConversionError::InvalidPrecision(d.precision() as i16))?; + let precision = Precision::new(d.precision()).map_err(|_| { + ConversionError::Decimal(DecimalError::InvalidPrecision(d.precision())) + })?; Ok(ProvableExprPlan::new_literal(LiteralValue::Decimal75( precision, scale, @@ -84,6 +87,25 @@ impl ProvableExprPlanBuilder<'_> { s.clone(), s.into(), )))), + Literal::TimestampTZ(its) => { + let posql_tu = its + .unit + .to_string() + .parse::() + .map_err(|_| IntermediateTimestampError::InvalidTimeUnit)?; + + let posql_tz = its + .timezone + .to_string() + .parse::() + .map_err(|_| IntermediateTimestampError::InvalidTimeUnit)?; + + Ok(ProvableExprPlan::new_literal(LiteralValue::TimestampTZ( + posql_tu, + posql_tz, + its.timestamp, + ))) + } } } diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 0b870b1cc..dfaa1652c 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -3,6 +3,7 @@ use crate::{ base::{ database::{ColumnRef, ColumnType, SchemaAccessor, TableRef}, math::decimal::Precision, + time::{timeunit::PoSQLTimeUnit, timezone::PoSQLTimeZone}, }, sql::ast::try_add_subtract_column_types, }; @@ -11,6 +12,7 @@ use proof_of_sql_parser::{ AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy, SelectResultExpr, Slice, TableExpression, UnaryOperator, }, + intermediate_time::IntermediateTimestampError, Identifier, ResourceId, }; use std::ops::Deref; @@ -249,6 +251,20 @@ impl<'a> QueryContextBuilder<'a> { let precision = Precision::new(d.precision())?; Ok(ColumnType::Decimal75(precision, d.scale())) } + Literal::TimestampTZ(its) => { + let posql_tu = its + .unit + .to_string() + .parse::() + .map_err(|_| IntermediateTimestampError::InvalidTimeUnit)?; + + let posql_tz = its + .timezone + .to_string() + .parse::() + .map_err(|_| IntermediateTimestampError::InvalidTimeZone)?; + Ok(ColumnType::TimestampTZ(posql_tu, posql_tz)) + } } } diff --git a/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs b/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs index 3204a7417..683b8dc58 100644 --- a/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs +++ b/crates/proof-of-sql/src/sql/transform/to_polars_expr.rs @@ -30,6 +30,7 @@ impl ToPolarsExpr for Expression { Literal::Int128(value) => value.to_lit(), Literal::VarChar(_) => panic!("Expression not supported"), Literal::Decimal(_) => todo!(), + Literal::TimestampTZ(_) => todo!(), }, Expression::Column(identifier) => col(identifier.as_str()), Expression::Binary { op, left, right } => {