Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure Partition Specs can only contain primitive types #780

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 96 additions & 23 deletions crates/iceberg/src/spec/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
*/
use std::sync::Arc;

use serde::{Deserialize, Serialize};
use serde::de::Unexpected;
use serde::{Deserialize, Deserializer, Serialize};
use typed_builder::TypedBuilder;

use super::transform::Transform;
Expand Down Expand Up @@ -69,6 +70,44 @@ pub struct BoundPartitionSpec {
partition_type: StructType,
}

impl<'de> Deserialize<'de> for BoundPartitionSpec {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where D: Deserializer<'de> {
#[derive(Deserialize)]
struct RawBoundPartitionSpec {
spec_id: i32,
fields: Vec<PartitionField>,
schema: SchemaRef,
partition_type: StructType,
}

let raw_spec = RawBoundPartitionSpec::deserialize(deserializer)?;

for field in &raw_spec.fields {
let field_type = raw_spec
.schema
.field_by_id(field.source_id)
.ok_or_else(|| serde::de::Error::custom("Invalid value"))?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, can we add the field-id that cannot be found?

Suggested change
.ok_or_else(|| serde::de::Error::custom("Invalid value"))?
.ok_or_else(|| serde::de::Error::custom("Invalid value"))?

.field_type
.clone();

if !field_type.is_primitive() {
return Err(serde::de::Error::invalid_type(
Unexpected::Other("non-primitive field type"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we help the user here by adding the source-id to the error?

Suggested change
Unexpected::Other("non-primitive field type"),
Unexpected::Other("non-primitive field type"),

&"a primitive field type",
));
}
}

Ok(BoundPartitionSpec {
spec_id: raw_spec.spec_id,
fields: raw_spec.fields,
schema: raw_spec.schema,
partition_type: raw_spec.partition_type,
})
}
}

/// Reference to [`SchemalessPartitionSpec`].
pub type SchemalessPartitionSpecRef = Arc<SchemalessPartitionSpec>;
/// Partition spec that defines how to produce a tuple of partition values from a record.
Expand Down Expand Up @@ -645,33 +684,31 @@ impl PartitionSpecBuilder {
)
})?;

if field.transform != Transform::Void {
if !schema_field.field_type.is_primitive() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Cannot partition by non-primitive source field: '{}'.",
schema_field.field_type
),
));
}
if !schema_field.field_type.is_primitive() {
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Cannot partition by non-primitive source field: '{}'.",
schema_field.field_type
),
));
}

if field
if field.transform != Transform::Void
&& field
.transform
.result_type(&schema_field.field_type)
.is_err()
{
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Invalid source type: '{}' for transform: '{}'.",
schema_field.field_type,
field.transform.dedup_name()
),
));
}
{
return Err(Error::new(
ErrorKind::DataInvalid,
format!(
"Invalid source type: '{}' for transform: '{}'.",
schema_field.field_type,
field.transform.dedup_name()
),
));
}

Ok(())
}
}
Expand Down Expand Up @@ -1870,4 +1907,40 @@ mod tests {
assert_eq!(1002, spec.fields[1].field_id);
assert!(!spec.has_sequential_ids());
}

#[test]
fn test_add_unbound_field_disallow_complex_types() {
let schema: Schema = Schema::builder()
.with_fields(vec![NestedField::required(
1,
"user",
Type::Struct(StructType::new(vec![
NestedField::required(2, "id", Type::Primitive(PrimitiveType::Int)).into(),
NestedField::required(3, "name", Type::Primitive(PrimitiveType::String)).into(),
])),
)
.into()])
.build()
.unwrap();

let builder = BoundPartitionSpec::builder(schema.clone());
assert!(builder
.add_unbound_field(UnboundPartitionField {
source_id: 2,
field_id: Some(1001),
name: "user_id_partition".to_string(),
transform: Transform::Identity,
})
.is_ok());

let builder = BoundPartitionSpec::builder(schema);
let result = builder.add_unbound_field(UnboundPartitionField {
source_id: 1,
field_id: Some(1002),
name: "user_struct_partition".to_string(),
transform: Transform::Void,
});

assert!(result.is_err());
}
}
Loading