From 27349fc85d93f60dab857e90939fc964ddcd6ce9 Mon Sep 17 00:00:00 2001 From: Lordworms Date: Wed, 11 Dec 2024 22:09:56 -0800 Subject: [PATCH] Ensure Partition Specs can only contain primitive types --- crates/iceberg/src/spec/partition.rs | 119 +++++++++++++++++++++------ 1 file changed, 96 insertions(+), 23 deletions(-) diff --git a/crates/iceberg/src/spec/partition.rs b/crates/iceberg/src/spec/partition.rs index 445e7d441..02ee7b1d9 100644 --- a/crates/iceberg/src/spec/partition.rs +++ b/crates/iceberg/src/spec/partition.rs @@ -20,7 +20,8 @@ */ use std::sync::Arc; -use serde::{Deserialize, Serialize}; +use serde::de::Unexpected; +use serde::{Deserialize, Deserializer, Serialize}; use typed_builder::TypedBuilder; use super::transform::Transform; @@ -69,6 +70,44 @@ pub struct BoundPartitionSpec { partition_type: StructType, } +impl<'de> Deserialize<'de> for BoundPartitionSpec { + fn deserialize(deserializer: D) -> std::result::Result + where D: Deserializer<'de> { + #[derive(Deserialize)] + struct RawBoundPartitionSpec { + spec_id: i32, + fields: Vec, + schema: SchemaRef, + partition_type: StructType, + } + + let raw_spec = RawBoundPartitionSpec::deserialize(deserializer)?; + + for field in &raw_spec.fields { + let field_type = raw_spec + .schema + .field_by_id(field.source_id) + .ok_or_else(|| serde::de::Error::custom("Invalid value"))? + .field_type + .clone(); + + if !field_type.is_primitive() { + return Err(serde::de::Error::invalid_type( + Unexpected::Other("non-primitive field type"), + &"a primitive field type", + )); + } + } + + Ok(BoundPartitionSpec { + spec_id: raw_spec.spec_id, + fields: raw_spec.fields, + schema: raw_spec.schema, + partition_type: raw_spec.partition_type, + }) + } +} + /// Reference to [`SchemalessPartitionSpec`]. pub type SchemalessPartitionSpecRef = Arc; /// Partition spec that defines how to produce a tuple of partition values from a record. @@ -645,33 +684,31 @@ impl PartitionSpecBuilder { ) })?; - if field.transform != Transform::Void { - if !schema_field.field_type.is_primitive() { - return Err(Error::new( - ErrorKind::DataInvalid, - format!( - "Cannot partition by non-primitive source field: '{}'.", - schema_field.field_type - ), - )); - } + if !schema_field.field_type.is_primitive() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Cannot partition by non-primitive source field: '{}'.", + schema_field.field_type + ), + )); + } - if field + if field.transform != Transform::Void + && field .transform .result_type(&schema_field.field_type) .is_err() - { - return Err(Error::new( - ErrorKind::DataInvalid, - format!( - "Invalid source type: '{}' for transform: '{}'.", - schema_field.field_type, - field.transform.dedup_name() - ), - )); - } + { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Invalid source type: '{}' for transform: '{}'.", + schema_field.field_type, + field.transform.dedup_name() + ), + )); } - Ok(()) } } @@ -1870,4 +1907,40 @@ mod tests { assert_eq!(1002, spec.fields[1].field_id); assert!(!spec.has_sequential_ids()); } + + #[test] + fn test_add_unbound_field_disallow_complex_types() { + let schema: Schema = Schema::builder() + .with_fields(vec![NestedField::required( + 1, + "user", + Type::Struct(StructType::new(vec![ + NestedField::required(2, "id", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::required(3, "name", Type::Primitive(PrimitiveType::String)).into(), + ])), + ) + .into()]) + .build() + .unwrap(); + + let builder = BoundPartitionSpec::builder(schema.clone()); + assert!(builder + .add_unbound_field(UnboundPartitionField { + source_id: 2, + field_id: Some(1001), + name: "user_id_partition".to_string(), + transform: Transform::Identity, + }) + .is_ok()); + + let builder = BoundPartitionSpec::builder(schema); + let result = builder.add_unbound_field(UnboundPartitionField { + source_id: 1, + field_id: Some(1002), + name: "user_struct_partition".to_string(), + transform: Transform::Void, + }); + + assert!(result.is_err()); + } }