Skip to content

Commit

Permalink
Add RecordBatchOptions::skip_schema_check
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Dec 7, 2024
1 parent 63ad87a commit 4a27246
Showing 1 changed file with 51 additions and 19 deletions.
70 changes: 51 additions & 19 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,27 +309,31 @@ impl RecordBatch {
return Err(ArrowError::InvalidArgumentError(err.to_string()));
}

// function for comparing column type and field type
// return true if 2 types are not matched
let type_not_match = if options.match_field_names {
|(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| col_type != field_type
} else {
|(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
!col_type.equals_datatype(field_type)
}
};
if !options.skip_schema_check {
// function for comparing column type and field type
// return true if 2 types are not matched
let type_not_match = if options.match_field_names {
|(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
col_type != field_type
}
} else {
|(_, (col_type, field_type)): &(usize, (&DataType, &DataType))| {
!col_type.equals_datatype(field_type)
}
};

// check that all columns match the schema
let not_match = columns
.iter()
.zip(schema.fields().iter())
.map(|(col, field)| (col.data_type(), field.data_type()))
.enumerate()
.find(type_not_match);
// check that all columns match the schema
let not_match = columns
.iter()
.zip(schema.fields().iter())
.map(|(col, field)| (col.data_type(), field.data_type()))
.enumerate()
.find(type_not_match);

if let Some((i, (col_type, field_type))) = not_match {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
if let Some((i, (col_type, field_type))) = not_match {
return Err(ArrowError::InvalidArgumentError(format!(
"column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}")));
}
}

Ok(RecordBatch {
Expand Down Expand Up @@ -390,6 +394,7 @@ impl RecordBatch {
&RecordBatchOptions {
match_field_names: true,
row_count: Some(self.row_count),
skip_schema_check: false,
},
)
}
Expand Down Expand Up @@ -631,6 +636,13 @@ pub struct RecordBatchOptions {

/// Optional row count, useful for specifying a row count for a RecordBatch with no columns
pub row_count: Option<usize>,

/// Option to skip schema checking when creating new record batches. This is intended for
/// cases where the schema has already been checked or where more flexibility is required
/// in downstream projects, such as allowing either Utf8 or Dictionary<_, Utf8> for a
/// schema with type Utf8. This option should not be used within DataFusion since it is
/// an invariant that all batches must have the same physical schema during execution.
pub skip_schema_check: bool,
}

impl RecordBatchOptions {
Expand All @@ -639,6 +651,7 @@ impl RecordBatchOptions {
Self {
match_field_names: true,
row_count: None,
skip_schema_check: false,
}
}
/// Sets the row_count of RecordBatchOptions and returns self
Expand All @@ -651,6 +664,11 @@ impl RecordBatchOptions {
self.match_field_names = match_field_names;
self
}
/// Sets the skip_schema_check of RecordBatchOptions and returns self
pub fn with_skip_schema_check(mut self, skip_schema_check: bool) -> Self {
self.skip_schema_check = skip_schema_check;
self
}
}
impl Default for RecordBatchOptions {
fn default() -> Self {
Expand Down Expand Up @@ -942,6 +960,18 @@ mod tests {
assert!(batch.is_err());
}

#[test]
fn create_record_batch_schema_mismatch_allowed() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int64Array::from(vec![1, 2, 3, 4, 5]);

let options = RecordBatchOptions::new().with_skip_schema_check(true);
let batch =
RecordBatch::try_new_with_options(Arc::new(schema), vec![Arc::new(a)], &options);
assert!(batch.is_ok());
}

#[test]
fn create_record_batch_field_name_mismatch() {
let fields = vec![
Expand Down Expand Up @@ -982,6 +1012,7 @@ mod tests {
let options = RecordBatchOptions {
match_field_names: false,
row_count: None,
skip_schema_check: false,
};
let batch = RecordBatch::try_new_with_options(schema, vec![a], &options);
assert!(batch.is_ok());
Expand Down Expand Up @@ -1226,6 +1257,7 @@ mod tests {
&RecordBatchOptions {
match_field_names: true,
row_count: Some(3),
skip_schema_check: false,
},
)
.expect("valid conversion");
Expand Down

0 comments on commit 4a27246

Please sign in to comment.