diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 0775392b7d64..2932bfe39a85 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -160,11 +160,18 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, // Utf8 to decimal (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, - (Struct(from_fields), Struct(to_fields)) => { - from_fields.len() == to_fields.len() && - from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { - can_cast_types(f1.data_type(), f2.data_type()) - }) + (Struct(from_fields), Struct(to_fields)) => { + from_fields.len() == to_fields.len() && + from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + // Only allows cast nullable or non-nullable fields to nullable fields. + // Although it is generally allowed to cast non-nullable fields to non-nullable fields, + // but if casting f1.data_type to f2.data_type may return null, e.g., overflow, + // it is not allowed when f2 is non-nullable. Because `can_cast_types` doesn't + // take `CastOptions` so we cannot check `safe` option here. We take safer + // approach to assume `safe` is true, so disallowing the case of casting non-nullable + // fields to non-nullable fields. + f2.is_nullable() && can_cast_types(f1.data_type(), f2.data_type()) + }) } (Struct(_), _) => false, (_, Struct(_)) => false, @@ -9525,4 +9532,64 @@ mod tests { result.unwrap_err().to_string() ); } + + #[test] + fn test_cast_struct_to_struct_disallow_nullability() { + let from_type = DataType::Struct( + vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), + ] + .into(), + ); + + // allow: nullable to nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ] + .into(), + ); + assert!(can_cast_types(&from_type, &to_type)); + + // disallow: nullable to non-nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + assert!(!can_cast_types(&from_type, &to_type)); + + let from_type = DataType::Struct( + vec![ + Field::new("a", DataType::Boolean, false), + Field::new("b", DataType::Int32, false), + ] + .into(), + ); + + // allow: non-nullable to nullable + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ] + .into(), + ); + assert!(can_cast_types(&from_type, &to_type)); + + // disallow: non-nullable to non-nullable + // because casting may return null when overflow + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + assert!(!can_cast_types(&from_type, &to_type)); + } }