Skip to content

Commit

Permalink
Merge pull request #53 from moeyensj/jm/nested-column-names
Browse files Browse the repository at this point in the history
Support nested column names in .select and .column
  • Loading branch information
spenczar authored Oct 2, 2023
2 parents dcc3051 + 697e278 commit 16e92b3
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 4 deletions.
44 changes: 41 additions & 3 deletions quivr/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,29 @@ def flattened_table(self) -> pa.Table:
def select(self, column_name: str, value: Any) -> Self:
"""Select from the table by exact match, returning a new
Table which only contains rows for which the value in
column_name equals value.
column_name equals value. Column_name can be a nested column,
in which case, this function will recursively search for
the column through the nested subtables using dot-delimited notation.
:param column_name: The name of the column to select on.
Use dot-delimited notation for nested columns.
:param value: The value to match.
Examples:
>>> import quivr as qv
>>> import pyarrow.compute as pc
>>> class MySubTable(qv.Table):
... x = qv.Int64Column()
... y = qv.Int64Column()
>>> class MyWrapperTable(qv.Table):
... child = MySubTable.as_column()
>>> c = MySubTable.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
>>> p = MyWrapperTable.from_kwargs(child=c)
>>> p_select = p.select("child.x", 2)
>>> print(p_select.child.x.to_pylist())
[2]
"""
table = self.table.filter(pc.field(column_name) == value)
return self.__class__(table)
return self.apply_mask(pc.equal(self.column(column_name), value))

