Skip to content

Commit

Permalink
early type checks in RowConverter (#3080)
Browse files Browse the repository at this point in the history
* refactor: remove duplicate code

Decimal types are already handled by `downcast_primitive`.

* refactor: check supported types when creating `RowConverter`

Check supported row format types when creating the converter instead of
during conversion. Also add an additional method
`RowConverter::supports_fields` to check types w/o relying on an error.

Closes #3077.

* Simplify

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
crepererum and tustvold authored Nov 10, 2022
1 parent ed20bf1 commit 8d364fe
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 45 deletions.
2 changes: 1 addition & 1 deletion arrow/benches/lexsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn do_bench(c: &mut Criterion, columns: &[Column], len: usize) {
.iter()
.map(|a| SortField::new(a.data_type().clone()))
.collect();
let mut converter = RowConverter::new(fields);
let mut converter = RowConverter::new(fields).unwrap();
let rows = converter.convert_columns(&arrays).unwrap();
let mut sort: Vec<_> = rows.iter().enumerate().collect();
sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
Expand Down
4 changes: 2 additions & 2 deletions arrow/benches/row_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ fn do_bench(c: &mut Criterion, name: &str, cols: Vec<ArrayRef>) {

c.bench_function(&format!("convert_columns {}", name), |b| {
b.iter(|| {
let mut converter = RowConverter::new(fields.clone());
let mut converter = RowConverter::new(fields.clone()).unwrap();
black_box(converter.convert_columns(&cols).unwrap())
});
});

let mut converter = RowConverter::new(fields);
let mut converter = RowConverter::new(fields).unwrap();
let rows = converter.convert_columns(&cols).unwrap();
// using a pre-prepared row converter should be faster than the first time
c.bench_function(&format!("convert_columns_prepared {}", name), |b| {
Expand Down
17 changes: 5 additions & 12 deletions arrow/src/row/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ use std::collections::HashMap;
pub fn compute_dictionary_mapping(
interner: &mut OrderPreservingInterner,
values: &ArrayRef,
) -> Result<Vec<Option<Interned>>, ArrowError> {
Ok(downcast_primitive_array! {
) -> Vec<Option<Interned>> {
downcast_primitive_array! {
values => interner
.intern(values.iter().map(|x| x.map(|x| x.encode()))),
DataType::Binary => {
Expand All @@ -53,8 +53,8 @@ pub fn compute_dictionary_mapping(
let iter = as_largestring_array(values).iter().map(|x| x.map(|x| x.as_bytes()));
interner.intern(iter)
}
t => return Err(ArrowError::NotYetImplemented(format!("dictionary value {} is not supported", t))),
})
_ => unreachable!(),
}
}

/// Dictionary types are encoded as
Expand Down Expand Up @@ -173,18 +173,11 @@ pub unsafe fn decode_dictionary<K: ArrowDictionaryKeyType>(
value_type => (decode_primitive_helper, values, value_type),
DataType::Null => NullArray::new(values.len()).into_data(),
DataType::Boolean => decode_bool(&values),
DataType::Decimal128(_, _) => decode_primitive_helper!(Decimal128Type, values, value_type),
DataType::Decimal256(_, _) => decode_primitive_helper!(Decimal256Type, values, value_type),
DataType::Utf8 => decode_string::<i32>(&values),
DataType::LargeUtf8 => decode_string::<i64>(&values),
DataType::Binary => decode_binary::<i32>(&values),
DataType::LargeBinary => decode_binary::<i64>(&values),
_ => {
return Err(ArrowError::NotYetImplemented(format!(
"decoding dictionary values of {}",
value_type
)))
}
_ => unreachable!(),
};

let data_type =
Expand Down
83 changes: 53 additions & 30 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
//! let mut converter = RowConverter::new(vec![
//! SortField::new(DataType::Int32),
//! SortField::new(DataType::Utf8),
//! ]);
//! ]).unwrap();
//! let rows = converter.convert_columns(&arrays).unwrap();
//!
//! // Compare rows
Expand Down Expand Up @@ -83,7 +83,7 @@
//! .iter()
//! .map(|a| SortField::new(a.data_type().clone()))
//! .collect();
//! let mut converter = RowConverter::new(fields);
//! let mut converter = RowConverter::new(fields).unwrap();
//! let rows = converter.convert_columns(&arrays).unwrap();
//! let mut sort: Vec<_> = rows.iter().enumerate().collect();
//! sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b));
Expand Down Expand Up @@ -231,12 +231,24 @@ impl SortField {

impl RowConverter {
/// Create a new [`RowConverter`] with the provided schema
pub fn new(fields: Vec<SortField>) -> Self {
pub fn new(fields: Vec<SortField>) -> Result<Self> {
if !Self::supports_fields(&fields) {
return Err(ArrowError::NotYetImplemented(format!(
"not yet implemented: {:?}",
fields
)));
}

let interners = (0..fields.len()).map(|_| None).collect();
Self {
Ok(Self {
fields: fields.into(),
interners,
}
})
}

/// Check if the given fields are supported by the row format.
pub fn supports_fields(fields: &[SortField]) -> bool {
fields.iter().all(|x| !DataType::is_nested(&x.data_type))
}

/// Convert [`ArrayRef`] columns into [`Rows`]
Expand Down Expand Up @@ -275,7 +287,7 @@ impl RowConverter {

let interner = interner.get_or_insert_with(Default::default);

let mapping: Vec<_> = compute_dictionary_mapping(interner, values)?
let mapping: Vec<_> = compute_dictionary_mapping(interner, values)
.into_iter()
.map(|maybe_interned| {
maybe_interned.map(|interned| interner.normalized_key(interned))
Expand All @@ -286,7 +298,7 @@ impl RowConverter {
})
.collect::<Result<Vec<_>>>()?;

let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields))?;
let mut rows = new_empty_rows(columns, &dictionaries, Arc::clone(&self.fields));

for ((column, field), dictionary) in
columns.iter().zip(self.fields.iter()).zip(dictionaries)
Expand Down Expand Up @@ -492,7 +504,7 @@ fn new_empty_rows(
cols: &[ArrayRef],
dictionaries: &[Option<Vec<Option<&[u8]>>>],
fields: Arc<[SortField]>,
) -> Result<Rows> {
) -> Rows {
use fixed::FixedLengthEncoding;

let num_rows = cols.first().map(|x| x.len()).unwrap_or(0);
Expand Down Expand Up @@ -535,7 +547,7 @@ fn new_empty_rows(
}
_ => unreachable!(),
}
t => return Err(ArrowError::NotYetImplemented(format!("not yet implemented: {}", t)))
_ => unreachable!(),
}
}

Expand Down Expand Up @@ -565,11 +577,11 @@ fn new_empty_rows(

let buffer = vec![0_u8; cur_offset];

Ok(Rows {
Rows {
buffer: buffer.into(),
offsets: offsets.into(),
fields,
})
}
}

/// Encodes a column to the provided [`Rows`] incrementing the offsets as it progresses
Expand Down Expand Up @@ -605,7 +617,7 @@ fn encode_column(
column => encode_dictionary(out, column, dictionary.unwrap(), opts),
_ => unreachable!()
}
t => unimplemented!("not yet implemented: {}", t)
_ => unreachable!(),
}
}

Expand Down Expand Up @@ -747,7 +759,8 @@ mod tests {
let mut converter = RowConverter::new(vec![
SortField::new(DataType::Int16),
SortField::new(DataType::Float32),
]);
])
.unwrap();
let rows = converter.convert_columns(&cols).unwrap();

assert_eq!(rows.offsets.as_ref(), &[0, 8, 16, 24, 32, 40, 48, 56]);
Expand Down Expand Up @@ -787,7 +800,8 @@ mod tests {
fn test_decimal128() {
let mut converter = RowConverter::new(vec![SortField::new(
DataType::Decimal128(DECIMAL128_MAX_PRECISION, 7),
)]);
)])
.unwrap();
let col = Arc::new(
Decimal128Array::from_iter([
None,
Expand Down Expand Up @@ -815,7 +829,8 @@ mod tests {
fn test_decimal256() {
let mut converter = RowConverter::new(vec![SortField::new(
DataType::Decimal256(DECIMAL256_MAX_PRECISION, 7),
)]);
)])
.unwrap();
let col = Arc::new(
Decimal256Array::from_iter([
None,
Expand Down Expand Up @@ -843,7 +858,8 @@ mod tests {

#[test]
fn test_bool() {
let mut converter = RowConverter::new(vec![SortField::new(DataType::Boolean)]);
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Boolean)]).unwrap();

let col = Arc::new(BooleanArray::from_iter([None, Some(false), Some(true)]))
as ArrayRef;
Expand All @@ -862,7 +878,8 @@ mod tests {
descending: true,
nulls_first: false,
},
)]);
)])
.unwrap();

let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();
assert!(rows.row(2) < rows.row(1));
Expand All @@ -879,7 +896,7 @@ mod tests {
let d = a.data_type().clone();

let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]);
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows = converter.convert_columns(&[Arc::new(a) as _]).unwrap();
let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
Expand All @@ -905,7 +922,7 @@ mod tests {
);

