diff --git a/.circleci/config.yml b/.circleci/config.yml index 321d947..6bae071 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -77,7 +77,13 @@ jobs: name: run tests command: | . venv/bin/activate - REDIS_PORT=6379 python test/test.py + REDIS_PORT=6379 python test/test.py + + - run: + name: run query builder tests + command: | + . venv/bin/activate + python test/test.py # no need for store_artifacts on nightly builds diff --git a/redisearch/aggregation.py b/redisearch/aggregation.py index b767317..16e5eab 100644 --- a/redisearch/aggregation.py +++ b/redisearch/aggregation.py @@ -99,15 +99,58 @@ def __init__(self, fields, reducers): self.limit = Limit() def build_args(self): - ret = [str(len(self.fields))] + ret = ['GROUPBY', str(len(self.fields))] ret.extend(self.fields) for reducer in self.reducers: ret += ['REDUCE', reducer.NAME, str(len(reducer.args))] ret.extend(reducer.args) - if reducer._alias: + if reducer._alias is not None: ret += ['AS', reducer._alias] return ret +class Projection(object): + """ + This object automatically created in the `AggregateRequest.apply()` + """ + + def __init__(self, projector, alias=None ): + + self.alias = alias + self.projector = projector + + def build_args(self): + ret = ['APPLY', self.projector] + if self.alias is not None: + ret += ['AS', self.alias] + + return ret + +class SortBy(object): + """ + This object automatically created in the `AggregateRequest.sort_by()` + """ + + def __init__(self, fields, max=0): + self.fields = fields + self.max = max + + + + def build_args(self): + fields_args = [] + for f in self.fields: + if isinstance(f, SortDirection): + fields_args += [f.field, f.DIRSTRING] + else: + fields_args += [f] + + ret = ['SORTBY', str(len(fields_args))] + ret.extend(fields_args) + if self.max > 0: + ret += ['MAX', str(self.max)] + + return ret + class AggregateRequest(object): """ @@ -127,11 +170,9 @@ def __init__(self, query='*'): return the object itself, making them useful for chaining. """ self._query = query - self._groups = [] - self._projections = [] + self._aggregateplan = [] self._loadfields = [] self._limit = Limit() - self._sortby = [] self._max = 0 self._with_schema = False self._verbatim = False @@ -162,7 +203,7 @@ def group_by(self, fields, *reducers): `aggregation` module. """ group = Group(fields, reducers) - self._groups.append(group) + self._aggregateplan.extend(group.build_args()) return self @@ -177,7 +218,8 @@ def apply(self, **kwexpr): expression itself, for example `apply(square_root="sqrt(@foo)")` """ for alias, expr in kwexpr.items(): - self._projections.append([alias, expr]) + projection = Projection(expr, alias ) + self._aggregateplan.extend(projection.build_args()) return self @@ -224,10 +266,7 @@ def limit(self, offset, num): """ limit = Limit(offset, num) - if self._groups: - self._groups[-1].limit = limit - else: - self._limit = limit + self._limit = limit return self def sort_by(self, *fields, **kwargs): @@ -258,16 +297,34 @@ def sort_by(self, *fields, **kwargs): .sort_by(Desc('@paid'), max=10) ``` """ - self._max = kwargs.get('max', 0) if isinstance(fields, (string_types, SortDirection)): fields = [fields] - for f in fields: - if isinstance(f, SortDirection): - self._sortby += [f.field, f.DIRSTRING] - else: - self._sortby.append(f) + + max = kwargs.get('max', 0) + sortby = SortBy(fields, max) + + self._aggregateplan.extend(sortby.build_args()) + return self + + def filter(self, expressions): + """ + Specify filter for post-query results using predicates relating to values in the result set. + + ### Parameters + + - **fields**: Fields to group by. This can either be a single string, + or a list of strings. + """ + if isinstance(expressions, (string_types)): + expressions = [expressions] + + for expression in expressions: + self._aggregateplan.extend(['FILTER', expression]) + return self + + def with_schema(self): """ If set, the `schema` property will contain a list of `[field, type]` @@ -312,18 +369,8 @@ def build_args(self): ret.append('LOAD') ret.append(str(len(self._loadfields))) ret.extend(self._loadfields) - for group in self._groups: - ret += ['GROUPBY'] + group.build_args() + group.limit.build_args() - for alias, projector in self._projections: - ret += ['APPLY', projector] - if alias: - ret += ['AS', alias] - - if self._sortby: - ret += ['SORTBY', str(len(self._sortby))] - ret += self._sortby - if self._max: - ret += ['MAX', str(self._max)] + + ret.extend(self._aggregateplan) ret += self._limit.build_args() diff --git a/test/test_builder.py b/test/test_builder.py index b621890..1936a81 100644 --- a/test/test_builder.py +++ b/test/test_builder.py @@ -1,9 +1,9 @@ -from unittest import TestCase +import unittest import redisearch.aggregation as a import redisearch.querystring as q import redisearch.reducers as r -class QueryBuilderTest(TestCase): +class QueryBuilderTest(unittest.TestCase): def testBetween(self): b = q.between(1, 10) self.assertEqual('[1 10]', str(b)) @@ -42,16 +42,16 @@ def testGroup(self): # Single field, single reducer g = a.Group('foo', r.count()) ret = g.build_args() - self.assertEqual(['1', 'foo', 'REDUCE', 'COUNT', '0'], ret) + self.assertEqual(['GROUPBY', '1', 'foo', 'REDUCE', 'COUNT', '0'], ret) # Multiple fields, single reducer g = a.Group(['foo', 'bar'], r.count()) - self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'], + self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0'], g.build_args()) # Multiple fields, multiple reducers g = a.Group(['foo', 'bar'], [r.count(), r.count_distinct('@fld1')]) - self.assertEqual(['2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'], + self.assertEqual(['GROUPBY', '2', 'foo', 'bar', 'REDUCE', 'COUNT', '0', 'REDUCE', 'COUNT_DISTINCT', '1', '@fld1'], g.build_args()) def testAggRequest(self): @@ -62,13 +62,38 @@ def testAggRequest(self): req = a.AggregateRequest().group_by('@foo', r.count()) self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], req.build_args()) + # Test with group_by and alias on reducer + req = a.AggregateRequest().group_by('@foo', r.count().alias('foo_count')) + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'AS', 'foo_count'], req.build_args()) + # Test with limit - req = a.AggregateRequest().\ - group_by('@foo', r.count()).\ + req = a.AggregateRequest(). \ + group_by('@foo', r.count()). \ sort_by('@foo') self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'SORTBY', '1', '@foo'], req.build_args()) + # Test with apply + req = a.AggregateRequest(). \ + apply(foo="@bar / 2"). \ + group_by('@foo', r.count()) + + self.assertEqual(['*', 'APPLY', '@bar / 2', 'AS', 'foo', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], + req.build_args()) + + # Test with filter + req = a.AggregateRequest().group_by('@foo', r.count()).filter( "@foo=='bar'") + self.assertEqual(['*', 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0', 'FILTER', "@foo=='bar'" ], req.build_args()) + + # Test with filter on different state of the pipeline + req = a.AggregateRequest().filter("@foo=='bar'").group_by('@foo', r.count()) + self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'GROUPBY', '1', '@foo','REDUCE', 'COUNT', '0' ], req.build_args()) + + # Test with filter on different state of the pipeline + req = a.AggregateRequest().filter(["@foo=='bar'","@foo2=='bar2'"]).group_by('@foo', r.count()) + self.assertEqual(['*', 'FILTER', "@foo=='bar'", 'FILTER', "@foo2=='bar2'", 'GROUPBY', '1', '@foo', 'REDUCE', 'COUNT', '0'], + req.build_args()) + # Test with sort_by req = a.AggregateRequest().group_by('@foo', r.count()).sort_by('@date') # print req.build_args() @@ -105,4 +130,8 @@ def test_reducers(self): self.assertEqual(('f1', 'BY', 'f2', 'ASC'), r.first_value('f1', a.Asc('f2')).args) self.assertEqual(('f1', 'BY', 'f1', 'ASC'), r.first_value('f1', a.Asc).args) - self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args) \ No newline at end of file + self.assertEqual(('f1', '50'), r.random_sample('f1', 50).args) + +if __name__ == '__main__': + + unittest.main() \ No newline at end of file