Skip to content

Commit

Permalink
feat: allow inserting subschemas (#3041)
Browse files Browse the repository at this point in the history
Allow inserting subset of columns in the schema, if missing columns are
nullable. Missing columns will be filled with null values. This even
works with nested fields.

For example:

```python
import lance
import pyarrow as pa

data = [
    {"vec": [1.0, 2.0, 3.0], "metadata": {"x": 1, "y": 2}},
    {"metadata": {"x": 3}},
    {"vec": [2.0, 3.0, 5.0], "metadata": {"y": 4}},
]
table = pa.Table.from_pylist(data)
ds = lance.write_dataset(table, "./demo")
ds.to_table().to_pandas()
```
```
               vec               metadata
0  [1.0, 2.0, 3.0]   {'x': 1.0, 'y': 2.0}
1             None  {'x': 3.0, 'y': None}
2  [2.0, 3.0, 5.0]  {'x': None, 'y': 4.0}
```

```python
new_data = [
    {"metadata": {"y": 6}, "vec": [1.0, 2.0, 3.0]},
]
new_table = pa.Table.from_pylist(new_data)
ds = lance.write_dataset(new_table, "./demo", mode="append")
ds.to_table().to_pandas()
```
```
               vec               metadata
0  [1.0, 2.0, 3.0]   {'x': 1.0, 'y': 2.0}
1             None  {'x': 3.0, 'y': None}
2  [2.0, 3.0, 5.0]  {'x': None, 'y': 4.0}
3  [1.0, 2.0, 3.0]    {'x': None, 'y': 6}
```

Closes #3016
  • Loading branch information
wjones127 authored Nov 7, 2024
1 parent 0abf7d4 commit 6d24d84
Show file tree
Hide file tree
Showing 19 changed files with 842 additions and 239 deletions.
18 changes: 0 additions & 18 deletions python/python/tests/test_balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,3 @@ def test_unsupported(balanced_dataset, big_val):
balanced_dataset.merge_insert("idx").when_not_matched_insert_all().execute(
make_table(0, 1, big_val)
)


# TODO: Once https://github.com/lancedb/lance/pull/3041 merges we will
# want to test partial appends. We need to make sure an append of
# non-blob data is supported. In order to do this we need to make
# sure a blob tx is created that marks the row ids as used so that
# the two row id sequences stay in sync.
#
# def test_one_sided_append(balanced_dataset, tmp_path):
# # Write new data, but only to the idx column
# ds = lance.write_dataset(
# pa.table({"idx": pa.array(range(128, 256), pa.uint64())}),
# tmp_path / "test_ds",
# max_bytes_per_file=32 * 1024 * 1024,
# mode="append",
# )

# print(ds.to_table())
7 changes: 7 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def test_dataset_append(tmp_path: Path):
with pytest.raises(OSError):
lance.write_dataset(table2, base_dir, mode="append")

# But we can append subschemas
table3 = pa.Table.from_pydict({"colA": [4, 5, 6]})
dataset = lance.write_dataset(table3, base_dir, mode="append")
assert dataset.to_table() == pa.table(
{"colA": [1, 2, 3, 4, 5, 6], "colB": [4, 5, 6, None, None, None]}
)


def test_dataset_from_record_batch_iterable(tmp_path: Path):
base_dir = tmp_path / "test"
Expand Down
2 changes: 1 addition & 1 deletion python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ impl Dataset {
let dict = PyDict::new(py);
let schema = self_.ds.schema();

let idx_schema = schema.project_by_ids(idx.fields.as_slice());
let idx_schema = schema.project_by_ids(idx.fields.as_slice(), true);

let is_vector = idx_schema
.fields
Expand Down
2 changes: 1 addition & 1 deletion python/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl PrettyPrintableFragment {
.files
.iter()
.map(|file| {
let schema = schema.project_by_ids(&file.fields);
let schema = schema.project_by_ids(&file.fields, false);
PrettyPrintableDataFile {
path: file.path.clone(),
fields: file.fields.clone(),
Expand Down
5 changes: 4 additions & 1 deletion rust/lance-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ mod field;
mod schema;

use crate::{Error, Result};
pub use field::{Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass};
pub use field::{
Encoding, Field, NullabilityComparison, SchemaCompareOptions, StorageClass,
LANCE_STORAGE_CLASS_SCHEMA_META_KEY,
};
pub use schema::Schema;

pub const COMPRESSION_META_KEY: &str = "lance-encoding:compression";
Expand Down
127 changes: 45 additions & 82 deletions rust/lance-core/src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::{
cmp::max,
collections::{HashMap, HashSet},
collections::HashMap,
fmt::{self, Display},
str::FromStr,
sync::Arc,
Expand All @@ -23,7 +23,10 @@ use deepsize::DeepSizeOf;
use lance_arrow::{bfloat16::ARROW_EXT_NAME_KEY, *};
use snafu::{location, Location};

use super::{Dictionary, LogicalType};
use super::{
schema::{compare_fields, explain_fields_difference},
Dictionary, LogicalType,
};
use crate::{Error, Result};

pub const LANCE_STORAGE_CLASS_SCHEMA_META_KEY: &str = "lance-schema:storage-class";
Expand All @@ -49,6 +52,14 @@ pub struct SchemaCompareOptions {
pub compare_field_ids: bool,
/// Should nullability be compared (default Strict)
pub compare_nullability: NullabilityComparison,
/// Allow fields in the expected schema to be missing from the schema being tested if
/// they are nullable (default false)
///
/// Fields in the schema being tested must always be present in the expected schema
/// regardless of this flag.
pub allow_missing_if_nullable: bool,
/// Allow out of order fields (default false)
pub ignore_field_order: bool,
}
/// Encoding enum.
#[derive(Debug, Clone, PartialEq, Eq, DeepSizeOf)]
Expand Down Expand Up @@ -151,7 +162,7 @@ impl Field {
self.storage_class
}

fn explain_differences(
pub(crate) fn explain_differences(
&self,
expected: &Self,
options: &SchemaCompareOptions,
Expand Down Expand Up @@ -210,61 +221,19 @@ impl Field {
self_name
));
}
if self.children.len() != expected.children.len()
|| !self
.children
.iter()
.zip(expected.children.iter())
.all(|(child, expected)| child.name == expected.name)
{
let self_children = self
.children
.iter()
.map(|child| child.name.clone())
.collect::<HashSet<_>>();
let expected_children = expected
.children
.iter()
.map(|child| child.name.clone())
.collect::<HashSet<_>>();
let missing = expected_children
.difference(&self_children)
.cloned()
.collect::<Vec<_>>();
let unexpected = self_children
.difference(&expected_children)
.cloned()
.collect::<Vec<_>>();
if missing.is_empty() && unexpected.is_empty() {
differences.push(format!(
"`{}` field order mismatch, expected [{}] but was [{}]",
self_name,
expected
.children
.iter()
.map(|child| child.name.clone())
.collect::<Vec<_>>()
.join(", "),
self.children
.iter()
.map(|child| child.name.clone())
.collect::<Vec<_>>()
.join(", "),
));
} else {
differences.push(format!(
"`{}` had mismatched children, missing=[{}] unexpected=[{}]",
self_name,
missing.join(", "),
unexpected.join(", ")
));
}
} else {
differences.extend(self.children.iter().zip(expected.children.iter()).flat_map(
|(child, expected_child)| {
child.explain_differences(expected_child, options, Some(&self_name))
},
));
let children_differences = explain_fields_difference(
&self.children,
&expected.children,
options,
Some(&self_name),
);
if !children_differences.is_empty() {
let children_differences = format!(
"`{}` had mismatched children: {}",
self_name,
children_differences.join(", ")
);
differences.push(children_differences);
}
differences
}
Expand Down Expand Up @@ -295,22 +264,13 @@ impl Field {
}

pub fn compare_with_options(&self, expected: &Self, options: &SchemaCompareOptions) -> bool {
if self.children.len() != expected.children.len() {
false
} else {
self.name == expected.name
&& self.logical_type == expected.logical_type
&& Self::compare_nullability(expected.nullable, self.nullable, options)
&& self.children.len() == expected.children.len()
&& self
.children
.iter()
.zip(&expected.children)
.all(|(left, right)| left.compare_with_options(right, options))
&& (!options.compare_field_ids || self.id == expected.id)
&& (!options.compare_dictionary || self.dictionary == expected.dictionary)
&& (!options.compare_metadata || self.metadata == expected.metadata)
}
self.name == expected.name
&& self.logical_type == expected.logical_type
&& Self::compare_nullability(expected.nullable, self.nullable, options)
&& compare_fields(&self.children, &expected.children, options)
&& (!options.compare_field_ids || self.id == expected.id)
&& (!options.compare_dictionary || self.dictionary == expected.dictionary)
&& (!options.compare_metadata || self.metadata == expected.metadata)
}

pub fn extension_name(&self) -> Option<&str> {
Expand Down Expand Up @@ -476,13 +436,13 @@ impl Field {
///
/// If the ids are `[2]`, then this will include the parent `0` and the
/// child `3`.
pub(crate) fn project_by_ids(&self, ids: &[i32]) -> Option<Self> {
pub(crate) fn project_by_ids(&self, ids: &[i32], include_all_children: bool) -> Option<Self> {
let children = self
.children
.iter()
.filter_map(|c| c.project_by_ids(ids))
.filter_map(|c| c.project_by_ids(ids, include_all_children))
.collect::<Vec<_>>();
if ids.contains(&self.id) {
if ids.contains(&self.id) && (children.is_empty() || include_all_children) {
Some(self.clone())
} else if !children.is_empty() {
Some(Self {
Expand Down Expand Up @@ -1177,7 +1137,10 @@ mod tests {
.unwrap();
assert_eq!(
wrong_child.explain_difference(&expected, &opts),
Some("`a.b` should have nullable=true but nullable=false".to_string())
Some(
"`a` had mismatched children: `a.b` should have nullable=true but nullable=false"
.to_string()
)
);

let mismatched_children: Field = ArrowField::new(
Expand All @@ -1192,13 +1155,13 @@ mod tests {
.unwrap();
assert_eq!(
mismatched_children.explain_difference(&expected, &opts),
Some("`a` had mismatched children, missing=[c] unexpected=[d]".to_string())
Some("`a` had mismatched children: fields did not match, missing=[a.c], unexpected=[a.d]".to_string())
);

let reordered_children: Field = ArrowField::new(
"a",
DataType::Struct(Fields::from(vec![
ArrowField::new("c", DataType::Int32, false),
ArrowField::new("c", DataType::Int32, true),
ArrowField::new("b", DataType::Int32, true),
])),
true,
Expand All @@ -1207,7 +1170,7 @@ mod tests {
.unwrap();
assert_eq!(
reordered_children.explain_difference(&expected, &opts),
Some("`a` field order mismatch, expected [b, c] but was [c, b]".to_string())
Some("`a` had mismatched children: fields in different order, expected: [b, c], actual: [c, b]".to_string())
);

let multiple_wrongs: Field = ArrowField::new(
Expand All @@ -1223,7 +1186,7 @@ mod tests {
assert_eq!(
multiple_wrongs.explain_difference(&expected, &opts),
Some(
"expected name 'a' but name was 'c', `c.c` should have type int32 but type was float"
"expected name 'a' but name was 'c', `c` had mismatched children: `c.c` should have type int32 but type was float"
.to_string()
)
);
Expand Down
Loading

0 comments on commit 6d24d84

Please sign in to comment.