assert_eq!(dict_with_tz.data_type(), &d);
let mut converter = RowConverter::new(vec![SortField::new(d.clone())]);
let mut converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();
let rows = converter
.convert_columns(&[Arc::new(dict_with_tz) as _])
.unwrap();
Expand All @@ -917,7 +934,8 @@ mod tests {
#[test]
fn test_null_encoding() {
let col = Arc::new(NullArray::new(10));
let mut converter = RowConverter::new(vec![SortField::new(DataType::Null)]);
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Null)]).unwrap();
let rows = converter.convert_columns(&[col]).unwrap();
assert_eq!(rows.num_rows(), 10);
assert_eq!(rows.row(1).data.len(), 0);
Expand All @@ -933,7 +951,8 @@ mod tests {
Some(""),
])) as ArrayRef;

let mut converter = RowConverter::new(vec![SortField::new(DataType::Utf8)]);
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Utf8)]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();

assert!(rows.row(1) < rows.row(0));
Expand All @@ -958,7 +977,8 @@ mod tests {
Some(vec![0xFF_u8; variable::BLOCK_SIZE + 1]),
])) as ArrayRef;

let mut converter = RowConverter::new(vec![SortField::new(DataType::Binary)]);
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Binary)]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();

for i in 0..rows.num_rows() {
Expand All @@ -983,7 +1003,8 @@ mod tests {
descending: true,
nulls_first: false,
},
)]);
)])
.unwrap();
let rows = converter.convert_columns(&[Arc::clone(&col)]).unwrap();