def sort_by(self, by: Union[str, list[tuple[str, str]]]) -> Self:
"""Sorts the Table by the given column name (or multiple
Expand Down Expand Up @@ -668,9 +684,31 @@ def to_dataframe(self, flatten: bool = True) -> pd.DataFrame:
def column(self, column_name: str) -> pa.ChunkedArray:
"""
Returns the column with the given name as a raw pyarrow ChunkedArray.
Column_name can be a nested column, in which case, this function will recursively
search for the column through the nested subtables using dot-delimited notation.
:param column_name: The name of the column to return.
Use dot-delimited notation for nested columns.
Examples:
>>> import quivr as qv
>>> import pyarrow.compute as pc
>>> class MySubTable(qv.Table):
... x = qv.Int64Column()
... y = qv.Int64Column()
>>> class MyWrapperTable(qv.Table):
... child = MySubTable.as_column()
>>> c = MySubTable.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
>>> p = MyWrapperTable.from_kwargs(child=c)
>>> column_x = p.column("child.x")
>>> print(column_x.to_pylist())
[1, 2, 3]
"""
if "." in column_name:
column_name, subkey = column_name.split(".", 1)
subtable = getattr(self, column_name)
return subtable.column(subkey)

return self.table.column(column_name)

def __repr__(self) -> str:
Expand Down
278 changes: 277 additions & 1 deletion test/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,140 @@ def test_select():
assert have.y[0].as_py() == 6


def test_select_empty():
def test_select_nested():
pair = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
wrapper = Wrapper.from_kwargs(
id=["1", "2", "3"],
pair=pair,
)
have = wrapper.select("pair.x", 3)
assert len(have) == 1
assert have.id[0].as_py() == "3"
assert have.pair.y[0].as_py() == 6


def test_select_nested_doubly():
class DoublyNested(qv.Table):
inner = Wrapper.as_column()

dn = DoublyNested.from_kwargs(
inner=Wrapper.from_kwargs(id=["a", "b", "c"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6]))
)
have = dn.select("inner.pair.x", 3)
assert len(have) == 1
assert have.inner.id[0].as_py() == "c"
assert have.inner.pair.y[0].as_py() == 6


def test_select_attributes():
class PairWithAttributes(qv.Table):
x = qv.Int64Column()
y = qv.Int64Column()
label = qv.StringAttribute()

pair = PairWithAttributes.from_kwargs(x=[1, 2, 3], y=[4, 5, 6], label="foo")
have = pair.select("x", 3)
assert len(have) == 1
assert have.y[0].as_py() == 6
assert have.label == "foo"


def test_select_nested_attributes():
class PairWithAttributes(qv.Table):
x = qv.Int64Column()
y = qv.Int64Column()
label = qv.StringAttribute()

class WrapperWithAttributes(qv.Table):
pair = PairWithAttributes.as_column()
id = qv.StringColumn()
label = qv.StringAttribute()

pair = PairWithAttributes.from_kwargs(x=[1, 2, 3], y=[4, 5, 6], label="foo")
wrapper = WrapperWithAttributes.from_kwargs(
id=["1", "2", "3"],
pair=pair,
label="bar",
)
have = wrapper.select("pair.x", 3)
assert len(have) == 1
assert have.id[0].as_py() == "3"
assert have.pair.y[0].as_py() == 6
assert have.label == "bar"
assert have.pair.label == "foo"


def test_select_nested_doubly_attributes():
class PairWithAttributes(qv.Table):
x = qv.Int64Column()
y = qv.Int64Column()
label = qv.StringAttribute()

class WrapperWithAttributes(qv.Table):
pair = PairWithAttributes.as_column()
id = qv.StringColumn()
label = qv.StringAttribute()

class DoublyNestedWithAttributes(qv.Table):
inner = WrapperWithAttributes.as_column()
label = qv.StringAttribute()

dn = DoublyNestedWithAttributes.from_kwargs(
inner=WrapperWithAttributes.from_kwargs(
id=["a", "b", "c"],
pair=PairWithAttributes.from_kwargs(x=[1, 2, 3], y=[4, 5, 6], label="foo"),
label="bar",
),
label="baz",
)

have = dn.select("inner.pair.x", 3)
assert len(have) == 1
assert have.inner.id[0].as_py() == "c"
assert have.inner.pair.y[0].as_py() == 6
assert have.inner.label == "bar"
assert have.inner.pair.label == "foo"
assert have.label == "baz"


def test_select_invalid_value():
pair = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
have = pair.select("x", 4)
assert len(have) == 0


def test_select_nested_invalid_value():
pair = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
wrapper = Wrapper.from_kwargs(
id=["1", "2", "3"],
pair=pair,
)
have = wrapper.select("pair.x", 4)
assert len(have) == 0

have = wrapper.select("id", "4")
assert len(have) == 0


def test_select_nested_doubly_invalid_value():
class DoublyNested(qv.Table):
id = qv.StringColumn()
inner = Wrapper.as_column()

dn = DoublyNested.from_kwargs(
id=["1", "2", "3"],
inner=Wrapper.from_kwargs(id=["1", "2", "3"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])),
)
have = dn.select("inner.pair.x", 4)
assert len(have) == 0

have = dn.select("inner.id", "4")
assert len(have) == 0

have = dn.select("id", "d")
assert len(have) == 0


def test_sort_by():
pair = Pair.from_kwargs(x=[1, 2, 3], y=[5, 1, 2])

Expand Down Expand Up @@ -778,6 +906,154 @@ class T5(qv.Table):
_column_validators = qv.StringColumn()


def test_column():
t = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
assert pc.all(pc.equal(t.x, t.column("x")))


def test_column_nested():
w = Wrapper.from_kwargs(id=["a", "b", "c"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6]))
assert pc.all(pc.equal(w.pair.x, w.column("pair.x")))
assert pc.all(pc.equal(w.pair.y, w.column("pair.y")))
assert pc.all(pc.equal(w.id, w.column("id")))


def test_column_nested_doubly():
class DoublyNested(qv.Table):
inner = Wrapper.as_column()

dn = DoublyNested.from_kwargs(
inner=Wrapper.from_kwargs(id=["a", "b", "c"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6]))
)
assert pc.all(pc.equal(dn.inner.pair.x, dn.column("inner.pair.x")))
assert pc.all(pc.equal(dn.inner.pair.y, dn.column("inner.pair.y")))
assert pc.all(pc.equal(dn.inner.id, dn.column("inner.id")))


def test_column_nulls():
class PairWithNulls(qv.Table):
x = qv.Int64Column(nullable=True)
y = qv.Int64Column(nullable=True)

t = PairWithNulls.from_kwargs(y=[4, 5, 6])
assert pc.all(pc.equal(t.x, t.column("x")))
assert pc.all(pc.equal(t.y, t.column("y")))

t = PairWithNulls.from_kwargs(x=[1, 2, 3])
assert pc.all(pc.equal(t.x, t.column("x")))
assert pc.all(pc.equal(t.y, t.column("y")))


def test_column_nested_nulls():
class PairWithNulls(qv.Table):
x = qv.Int64Column(nullable=True)
y = qv.Int64Column(nullable=True)

class WrapperWithNulls(qv.Table):
id = qv.StringColumn()
pair = PairWithNulls.as_column(nullable=True)

# Null grandchild
w = WrapperWithNulls.from_kwargs(id=["a", "b", "c"], pair=PairWithNulls.from_kwargs(y=[4, 5, 6]))
assert pc.all(pc.equal(w.pair.x, w.column("pair.x")))
assert pc.all(pc.equal(w.pair.y, w.column("pair.y")))
assert pc.all(pc.equal(w.id, w.column("id")))

# Null child
w = WrapperWithNulls.from_kwargs(id=["a", "b", "c"])
assert pc.all(pc.equal(w.pair.x, w.column("pair.x")))
assert pc.all(pc.equal(w.pair.y, w.column("pair.y")))
assert pc.all(pc.equal(w.id, w.column("id")))


def test_column_nested_doubly_nulls():
class PairWithNulls(qv.Table):
x = qv.Int64Column(nullable=True)
y = qv.Int64Column(nullable=True)

class WrapperWithNulls(qv.Table):
id = qv.StringColumn(nullable=True)
pair = PairWithNulls.as_column(nullable=True)

class DoublyNestedWithNulls(qv.Table):
id = qv.StringColumn()
inner = WrapperWithNulls.as_column(nullable=True)

# Null great-grandchild
dn = DoublyNestedWithNulls.from_kwargs(
id=["a", "b", "c"],
inner=WrapperWithNulls.from_kwargs(id=["a", "b", "c"], pair=PairWithNulls.from_kwargs(y=[4, 5, 6])),
)
assert pc.all(pc.equal(dn.inner.pair.x, dn.column("inner.pair.x")))
assert pc.all(pc.equal(dn.inner.pair.y, dn.column("inner.pair.y")))
assert pc.all(pc.equal(dn.inner.id, dn.column("inner.id")))
assert pc.all(pc.equal(dn.id, dn.column("id")))

# Null grandchild
dn = DoublyNestedWithNulls.from_kwargs(
id=["a", "b", "c"], inner=WrapperWithNulls.from_kwargs(id=["a", "b", "c"])
)
assert pc.all(pc.equal(dn.inner.pair.x, dn.column("inner.pair.x")))
assert pc.all(pc.equal(dn.inner.pair.y, dn.column("inner.pair.y")))
assert pc.all(pc.equal(dn.inner.id, dn.column("inner.id")))
assert pc.all(pc.equal(dn.id, dn.column("id")))

# Null child
dn = DoublyNestedWithNulls.from_kwargs(id=["a", "b", "c"])
assert pc.all(pc.equal(dn.inner.pair.x, dn.column("inner.pair.x")))
assert pc.all(pc.equal(dn.inner.pair.y, dn.column("inner.pair.y")))
assert pc.all(pc.equal(dn.inner.id, dn.column("inner.id")))
assert pc.all(pc.equal(dn.id, dn.column("id")))


def test_column_empty():
t = Pair.empty()
assert len(t.column("x")) == 0
assert len(t.column("y")) == 0


def test_column_nested_empty():
w = Wrapper.empty()
assert len(w.column("pair.x")) == 0
assert len(w.column("pair.y")) == 0


def test_column_nested_doubly_empty():
class DoublyNested(qv.Table):
inner = Wrapper.as_column()

dn = DoublyNested.empty()
assert len(dn.column("inner.pair.x")) == 0
assert len(dn.column("inner.pair.y")) == 0


def test_column_invalid_name():
t = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
with pytest.raises(KeyError):
t.column("z")


def test_column_nested_invalid_name():
w = Wrapper.from_kwargs(id=["a", "b", "c"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6]))
with pytest.raises(KeyError):
w.column("pair.z")
with pytest.raises(AttributeError):
w.column("wrong.x")


def test_column_nested_doubly_invalid_name():
class DoublyNested(qv.Table):
inner = Wrapper.as_column()

dn = DoublyNested.from_kwargs(
inner=Wrapper.from_kwargs(id=["a", "b", "c"], pair=Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6]))
)
with pytest.raises(KeyError):
dn.column("inner.pair.z")
with pytest.raises(AttributeError):
dn.column("inner.wrong.x")


def test_set_column():
t = Pair.from_kwargs(x=[1, 2, 3], y=[4, 5, 6])
t2 = t.set_column("x", [7, 8, 9])
Expand Down

0 comments on commit 16e92b3

Please sign in to comment.