diff --git a/hail_search/test_search.py b/hail_search/test_search.py index c0c9d25da8..c74da6de0b 100644 --- a/hail_search/test_search.py +++ b/hail_search/test_search.py @@ -1,5 +1,8 @@ from aiohttp.test_utils import AioHTTPTestCase +import asyncio from copy import deepcopy +import time +from unittest import mock from hail_search.test_utils import get_hail_search_body, FAMILY_2_VARIANT_SAMPLE_DATA, FAMILY_2_MISSING_SAMPLE_DATA, \ VARIANT1, VARIANT2, VARIANT3, VARIANT4, MULTI_PROJECT_SAMPLE_DATA, MULTI_PROJECT_MISSING_SAMPLE_DATA, \ @@ -8,7 +11,7 @@ GCNV_MULTI_FAMILY_VARIANT1, GCNV_MULTI_FAMILY_VARIANT2, SV_WES_SAMPLE_DATA, EXPECTED_SAMPLE_DATA, \ FAMILY_2_MITO_SAMPLE_DATA, FAMILY_2_ALL_SAMPLE_DATA, MITO_VARIANT1, MITO_VARIANT2, MITO_VARIANT3, \ EXPECTED_SAMPLE_DATA_WITH_SEX, SV_WGS_SAMPLE_DATA_WITH_SEX, VARIANT_LOOKUP_VARIANT -from hail_search.web_app import init_web_app +from hail_search.web_app import init_web_app, sync_to_async_hail_query PROJECT_2_VARIANT = { 'variantId': '1-10146-ACC-A', @@ -193,6 +196,20 @@ class HailSearchTestCase(AioHTTPTestCase): async def get_application(self): return await init_web_app() + async def test_sync_to_async_hail_query(self): + request = mock.Mock() + request.app = await self.get_application() + # NB: request.json() is the first arg passed to the callable + request.json.return_value = asyncio.Future() + request.json.return_value.set_result(3) + with self.assertRaises(TimeoutError): + await sync_to_async_hail_query(request, time.sleep, timeout_s=1) + + with mock.patch('hail_search.web_app.ctypes.pythonapi.PyThreadState_SetAsyncExc') as mock_set_async_exc: + mock_set_async_exc.return_value = 2 + with self.assertRaises(SystemExit): + await sync_to_async_hail_query(request, time.sleep, timeout_s=1) + async def test_status(self): async with self.client.request('GET', '/status') as resp: self.assertEqual(resp.status, 200) diff --git a/hail_search/web_app.py b/hail_search/web_app.py index a4ef765c04..fc274d2c31 100644 --- a/hail_search/web_app.py +++ b/hail_search/web_app.py @@ -1,6 +1,7 @@ from aiohttp import web import asyncio import concurrent.futures +import ctypes import functools import json import os @@ -16,6 +17,7 @@ JAVA_OPTS_XSS = os.environ.get('JAVA_OPTS_XSS') MACHINE_MEM = os.environ.get('MACHINE_MEM') JVM_MEMORY_FRACTION = 0.9 +QUERY_TIMEOUT_S = 300 def _handle_exception(e, request): @@ -44,10 +46,32 @@ def _hl_json_default(o): def hl_json_dumps(obj): return json.dumps(obj, default=_hl_json_default) -async def sync_to_async_hail_query(request: web.Request, query: Callable, *args, **kwargs): +async def sync_to_async_hail_query(request: web.Request, query: Callable, *args, timeout_s=QUERY_TIMEOUT_S, **kwargs): loop = asyncio.get_running_loop() - return await loop.run_in_executor(request.app.pool, functools.partial(query, await request.json(), *args, **kwargs)) - + future = loop.run_in_executor(request.app.pool, functools.partial(query, await request.json(), *args, **kwargs)) + try: + return await asyncio.wait_for(future, timeout_s) + except asyncio.TimeoutError: + # Well documented issue with the "wait_for" approach.... the concurrent.Future is canceled but + # the underlying thread is not, allowing the Hail Query under the hood to keep running. + # https://stackoverflow.com/questions/34452590/timeout-handling-while-using-run-in-executor-and-asyncio + # This unsafe approach is taken from: + # https://stackoverflow.com/questions/323972/is-there-any-way-to-kill-a-thread + # + # A few other thoughts: + # - A ProcessPoolExecutor instead of a ThreadPoolExecutor would allow for safe worker termination + # and would potentially be a safer option in general (some portion of a hail query is cpu bound in python!) + # - A "timeout" decorator applied to the query function, catching a SIGALARM would also potentially + # suffice... but threads don't play well with signals. + # - We could also just kill the service/pod (which is fine). + for t in request.app.pool._threads: + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(t.ident), ctypes.py_object(TimeoutError)) + if res > 1: + # "if it returns a number greater than one, you're in trouble, + # and you should call it again with exc=NULL to revert the effect" + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(t.ident), None) + raise SystemExit('PyThreadState_SetAsyncExc failed') + raise TimeoutError('Hail Query Timeout Exceeded') async def gene_counts(request: web.Request) -> web.Response: hail_results = await sync_to_async_hail_query(request, search_hail_backend, gene_counts=True)