Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement scan for composite aggregation #225

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,43 @@ Asynchronous `scroll <https://www.elastic.co/guide/en/elasticsearch/reference/cu
loop.run_until_complete(go())
loop.close()

Asynchronous `scroll for composite aggregation <https://www.elastic.co/guide/en/elasticsearch/reference/current/search-aggregations-bucket-composite-aggregation.html#_pagination>`_.

.. code-block:: python

import asyncio

from aioelasticsearch import Elasticsearch
from aioelasticsearch.helpers import CompositeAggregationScan

QUERY = {
'aggs': {
'buckets': {
'composite': {
'sources': [
{'score': {'terms': {'field': 'score.keyword'}}},
],
'size': 5,
},
},
},
}

async def go():
async with Elasticsearch() as es:
async with CompositeAggregationScan(
es,
QUERY,
index='index',
) as scan:

async for doc in scan:
print(doc['doc_count'], doc['key'])

loop = asyncio.get_event_loop()
loop.run_until_complete(go())
loop.close()

Thanks
------

Expand Down
156 changes: 154 additions & 2 deletions aioelasticsearch/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
import logging
from copy import deepcopy

from elasticsearch.helpers import ScanError

from aioelasticsearch import NotFoundError
from aioelasticsearch import ElasticsearchException, NotFoundError

__all__ = ('Scan', 'ScanError')
__all__ = ('CompositeAggregationScan', 'Scan', 'ScanError')


logger = logging.getLogger('elasticsearch')
Expand Down Expand Up @@ -140,3 +142,153 @@ def _update_state(self, resp):
self._successful_shards = resp['_shards']['successful']
self._total_shards = resp['_shards']['total']
self._done = not self._hits or self._scroll_id is None


class CompositeAggregationScan:

def __init__(
self,
es,
query,
loop=None,
raise_on_error=True,
prefetch_next_chunk=False,
**kwargs
):
self._es = es
self._query = deepcopy(query)
self._raise_on_error = raise_on_error
self._prefetch_next_chunk = prefetch_next_chunk
self._kwargs = kwargs

if loop is None:
loop = asyncio.get_event_loop()

self._loop = loop

self._aggs_key = self._extract_aggs_key()

if 'composite' not in self._query['aggs'][self._aggs_key]:
raise RuntimeError(
'Scroll available only for composite aggregations.',
)

self._after_key = None

self._initial = True
self._done = False
self._buckets = []
self._buckets_idx = 0

self._successful_shards = 0
self._total_shards = 0
self._prefetched = None

def _extract_aggs_key(self):
try:
return list(self._query['aggs'].keys())[0]
except (KeyError, IndexError):
raise RuntimeError(
"Can't get aggregation key from query {query}."
.format(query=self._query),
)

async def __aenter__(self): # noqa
self._initial = False
await self._fetch_results()

return self

async def __aexit__(self, *exc_info): # noqa
self._reset_prefetched()

def __aiter__(self):
if self._initial:
raise RuntimeError(
'Scan operations should be done '
'inside async context manager.',
)

return self

async def __anext__(self):
if self._done:
raise StopAsyncIteration

if self._buckets_idx >= len(self._buckets):
if self._successful_shards < self._total_shards:
logger.warning(
'Aggregation request has only succeeded '
'on %d shards out of %d.',
self._successful_shards, self._total_shards,
)
if self._raise_on_error:
raise ElasticsearchException(
'Aggregation request has only succeeded '
'on %d shards out of %d.'
.format(self._successful_shards, self._total_shards),
)

await self._fetch_results()
if self._done:
raise StopAsyncIteration

ret = self._buckets[self._buckets_idx]
self._buckets_idx += 1

return ret

async def _search(self):
found, resp = True, None
try:
resp = await self._es.search(
body=self._query,
**self._kwargs,
)
except NotFoundError:
found = False

return found, resp

def _reset_prefetched(self):
if self._prefetched is not None and not self._prefetched.cancelled(): # noqa
self._prefetched.cancel()

self._prefetched = None

async def _fetch_results(self):
if self._prefetched is not None:
found, resp = await self._prefetched
self._reset_prefetched()
else:
found, resp = await self._search()

if not found:
self._done = True

return

self._update_state(resp)

if self._prefetch_next_chunk:
self._prefetched = self._loop.create_task(
self._search(),
)

def _update_query(self):
if self._after_key is None:
return

self._query['aggs'][self._aggs_key]['composite']['after'] = self._after_key # noqa

def _update_state(self, resp):
self._after_key = resp['aggregations'][self._aggs_key].get('after_key')
self._buckets = resp['aggregations'][self._aggs_key]['buckets']
self._buckets_idx = 0

self._update_query()

self._successful_shards = resp['_shards']['successful']
self._total_shards = resp['_shards']['total']

self._done = not self._buckets or self._after_key is None
Loading