Skip to content

Commit

Permalink
Subtables are validated when validate is called (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian authored Sep 18, 2024
1 parent da80cee commit 9d41599
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 7 deletions.
4 changes: 2 additions & 2 deletions quivr/columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ def __get__(self, obj: Optional[tables.Table], objtype: type) -> Union[Self, T]:
schema = self.schema.with_metadata(metadata)

subtable = pa.Table.from_arrays(array.flatten(), schema=schema)
# We don't validate the subtable. If the parent table were validated,
# We don't validate the subtable. If the parent table was validated,
# then the subtable would have been validated as part of that process.
# If the parent table is intentionally not validated, then we don't
# want to validate the subtable either as it will throw an error
# want to validate the subtable at this time as it will throw an error
# when accessing the subtable as a column (e.g. during concatenation).
return self.table_type.from_pyarrow(subtable, permit_nulls=self.nullable, validate=False)

Expand Down
9 changes: 9 additions & 0 deletions quivr/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,9 @@ def validate(self) -> None:
validator.validate(self.table.column(name))
except errors.ValidationError as e:
raise errors.ValidationError(f"Column {name} failed validation: {str(e)}", e.failures) from e
# Validate subtables
for name in self._quivr_subtables.keys():
getattr(self, name).validate()

def invalid_mask(self) -> pa.Array:
"""Return a boolean mask indicating which rows are invalid."""
Expand All @@ -968,6 +971,12 @@ def invalid_mask(self) -> pa.Array:
for name, validator in self._column_validators.items():
indices, _ = validator.failures(self.table.column(name))
mask[indices.to_numpy()] = True

# Get invalid rows from subtables
for name in self._quivr_subtables.keys():
subtable = getattr(self, name)
subtable_mask = subtable.invalid_mask().to_numpy(zero_copy_only=False)
mask = mask | subtable_mask
return pa.array(mask, type=pa.bool_())

def separate_invalid(self) -> Tuple[Self, Self]:
Expand Down
3 changes: 3 additions & 0 deletions test/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,11 @@ class Outer(qv.Table):
name = qv.StringAttribute()

df = pd.DataFrame({"inner.x": [1, 2, 3], "y": [4, 5, 6]})
# Set the dataframe attributes
df.attrs["inner"] = {"id": 10}
table = Outer.from_flat_dataframe(df, name="foo")
assert table.name == "foo"
assert table.inner.id == 10

def test_from_kwargs(self):
table = self.MyTable.from_kwargs(name="foo", vals=[1, 2, 3])
Expand Down
8 changes: 3 additions & 5 deletions test/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,10 @@ class ValidationTable(qv.Table):
with pytest.raises(qv.ValidationError, match="Column x failed validation"):
qv.concatenate([invalid_x, valid])

have = qv.concatenate([invalid_x, valid], validate=False)
assert len(have) == 2
with pytest.raises(qv.ValidationError, match="Column y failed validation"):
qv.concatenate([valid, invalid_subtable])

# Subtables are not validated during concatenation, as the main table
# was initialized with validate=False
have = qv.concatenate([valid, invalid_subtable])
have = qv.concatenate([invalid_x, valid], validate=False)
assert len(have) == 2


Expand Down
37 changes: 37 additions & 0 deletions test/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,25 @@ class MyTable(qv.Table):
)


def test_invalid_mask_subtable():
class SubTable(qv.Table):
x = qv.Int8Column(validator=and_(lt(15), gt(10)))

class MyTable(qv.Table):
x = qv.Int8Column()
subtable = SubTable.as_column()

table = MyTable.from_kwargs(
x=[1, 2, 3, 4, 5, 6, 7, 8, 9],
subtable=SubTable.from_kwargs(x=[8, 9, 10, 11, 12, 13, 14, 15, 16], validate=False),
validate=False,
)
invalid = table.invalid_mask()
np.testing.assert_array_equal(
invalid.to_pylist(), [True, True, True, False, False, False, False, True, True]
)


def test_separate_invalid():
class MyTable(qv.Table):
x = qv.Int8Column(validator=and_(lt(15), gt(10)))
Expand All @@ -1005,6 +1024,24 @@ class MyTable(qv.Table):
np.testing.assert_array_equal(invalid.x, [8, 9, 10, 15, 16])


def test_separate_invalid_subtable():
class SubTable(qv.Table):
x = qv.Int8Column(validator=and_(lt(15), gt(10)))

class MyTable(qv.Table):
x = qv.Int8Column()
subtable = SubTable.as_column()

table = MyTable.from_kwargs(
x=[1, 2, 3, 4, 5, 6, 7, 8, 9],
subtable=SubTable.from_kwargs(x=[8, 9, 10, 11, 12, 13, 14, 15, 16], validate=False),
validate=False,
)
valid, invalid = table.separate_invalid()
np.testing.assert_array_equal(valid.x, [4, 5, 6, 7])
np.testing.assert_array_equal(invalid.x, [1, 2, 3, 8, 9])


class TestTableEqualityBenchmarks:
@pytest.mark.benchmark(group="table-equality")
def test_identical_small_tables(self, benchmark):
Expand Down

0 comments on commit 9d41599

Please sign in to comment.