diff --git a/remodel/object_handler.py b/remodel/object_handler.py index bcf9dfb..65703a7 100644 --- a/remodel/object_handler.py +++ b/remodel/object_handler.py @@ -41,7 +41,9 @@ def get_or_create(self, id_=None, **kwargs): return self.create(**kwargs), True def filter(self, ids=None, **kwargs): - if ids: + if callable(ids): + query = self.query.filter(ids).filter(kwargs) + elif ids: try: query = self.query.get_all(r.args(ids)).filter(kwargs) except AttributeError: diff --git a/tests/test_object_handler.py b/tests/test_object_handler.py index 2ebedf9..46b47fd 100644 --- a/tests/test_object_handler.py +++ b/tests/test_object_handler.py @@ -126,6 +126,9 @@ def test_returns_object_set(self): def test_by_ids_no_objects(self): assert len(self.Artist.filter(['id'])) == 0 + def test_by_lambda_no_objects(self): + assert len(self.Artist.filter(lambda artist: artist['id'] == 'id')) == 0 + def test_by_kwargs_no_objects(self): assert len(self.Artist.filter(id='id')) == 0 @@ -137,6 +140,14 @@ def test_by_ids_some_objects_valid_filter(self): assert isinstance(objs[0], self.Artist) assert objs[0]['id'] == a['id'] + def test_by_lambda_some_objects_valid_filter(self): + a = self.Artist.create() + self.Artist.create() + objs = self.Artist.filter(lambda artist: artist['id'] == a['id']) + assert len(objs) == 1 + assert isinstance(objs[0], self.Artist) + assert objs[0]['id'] == a['id'] + def test_by_kwargs_some_objects_valid_filter(self): a = self.Artist.create() self.Artist.create() @@ -145,6 +156,22 @@ def test_by_kwargs_some_objects_valid_filter(self): assert isinstance(objs[0], self.Artist) assert objs[0]['id'] == a['id'] + def test_by_ids_and_kwargs_some_objects_valid_filter(self): + a = self.Artist.create(name='Andrei') + self.Artist.create() + objs = self.Artist.filter([a['id']], name='Andrei') + assert len(objs) == 1 + assert isinstance(objs[0], self.Artist) + assert objs[0]['id'] == a['id'] + + def test_by_lambda_and_kwargs_some_objects_valid_filter(self): + a = self.Artist.create(name='Andrei') + self.Artist.create() + objs = self.Artist.filter(lambda artist: artist['id'] == a['id'], name='Andrei') + assert len(objs) == 1 + assert isinstance(objs[0], self.Artist) + assert objs[0]['id'] == a['id'] + def test_by_ids_some_objects_deleted_valid_filter(self): a = self.Artist.create() a_id = a['id'] @@ -152,6 +179,13 @@ def test_by_ids_some_objects_deleted_valid_filter(self): a.delete() assert len(self.Artist.filter([a_id])) == 0 + def test_by_lambda_some_objects_deleted_valid_filter(self): + a = self.Artist.create() + a_id = a['id'] + self.Artist.create() + a.delete() + assert len(self.Artist.filter(lambda artist: artist['id'] == a_id)) == 0 + def test_by_kwargs_some_objects_deleted_valid_filter(self): a = self.Artist.create() a_id = a['id'] @@ -164,11 +198,26 @@ def test_by_ids_some_objects_invalid_filter(self): self.Artist.create() assert len(self.Artist.filter(['id'])) == 0 + def test_by_lambda_some_objects_invalid_filter(self): + self.Artist.create() + self.Artist.create() + assert len(self.Artist.filter(lambda artist: artist['id'] == 'id')) == 0 + def test_by_kwargs_some_objects_invalid_filter(self): self.Artist.create() self.Artist.create() assert len(self.Artist.filter(id='id')) == 0 + def test_by_ids_and_kwargs_some_objects_invalid_filter(self): + self.Artist.create(name='Andrei') + self.Artist.create() + assert len(self.Artist.filter(['id'], name='Andrei')) == 0 + + def test_by_lambda_and_kwargs_some_objects_invalid_filter(self): + self.Artist.create(name='Andrei') + self.Artist.create() + assert len(self.Artist.filter(lambda artist: artist['id'] == 'id', name='Andrei')) == 0 + class CountTests(DbBaseTestCase): def setUp(self):