diff --git a/interpreter/Cargo.toml b/interpreter/Cargo.toml index 45a1379..ba65b53 100644 --- a/interpreter/Cargo.toml +++ b/interpreter/Cargo.toml @@ -13,7 +13,7 @@ cel-parser = { path = "../parser", version = "0.8.0" } nom = "7.1.3" -chrono = { version = "0.4", default-features = false, features = ["alloc"], optional = true } +chrono = { version = "0.4", default-features = false, features = ["alloc", "serde"], optional = true } regex = { version = "1.10.5", optional = true } serde = "1.0" serde_json = { version = "1.0", optional = true } diff --git a/interpreter/src/lib.rs b/interpreter/src/lib.rs index b11ff47..7dfa87c 100644 --- a/interpreter/src/lib.rs +++ b/interpreter/src/lib.rs @@ -23,6 +23,8 @@ mod duration; mod ser; pub use ser::to_value; +#[cfg(feature = "chrono")] +pub use ser::{Duration, Timestamp}; #[cfg(feature = "json")] mod json; diff --git a/interpreter/src/ser.rs b/interpreter/src/ser.rs index 008cdbd..62c9e88 100644 --- a/interpreter/src/ser.rs +++ b/interpreter/src/ser.rs @@ -5,14 +5,170 @@ use crate::{objects::Key, Value}; use serde::{ - ser::{self, Impossible}, + ser::{self, Impossible, SerializeStruct}, Serialize, }; use std::{collections::HashMap, fmt::Display, iter::FromIterator, sync::Arc}; use thiserror::Error; + +#[cfg(feature = "chrono")] +use chrono::FixedOffset; + pub struct Serializer; pub struct KeySerializer; +/// A wrapper Duration type which allows conversion to [Value::Duration] for +/// types using automatic conversion with [serde::Serialize]. +/// +/// # Examples +/// +/// ``` +/// use cel_interpreter::{Context, Duration, Program}; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct MyStruct { +/// dur: Duration, +/// } +/// +/// let mut context = Context::default(); +/// +/// // MyStruct will be implicitly serialized into the CEL appropriate types +/// context +/// .add_variable( +/// "foo", +/// MyStruct { +/// dur: chrono::Duration::hours(2).into(), +/// }, +/// ) +/// .unwrap(); +/// +/// let program = Program::compile("foo.dur == duration('2h')").unwrap(); +/// let value = program.execute(&context).unwrap(); +/// assert_eq!(value, true.into()); +/// ``` +#[cfg(feature = "chrono")] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub struct Duration(pub chrono::Duration); + +#[cfg(feature = "chrono")] +impl Duration { + // Since serde can't natively represent durations, we serialize a special + // newtype to indicate we want to rebuild the duration in the result, while + // remaining compatible with most other Serializer implemenations. + const NAME: &str = "$__cel_private_Duration"; + const STRUCT_NAME: &str = "Duration"; + const SECS_FIELD: &str = "secs"; + const NANOS_FIELD: &str = "nanos"; +} + +#[cfg(feature = "chrono")] +impl From for chrono::Duration { + fn from(value: Duration) -> Self { + value.0 + } +} + +#[cfg(feature = "chrono")] +impl From for Duration { + fn from(value: chrono::Duration) -> Self { + Self(value) + } +} + +#[cfg(feature = "chrono")] +impl ser::Serialize for Duration { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: ser::Serializer, + { + // chrono::Duration's Serialize impl isn't stable yet and relies on + // private fields, so attempt to mimic serde's default impl for std + // Duration. + struct DurationProxy(chrono::Duration); + impl Serialize for DurationProxy { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + let mut s = serializer.serialize_struct(Duration::STRUCT_NAME, 2)?; + s.serialize_field(Duration::SECS_FIELD, &self.0.num_seconds())?; + s.serialize_field(Duration::NANOS_FIELD, &self.0.subsec_nanos())?; + s.end() + } + } + serializer.serialize_newtype_struct(Self::NAME, &DurationProxy(self.0)) + } +} + +/// A wrapper Timestamp type which allows conversion to [Value::Timestamp] for +/// types using automatic conversion with [serde::Serialize]. +/// +/// # Examples +/// +/// ``` +/// use cel_interpreter::{Context, Timestamp, Program}; +/// use serde::Serialize; +/// +/// #[derive(Serialize)] +/// struct MyStruct { +/// ts: Timestamp, +/// } +/// +/// let mut context = Context::default(); +/// +/// // MyStruct will be implicitly serialized into the CEL appropriate types +/// context +/// .add_variable( +/// "foo", +/// MyStruct { +/// ts: chrono::DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z") +/// .unwrap() +/// .into(), +/// }, +/// ) +/// .unwrap(); +/// +/// let program = Program::compile("foo.ts == timestamp('2025-01-01T00:00:00Z')").unwrap(); +/// let value = program.execute(&context).unwrap(); +/// assert_eq!(value, true.into()); +/// ``` +#[cfg(feature = "chrono")] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)] +pub struct Timestamp(pub chrono::DateTime); + +#[cfg(feature = "chrono")] +impl Timestamp { + // Since serde can't natively represent timestamps, we serialize a special + // newtype to indicate we want to rebuild the timestamp in the result, + // while remaining compatible with most other Serializer implemenations. + const NAME: &str = "$__cel_private_Timestamp"; +} + +#[cfg(feature = "chrono")] +impl From for chrono::DateTime { + fn from(value: Timestamp) -> Self { + value.0 + } +} + +#[cfg(feature = "chrono")] +impl From> for Timestamp { + fn from(value: chrono::DateTime) -> Self { + Self(value) + } +} + +#[cfg(feature = "chrono")] +impl ser::Serialize for Timestamp { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: ser::Serializer, + { + serializer.serialize_newtype_struct(Self::NAME, &self.0) + } +} + #[derive(Error, Debug, PartialEq, Clone)] pub enum SerializationError { InvalidKey(String), @@ -142,11 +298,17 @@ impl ser::Serializer for Serializer { self.serialize_str(variant) } - fn serialize_newtype_struct(self, _name: &'static str, value: &T) -> Result + fn serialize_newtype_struct(self, name: &'static str, value: &T) -> Result where T: ?Sized + Serialize, { - value.serialize(self) + match name { + #[cfg(feature = "chrono")] + Duration::NAME => value.serialize(TimeSerializer::Duration), + #[cfg(feature = "chrono")] + Timestamp::NAME => value.serialize(TimeSerializer::Timestamp), + _ => value.serialize(self), + } } fn serialize_newtype_variant( @@ -237,6 +399,13 @@ pub struct SerializeStructVariant { map: HashMap, } +#[cfg(feature = "chrono")] +#[derive(Debug, Default)] +struct SerializeTimestamp { + secs: i64, + nanos: i32, +} + impl ser::SerializeSeq for SerializeVec { type Ok = Value; type Error = SerializationError; @@ -371,6 +540,55 @@ impl ser::SerializeStructVariant for SerializeStructVariant { } } +#[cfg(feature = "chrono")] +impl ser::SerializeStruct for SerializeTimestamp { + type Ok = Value; + type Error = SerializationError; + fn serialize_field( + &mut self, + key: &'static str, + value: &T, + ) -> std::result::Result<(), Self::Error> + where + T: ?Sized + Serialize, + { + match key { + Duration::SECS_FIELD => { + let Value::Int(val) = value.serialize(Serializer)? else { + return Err(SerializationError::SerdeError( + "invalid type of value in timestamp struct".to_owned(), + )); + }; + self.secs = val; + Ok(()) + } + Duration::NANOS_FIELD => { + let Value::Int(val) = value.serialize(Serializer)? else { + return Err(SerializationError::SerdeError( + "invalid type of value in timestamp struct".to_owned(), + )); + }; + self.nanos = val.try_into().map_err(|_| { + SerializationError::SerdeError( + "timestamp struct nanos field is invalid".to_owned(), + ) + })?; + Ok(()) + } + _ => Err(SerializationError::SerdeError( + "invalid field in duration struct".to_owned(), + )), + } + } + + fn end(self) -> std::result::Result { + Ok(chrono::Duration::seconds(self.secs) + .checked_add(&chrono::Duration::nanoseconds(self.nanos.into())) + .unwrap() + .into()) + } +} + impl ser::Serializer for KeySerializer { type Ok = Key; type Error = SerializationError; @@ -560,6 +778,194 @@ impl ser::Serializer for KeySerializer { } } +#[cfg(feature = "chrono")] +#[derive(Debug)] +enum TimeSerializer { + Duration, + Timestamp, +} + +#[cfg(feature = "chrono")] +impl ser::Serializer for TimeSerializer { + type Ok = Value; + type Error = SerializationError; + + type SerializeStruct = SerializeTimestamp; + + // Should never be used, so just reuse existing. + type SerializeSeq = SerializeVec; + type SerializeTuple = SerializeVec; + type SerializeTupleStruct = SerializeVec; + type SerializeTupleVariant = SerializeTupleVariant; + type SerializeMap = SerializeMap; + type SerializeStructVariant = SerializeStructVariant; + + fn serialize_struct(self, name: &'static str, len: usize) -> Result { + if !matches!(self, Self::Duration { .. }) || name != Duration::STRUCT_NAME { + return Err(SerializationError::SerdeError( + "expected Duration struct with Duration marker newtype struct".to_owned(), + )); + } + if len != 2 { + return Err(SerializationError::SerdeError( + "expected Duration struct to have 2 fields".to_owned(), + )); + } + Ok(SerializeTimestamp::default()) + } + + fn serialize_str(self, v: &str) -> Result { + if !matches!(self, Self::Timestamp) { + return Err(SerializationError::SerdeError( + "expected Timestamp string with Timestamp marker newtype struct".to_owned(), + )); + } + Ok(v.parse::>() + .map_err(|e| SerializationError::SerdeError(e.to_string()))? + .into()) + } + + fn serialize_bool(self, _v: bool) -> Result { + unreachable!() + } + + fn serialize_i8(self, _v: i8) -> Result { + unreachable!() + } + + fn serialize_i16(self, _v: i16) -> Result { + unreachable!() + } + + fn serialize_i32(self, _v: i32) -> Result { + unreachable!() + } + + fn serialize_i64(self, _v: i64) -> Result { + unreachable!() + } + + fn serialize_u8(self, _v: u8) -> Result { + unreachable!() + } + + fn serialize_u16(self, _v: u16) -> Result { + unreachable!() + } + + fn serialize_u32(self, _v: u32) -> Result { + unreachable!() + } + + fn serialize_u64(self, _v: u64) -> Result { + unreachable!() + } + + fn serialize_f32(self, _v: f32) -> Result { + unreachable!() + } + + fn serialize_f64(self, _v: f64) -> Result { + unreachable!() + } + + fn serialize_char(self, _v: char) -> Result { + unreachable!() + } + + fn serialize_bytes(self, _v: &[u8]) -> Result { + unreachable!() + } + + fn serialize_none(self) -> Result { + unreachable!() + } + + fn serialize_some(self, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_unit(self) -> Result { + unreachable!() + } + + fn serialize_unit_struct(self, _name: &'static str) -> Result { + unreachable!() + } + + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + ) -> Result { + unreachable!() + } + + fn serialize_newtype_struct(self, _name: &'static str, _value: &T) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result + where + T: ?Sized + Serialize, + { + unreachable!() + } + + fn serialize_seq(self, _len: Option) -> Result { + unreachable!() + } + + fn serialize_tuple(self, _len: usize) -> Result { + unreachable!() + } + + fn serialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } + + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } + + fn serialize_map(self, _len: Option) -> Result { + unreachable!() + } + + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + unreachable!() + } +} + #[cfg(test)] mod tests { use crate::{objects::Key, to_value, Value}; @@ -568,6 +974,9 @@ mod tests { use serde_bytes::Bytes; use std::{collections::HashMap, iter::FromIterator, sync::Arc}; + #[cfg(feature = "chrono")] + use super::{Duration, Timestamp}; + macro_rules! primitive_test { ($functionName:ident, $strValue: literal, $value: expr) => { #[test] @@ -806,4 +1215,158 @@ mod tests { .into(); assert_eq!(map, expected) } + + #[cfg(feature = "chrono")] + #[derive(Serialize)] + struct TestTimeTypes { + dur: Duration, + ts: Timestamp, + } + + #[cfg(feature = "chrono")] + #[test] + fn test_time_types() { + use chrono::FixedOffset; + + let tests = to_value([ + TestTimeTypes { + dur: chrono::Duration::milliseconds(1527).into(), + ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + }, + // Let's test chrono::Duration's particular handling around math + // and negatives and timestamps from BCE. + TestTimeTypes { + dur: chrono::Duration::milliseconds(-1527).into(), + ts: "-0001-12-01T00:00:00-08:00" + .parse::>() + .unwrap() + .into(), + }, + TestTimeTypes { + dur: (chrono::Duration::seconds(1) - chrono::Duration::nanoseconds(1000000001)) + .into(), + ts: chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00") + .unwrap() + .into(), + }, + TestTimeTypes { + dur: (chrono::Duration::seconds(-1) + chrono::Duration::nanoseconds(1000000001)) + .into(), + ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + }, + ]) + .unwrap(); + let expected: Value = vec![ + Value::Map( + HashMap::<_, Value>::from([ + ("dur", chrono::Duration::milliseconds(1527).into()), + ( + "ts", + chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + ), + ]) + .into(), + ), + Value::Map( + HashMap::<_, Value>::from([ + ("dur", chrono::Duration::nanoseconds(-1527000000).into()), + ( + "ts", + "-0001-12-01T00:00:00-08:00" + .parse::>() + .unwrap() + .into(), + ), + ]) + .into(), + ), + Value::Map( + HashMap::<_, Value>::from([ + ("dur", chrono::Duration::nanoseconds(-1).into()), + ( + "ts", + chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00") + .unwrap() + .into(), + ), + ]) + .into(), + ), + Value::Map( + HashMap::<_, Value>::from([ + ("dur", chrono::Duration::nanoseconds(1).into()), + ( + "ts", + chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + ), + ]) + .into(), + ), + ] + .into(); + assert_eq!(tests, expected); + + let program = Program::compile("test == expected").unwrap(); + let mut context = Context::default(); + context.add_variable("expected", expected).unwrap(); + context.add_variable("test", tests).unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!(value, true.into()); + } + + #[cfg(feature = "chrono")] + #[cfg(feature = "json")] + #[test] + fn test_time_json() { + use chrono::FixedOffset; + + // Test that Durations and Timestamps serialize correctly with + // serde_json. + let tests = [ + TestTimeTypes { + dur: chrono::Duration::milliseconds(1527).into(), + ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + }, + TestTimeTypes { + dur: chrono::Duration::milliseconds(-1527).into(), + ts: "-0001-12-01T00:00:00-08:00" + .parse::>() + .unwrap() + .into(), + }, + TestTimeTypes { + dur: (chrono::Duration::seconds(1) - chrono::Duration::nanoseconds(1000000001)) + .into(), + ts: chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00") + .unwrap() + .into(), + }, + TestTimeTypes { + dur: (chrono::Duration::seconds(-1) + chrono::Duration::nanoseconds(1000000001)) + .into(), + ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00") + .unwrap() + .into(), + }, + ]; + + let expect = "[\ +{\"dur\":{\"secs\":1,\"nanos\":527000000},\"ts\":\"1996-12-19T16:39:57-08:00\"},\ +{\"dur\":{\"secs\":-1,\"nanos\":-527000000},\"ts\":\"-0001-12-01T00:00:00-08:00\"},\ +{\"dur\":{\"secs\":0,\"nanos\":-1},\"ts\":\"0001-12-01T00:00:00+08:00\"},\ +{\"dur\":{\"secs\":0,\"nanos\":1},\"ts\":\"1996-12-19T16:39:57-08:00\"}\ +]"; + let actual = serde_json::to_string(&tests).unwrap(); + assert_eq!(actual, expect); + } }