for i in 0..rows.num_rows() {
Expand Down Expand Up @@ -1017,7 +1038,7 @@ mod tests {
])) as ArrayRef;

let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]);
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows_a = converter.convert_columns(&[Arc::clone(&a)]).unwrap();

assert!(rows_a.row(3) < rows_a.row(5));
Expand Down Expand Up @@ -1052,7 +1073,8 @@ mod tests {
descending: true,
nulls_first: false,
},
)]);
)])
.unwrap();

let rows_c = converter.convert_columns(&[Arc::clone(&a)]).unwrap();
assert!(rows_c.row(3) > rows_c.row(5));
Expand All @@ -1078,7 +1100,7 @@ mod tests {
let a = builder.finish();

let mut converter =
RowConverter::new(vec![SortField::new(a.data_type().clone())]);
RowConverter::new(vec![SortField::new(a.data_type().clone())]).unwrap();
let rows = converter.convert_columns(&[Arc::new(a)]).unwrap();
assert!(rows.row(0) < rows.row(1));
assert!(rows.row(2) < rows.row(0));
Expand All @@ -1104,7 +1126,7 @@ mod tests {
.build()
.unwrap();

let mut converter = RowConverter::new(vec![SortField::new(data_type)]);
let mut converter = RowConverter::new(vec![SortField::new(data_type)]).unwrap();
let rows = converter
.convert_columns(&[Arc::new(DictionaryArray::<Int32Type>::from(data))])
.unwrap();
Expand All @@ -1119,10 +1141,11 @@ mod tests {
#[should_panic(expected = "rows were not produced by this RowConverter")]
fn test_different_converter() {
let values = Arc::new(Int32Array::from_iter([Some(1), Some(-1)]));
let mut converter = RowConverter::new(vec![SortField::new(DataType::Int32)]);
let mut converter =
RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
let rows = converter.convert_columns(&[values]).unwrap();

let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]);
let converter = RowConverter::new(vec![SortField::new(DataType::Int32)]).unwrap();
let _ = converter.convert_rows(&rows);
}

Expand Down Expand Up @@ -1266,7 +1289,7 @@ mod tests {
.map(|(o, a)| SortField::new_with_options(a.data_type().clone(), o))
.collect();

let mut converter = RowConverter::new(columns);
let mut converter = RowConverter::new(columns).unwrap();
let rows = converter.convert_columns(&arrays).unwrap();

for i in 0..len {
Expand Down

0 comments on commit 8d364fe

Please sign in to comment.