Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support async bulk api #90

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
123 changes: 122 additions & 1 deletion aioelasticsearch/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio

import logging
from operator import methodcaller
from typing import List

from aioelasticsearch import NotFoundError
from elasticsearch.helpers import ScanError
from elasticsearch.helpers import ScanError, _chunk_actions, expand_action
from elasticsearch.exceptions import TransportError

from .compat import PY_352

Expand Down Expand Up @@ -147,3 +150,121 @@ def _update_state(self, resp):
self._successful_shards = resp['_shards']['successful']
self._total_shards = resp['_shards']['total']
self._done = not self._hits or self._scroll_id is None


async def worker_bulk(client, datas: List[dict], actions: List[str], **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All project do not provide yet typings everywhere, so it not makes to put them here. Later typings will be provided everywhere...

try:
resp = await client.bulk("\n".join(actions) + '\n', **kwargs)
except TransportError as e:
return e, datas
fail_actions = []
finish_count = 0
for data, (op_type, item) in zip(datas, map(methodcaller('popitem'),
resp['items'])):
ok = 200 <= item.get('status', 500) < 300
if not ok:
fail_actions.append(data)
else:
finish_count += 1
return finish_count, fail_actions


def _get_fail_data(results, serializer):
finish_count = 0
bulk_action = []
bulk_data = []
lazy_exception = None
for result in results:
if isinstance(result[0], int):
finish_count += result[0]
else:
if lazy_exception is None:
lazy_exception = result[0]

for fail_data in result[1]:
for _ in fail_data:
bulk_data.append(_)
if result[1]:
bulk_action.extend(map(serializer.dumps,result[1]))
return finish_count, bulk_data, bulk_action, lazy_exception


async def _retry_handler(client, futures, max_retries, initial_backoff,
max_backoff, **kwargs):
finish = 0
for attempt in range(max_retries + 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have You copied this implementation from original library?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if attempt:
sleep = min(max_backoff, initial_backoff * 2 ** (attempt - 1))
await asyncio.sleep(sleep)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provide please explicit loop


results = await asyncio.gather(*futures,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gather produces unexpected load, not sure, but maybe it should be rewritten in asyncio.queue / workers

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provide please explicit loop

return_exceptions=True)
futures = []

count, fail_data, fail_action, lazy_exception = \
_get_fail_data(results, client.transport.serializer)

finish += count

if not fail_action or attempt == max_retries:
break

coroutine = worker_bulk(client, fail_data, fail_action, **kwargs)
futures.append(asyncio.ensure_future(coroutine))

if lazy_exception:
raise lazy_exception

return finish, fail_data


async def bulk(client, actions, concurrency_limit=2, chunk_size=500,
max_chunk_bytes=100 * 1024 * 1024,
expand_action_callback=expand_action, max_retries=0,
initial_backoff=2, max_backoff=600, stats_only=False, **kwargs):

async def concurrency_wrapper(chunk_iter):

partial_count = 0
if stats_only:
partial_fail = 0
else:
partial_fail = []
for bulk_data, bulk_action in chunk_iter:
futures = [worker_bulk(client, bulk_data, bulk_action, **kwargs)]
count, fails = await _retry_handler(client,
futures,
max_retries,
initial_backoff,
max_backoff, **kwargs)
partial_count += count
if stats_only:
partial_fail += len(fails)
else:
partial_fail.extend(fails)
return partial_count, partial_fail

actions = map(expand_action_callback, actions)
finish_count = 0
if stats_only:
fail_datas = 0
else:
fail_datas = []

chunk_action_iter = _chunk_actions(actions, chunk_size, max_chunk_bytes,
client.transport.serializer)

tasks = []
concurrency_limit = concurrency_limit if concurrency_limit > 0 else 2
for i in range(concurrency_limit):
tasks.append(concurrency_wrapper(chunk_action_iter))

results = await asyncio.gather(*tasks)
for p_count, p_fails in results:
finish_count += p_count
if stats_only:
fail_datas += p_fails
else:
fail_datas.extend(p_fails)

return finish_count, fail_datas
53 changes: 53 additions & 0 deletions tests/test_bulk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
import logging

from unittest import mock

import pytest

from aioelasticsearch import NotFoundError
from aioelasticsearch.helpers import bulk


logger = logging.getLogger('elasticsearch')

def gen_data1():
for i in range(10):
yield {"_index": "test_aioes",
"_type": "type_3",
"_id": str(i),
"foo": "1"}

def gen_data2():
for i in range(10,20):
yield {"_index": "test_aioes",
"_type": "type_3",
"_id": str(i),
"_source": {"foo": "1"}
}


@pytest.mark.run_loop
async def test_bulk_simple(es):
success, fails = await bulk(es, gen_data1(),
concurrency_limit=2,
stats_only=True)
await success == 10
assert fails == 0


success, fails = await bulk(es, gen_data2(),
concurrency_limit=2,
stats_only=True)
await success == 10
assert fails == 0

@pytest.mark.run_loop
async def test_bulk_fails(es):
datas = [{'op_type': 'delete',
'_index': 'test_aioes',
'_type': 'type_3', '_id': "999"}
]
success, fails = await bulk(es,datas,stats_only=True)
await success == 0
await success == 1