Skip to content

Commit

Permalink
add split method
Browse files Browse the repository at this point in the history
  • Loading branch information
yymao committed Dec 12, 2021
1 parent e92abb2 commit 80cc22f
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
41 changes: 40 additions & 1 deletion easyquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numexpr as ne

__all__ = ['Query', 'QueryMaker']
__version__ = '0.3.0'
__version__ = '0.4.0'


def _is_string_like(obj):
Expand All @@ -34,6 +34,7 @@ class Query(object):
A Query object has three major methods: filter, count, and mask.
All of them operate on NumPy structured array and astropy Table:
- `filter` returns a new table that only has entries satisfying the query;
- `split` returns two new tables that has entries satisfying and not satisfying the query, respectively;
- `count` returns the number of entries satisfying the query;
- `mask` returns a bool array for masking the table;
- `where` returns a int array for the indices that select satisfying entries.
Expand Down Expand Up @@ -303,6 +304,26 @@ def where(self, table):

return np.flatnonzero(self.mask(table))

def split(self, table, column_slice=None):
"""
Split the `table` into two parts: satisfying and not satisfy the queries.
The function will return q.filter(table), (~q).filter(table)
where `q` is the current Query object.
Parameters
----------
table : NumPy structured array, astropy Table, etc.
Returns
-------
table_true : filtered table, satisfying the queries
table_false : filtered table, not satisfying the queries
"""
mask = self.mask(table)
if column_slice is not None:
table = self._get_table_column(table, column_slice)
return self._mask_table(table, mask), self._mask_table(table, ~mask)

def copy(self):
"""
Create a copy of the current Query object.
Expand Down Expand Up @@ -437,6 +458,24 @@ def where(table, *queries):
return _query_class(*queries).where(table)


def split(table, *queries):
"""
A convenient function to split `table` into satisfying and non-satisfying parts.
Equivalent to `Query(*queries).split(table)`
Parameters
----------
table : NumPy structured array, astropy Table, etc.
queries : string, tuple, callable
Returns
-------
table_true : filtered table, satisfying the queries
table_false : filtered table, not satisfying the queries
"""
return _query_class(*queries).split(table)


class QueryMaker():
"""
provides convenience functions to generate query objects
Expand Down
16 changes: 16 additions & 0 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ def check_query_on_table(table, query_object, true_mask=None):
if true_mask is None:
true_mask = np.ones(len(table), bool)

stable1, stable2 = query_object.split(table)

assert (query_object.filter(table) == table[true_mask]).all(), 'filter not correct'
assert (stable1 == table[true_mask]).all(), 'split not correct'
assert (stable2 == table[~true_mask]).all(), 'split not correct'
assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
assert (query_object.mask(table) == true_mask).all(), 'mask not correct'
assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct'
Expand All @@ -62,8 +66,17 @@ def check_query_on_dict_table(table, query_object, true_mask=None):

ftable = query_object.filter(table)
ftable_true = {k: table[k][true_mask] for k in table}

stable1, stable2 = query_object.split(table)
stable1_true = ftable_true
stable2_true = {k: table[k][~true_mask] for k in table}

assert set(ftable) == set(ftable_true), 'filter not correct'
assert all((ftable[k] == ftable_true[k]).all() for k in ftable), 'filter not correct'
assert set(stable1) == set(stable1_true), 'split not correct'
assert all((stable1[k] == stable1_true[k]).all() for k in ftable), 'split not correct'
assert set(stable2) == set(stable2_true), 'split not correct'
assert all((stable2[k] == stable2_true[k]).all() for k in ftable), 'split not correct'
assert query_object.count(table) == np.count_nonzero(true_mask), 'count not correct'
assert (query_object.mask(table) == true_mask).all(), 'mask not correct'
assert (query_object.where(table) == np.flatnonzero(true_mask)).all(), 'where not correct'
Expand Down Expand Up @@ -159,10 +172,13 @@ def test_filter_column_slice():
t = gen_test_table()
q = Query('a > 2')
assert (q.filter(t, 'b') == t['b'][t['a'] > 2]).all()
assert (q.split(t, 'b')[1] == t['b'][~(t['a'] > 2)]).all()
q = Query('a > 2', 'b < 2')
assert (q.filter(t, 'c') == t['c'][(t['a'] > 2) & (t['b'] < 2)]).all()
assert (q.split(t, 'c')[1] == t['c'][~((t['a'] > 2) & (t['b'] < 2))]).all()
q = Query(None)
assert (q.filter(t, 'a') == t['a']).all()
assert len(q.split(t, 'a')[1]) == 0


def test_query_maker():
Expand Down

0 comments on commit 80cc22f

Please sign in to comment.