Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deadlocks and races in multi-threaded asyncio logic #747

Merged
merged 13 commits into from
Jan 18, 2024
7 changes: 4 additions & 3 deletions integration/test_collection_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,13 @@ def test_add_object_batch_with_tenant(batch_collection: BatchCollection) -> None

mt_collection.tenants.create([Tenant(name="tenant" + str(i)) for i in range(5)])
for i in range(5):
with mt_collection.with_tenant("tenant" + str(i % 5)).batch as batch:
col = mt_collection.with_tenant("tenant" + str(i % 5))
with col.batch as batch:
batch.add_object(
properties={"name": "tenant" + str(i % 5)},
)
assert len(mt_collection.batch.failed_objects()) == 0
assert len(mt_collection.batch.failed_references()) == 0
assert len(col.batch.failed_objects()) == 0
assert len(col.batch.failed_references()) == 0
objs = mt_collection.with_tenant("tenant1").query.fetch_objects().objects
assert len(objs) == 1
for obj in objs:
Expand Down
59 changes: 46 additions & 13 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(
connection: ConnectionV4,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
event_loop: asyncio.AbstractEventLoop,
fixed_batch_size: Optional[int] = None, # dynamic by default
fixed_concurrent_requests: Optional[int] = None, # dynamic by default
objects_: Optional[ObjectsBatchRequest] = None,
Expand Down Expand Up @@ -186,15 +185,51 @@ def __init__(
self.__last_scale_up: float = 0
self.__max_observed_rate: int = 0

self.__loop = event_loop
self.__bg_thread = self.__start_bg_thread()

self.__start_bg_thread()
def __run_event_loop(self, loop: asyncio.AbstractEventLoop) -> None:
loop.set_debug(True) # in case of errors, shows async errors in the terminal to users
try:
loop.run_forever()
finally:
# This is entered when loop.stop is scheduled from the main thread
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()

def __start_new_event_loop(self) -> asyncio.AbstractEventLoop:
loop = asyncio.new_event_loop()

event_loop = threading.Thread(
target=self.__run_event_loop,
daemon=True,
args=(loop,),
name="eventLoop",
)
event_loop.start()

while not loop.is_running():
time.sleep(0.01)

def __start_bg_thread(self) -> None:
"""Create a background process that periodically checks how congested the batch queue is."""
return loop

def _shutdown(self) -> None:
"""Shutdown the current batch and wait for all requests to be finished."""
self.flush()

# we are done, shut bg threads down and end the event loop
self.__shut_background_thread_down.set()
while self.__bg_thread.is_alive():
time.sleep(0.01)

def __start_bg_thread(self) -> threading.Thread:
"""Create a background thread that periodically checks how congested the batch queue is."""
self.__shut_background_thread_down = threading.Event()

def periodic_check() -> None:
loop = self.__start_new_event_loop()
future = asyncio.run_coroutine_threadsafe(self.__connection.aopen(), loop)
future.result() # Wait for self._connection.aopen() to finish

while (
self.__shut_background_thread_down is not None
and not self.__shut_background_thread_down.is_set()
Expand Down Expand Up @@ -287,17 +322,22 @@ def periodic_check() -> None:
self.__batch_objects.pop_items(self.__recommended_num_objects),
self.__batch_references.pop_items(self.__recommended_num_refs),
),
self.__loop,
loop,
)

time.sleep(refresh_time)

future = asyncio.run_coroutine_threadsafe(self.__connection.aclose(), loop)
future.result() # Wait for self._connection.aclose() to finish
loop.call_soon_threadsafe(loop.stop)

demon = threading.Thread(
target=periodic_check,
daemon=True,
name="BgBatchScheduler",
)
demon.start()
return demon

async def __send_batch_async(
self, objs: List[_BatchObject], refs: List[_BatchReference]
Expand Down Expand Up @@ -363,13 +403,6 @@ def flush(self) -> None:
):
time.sleep(0.01)

def _shutdown(self) -> None:
"""Shutdown the current batch and wait for all requests to be finished."""
self.flush()

# we are done, shut bg threads down and end the event loop
self.__shut_background_thread_down.set()

def _add_object(
self,
collection: str,
Expand Down
50 changes: 1 addition & 49 deletions weaviate/collections/batch/batch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import asyncio
import threading
import time
from copy import copy
from typing import List, Optional, Any, cast
Expand All @@ -20,62 +18,16 @@ def __init__(
self._connection = connection
self._consistency_level = consistency_level
self._current_batch: Optional[_BatchBase] = None
self._current_loop: Optional[asyncio.AbstractEventLoop] = None
# config options
self._batch_size: Optional[int] = None
self._concurrent_requests: int = 2

self._batch_data = _BatchDataWrapper()
self.__shut_background_thread_down: Optional[threading.Event] = None

def __start_event_loop_thread(self, loop: asyncio.AbstractEventLoop) -> None:
while (
self.__shut_background_thread_down is not None
and not self.__shut_background_thread_down.is_set()
):
if loop.is_running():
continue
else:
loop.run_forever()

def _open_async_connection(self) -> asyncio.AbstractEventLoop:
try:
self._current_loop = asyncio.get_running_loop()
except RuntimeError:
self._current_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._current_loop)

self.__shut_background_thread_down = threading.Event()
event_loop = threading.Thread(
target=self.__start_event_loop_thread,
daemon=True,
args=(self._current_loop,),
name="eventLoop",
)
event_loop.start()

while not self._current_loop.is_running():
time.sleep(0.01)

future = asyncio.run_coroutine_threadsafe(self._connection.aopen(), self._current_loop)
future.result() # Wait for self._connection.aopen() to finish

return self._current_loop

# enter is in inherited classes
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
assert (
self._current_batch is not None
and self._current_loop is not None
and self.__shut_background_thread_down is not None
)
assert self._current_batch is not None
self._current_batch._shutdown()
future = asyncio.run_coroutine_threadsafe(self._connection.aclose(), self._current_loop)
future.result() # Wait for self._connection.aclose() to finish

self.__shut_background_thread_down.set()
self._current_loop.stop()
self._current_loop = None
self._current_batch = None

def wait_for_vector_indexing(
Expand Down
3 changes: 0 additions & 3 deletions weaviate/collections/batch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,12 @@ def add_reference(

class _BatchClientWrapper(_BatchWrapper):
def __enter__(self) -> _BatchClient:
loop = self._open_async_connection()

self._current_batch = _BatchClient(
connection=self._connection,
consistency_level=self._consistency_level,
results=self._batch_data,
fixed_batch_size=self._batch_size,
fixed_concurrent_requests=self._concurrent_requests,
event_loop=loop,
)
return self._current_batch

Expand Down
6 changes: 0 additions & 6 deletions weaviate/collections/batch/collection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from typing import Generic, List, Optional, Sequence, Union

from weaviate.collections.batch.base import _BatchBase, _BatchDataWrapper
Expand All @@ -16,7 +15,6 @@ def __init__(
connection: ConnectionV4,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
event_loop: asyncio.AbstractEventLoop,
fixed_batch_size: Optional[int],
fixed_concurrent_requests: Optional[int],
name: str,
Expand All @@ -26,7 +24,6 @@ def __init__(
connection=connection,
consistency_level=consistency_level,
results=results,
event_loop=event_loop,
fixed_batch_size=fixed_batch_size,
fixed_concurrent_requests=fixed_concurrent_requests,
)
Expand Down Expand Up @@ -112,8 +109,6 @@ def __init__(
self.__tenant = tenant

def __enter__(self) -> _BatchCollection[Properties]:
loop = self._open_async_connection()

self._current_batch = _BatchCollection[Properties](
connection=self._connection,
consistency_level=self._consistency_level,
Expand All @@ -122,7 +117,6 @@ def __enter__(self) -> _BatchCollection[Properties]:
fixed_concurrent_requests=self._concurrent_requests,
name=self.__name,
tenant=self.__tenant,
event_loop=loop,
)
return self._current_batch

Expand Down
1 change: 1 addition & 0 deletions weaviate/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import grpc # type: ignore
from grpc import Channel, ssl_channel_credentials
from grpc.aio import Channel as AsyncChannel # type: ignore

from pydantic import BaseModel, field_validator, model_validator

from weaviate.types import NUMBER
Expand Down