Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/79 handle unseen fields correctly #82

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
`TracingOptions::default().coerce_numbers(true)`)
- Add support for `Timestamp(Seconds, None)` and
`Timestamp(Seconds, Some("UTC"))`.
- Raise an error if resulting arrays are of unequal length (#78)
- Fix bug in bytecode serialization for missing fields (#79)
- Handle nullable top-level fields correctly in bytecode serialization
- Fix bug in bytecode serialization for out of order fields (#80)

## 0.7.1

Expand Down
12 changes: 12 additions & 0 deletions serde_arrow/src/arrow/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ impl Interpreter {
res.push(array);
}
self.buffers.clear();

let max_len = res.iter().map(|a| a.len()).max().unwrap_or_default();
for (arr, mapping) in res.iter().zip(&self.structure.array_mapping) {
if arr.len() != max_len {
fail!("
Unbalanced array lengths: array {name} has length {len}, but expected {max_len}",
name = mapping.get_field().name,
len = arr.len(),
);
}
}

Ok(res)
}

Expand Down
12 changes: 12 additions & 0 deletions serde_arrow/src/arrow2/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ impl Interpreter {
res.push(array);
}
self.buffers.clear();

let max_len = res.iter().map(|a| a.len()).max().unwrap_or_default();
for (arr, mapping) in res.iter().zip(&self.structure.array_mapping) {
if arr.len() != max_len {
fail!(
"Unbalanced array lengths: array {name} has length {len}, but expected {max_len}",
name = mapping.get_field().name,
len = arr.len(),
);
}
}

Ok(res)
}

Expand Down
22 changes: 20 additions & 2 deletions serde_arrow/src/internal/serialization/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,13 @@ define_bytecode!(
self_pos: usize,
struct_idx: usize,
field_name: String,
field_idx: usize,
seen: usize,
},
OuterRecordEnd {
self_pos: usize,
struct_idx: usize,
seen: usize,
},
LargeListItem {
list_idx: usize,
Expand Down Expand Up @@ -514,19 +517,25 @@ impl Program {
});
self.structure.large_lists[0].item = self.structure.program.len();

let seen: usize;
if self.options.wrap_with_struct {
seen = self.buffers.num_seen.next_value();
self.structure.structs.push(StructDefinition::default());
self.push_instr(OuterRecordStart { next: UNSET_INSTR });
}
} else {
seen = usize::MAX;
};

for (field_idx, field) in fields.iter().enumerate() {
if self.options.wrap_with_struct {
let self_pos = self.structure.program.len();
self.push_instr(OuterRecordField {
next: UNSET_INSTR,
self_pos,
seen,
struct_idx: 0,
field_name: field.name.to_string(),
field_idx,
});
self.structure.structs[0].fields.insert(
field.name.to_string(),
Expand All @@ -537,7 +546,15 @@ impl Program {
},
);
}
let (f, _) = self.compile_field(field)?;
let (f, null_definition) = self.compile_field(field)?;

if self.options.wrap_with_struct {
let field_def = self.structure.structs[0]
.fields
.get_mut(&field.name)
.ok_or_else(|| error!("compile error: could not read struct field"))?;
field_def.null_definition = null_definition;
}

self.structure.array_mapping.push(f);
}
Expand All @@ -548,6 +565,7 @@ impl Program {
next: UNSET_INSTR,
struct_idx: 0,
self_pos,
seen,
});
self.structure.structs[0].r#return = self.structure.program.len();
}
Expand Down
15 changes: 10 additions & 5 deletions serde_arrow/src/internal/serialization/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,9 @@ impl Instruction for OuterRecordField {
fn accept_end_struct(
&self,
structure: &Structure,
_buffers: &mut MutableBuffers,
buffers: &mut MutableBuffers,
) -> Result<usize> {
struct_end(structure, buffers, self.struct_idx, self.seen)?;
Ok(structure.structs[self.struct_idx].r#return)
}

Expand All @@ -641,15 +642,17 @@ impl Instruction for OuterRecordField {
fn accept_str(
&self,
structure: &Structure,
_buffers: &mut MutableBuffers,
buffers: &mut MutableBuffers,
val: &str,
) -> Result<usize> {
if self.field_name == val {
buffers.seen[self.seen].insert(self.field_idx);
Ok(self.next)
} else {
let Some(field_def) = structure.structs[self.struct_idx].fields.get(val) else {
fail!("Cannot find field {val} in struct {idx}", idx=self.struct_idx);
};
buffers.seen[self.seen].insert(field_def.index);
Ok(field_def.jump)
}
}
Expand All @@ -661,9 +664,10 @@ impl Instruction for OuterRecordEnd {

fn accept_end_struct(
&self,
_structure: &Structure,
_buffers: &mut MutableBuffers,
structure: &Structure,
buffers: &mut MutableBuffers,
) -> Result<usize> {
struct_end(structure, buffers, self.struct_idx, self.seen)?;
Ok(self.next)
}

Expand All @@ -674,12 +678,13 @@ impl Instruction for OuterRecordEnd {
fn accept_str(
&self,
structure: &Structure,
_buffers: &mut MutableBuffers,
buffers: &mut MutableBuffers,
val: &str,
) -> Result<usize> {
let Some(field_def) = structure.structs[self.struct_idx].fields.get(val) else {
fail!("cannot find field {val:?} in struct {idx}", idx=self.struct_idx);
};
buffers.seen[self.seen].insert(field_def.index);
Ok(field_def.jump)
}

Expand Down
52 changes: 52 additions & 0 deletions serde_arrow/src/test_impls/issue_79.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use super::macros::test_generic;

test_generic!(
fn declared_but_missing_fields() {
use serde::Serialize;

#[derive(Serialize)]
struct S {
a: u8,
}

let items = [S { a: 0 }, S { a: 1 }];

let fields = vec![
Field::try_from(&GenericField::new("a", GenericDataType::U8, false)).unwrap(),
Field::try_from(&GenericField::new("b", GenericDataType::U8, true)).unwrap(),
];

let arrays = serialize_into_arrays(&fields, &items).unwrap();

assert_eq!(arrays.len(), 2);
assert_eq!(arrays[0].len(), 2);
assert_eq!(arrays[1].len(), 2);
}
);

test_generic!(
fn declared_but_missing_fields_non_nullable() {
use serde::Serialize;

#[derive(Serialize)]
struct S {
a: u8,
}

let items = [S { a: 0 }, S { a: 1 }];

let fields = vec![
Field::try_from(&GenericField::new("a", GenericDataType::U8, false)).unwrap(),
Field::try_from(&GenericField::new("b", GenericDataType::U8, false)).unwrap(),
];

let Err(err) = serialize_into_arrays(&fields, &items) else {
panic!("Expected error");
};
assert!(
err.to_string()
.contains("missing non-nullable field b in struct"),
"unexpected error: {err}"
);
}
);
2 changes: 2 additions & 0 deletions serde_arrow/src/test_impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ mod tuple;
mod r#union;
mod utils;
mod wrappers;

mod issue_79;