Skip to content

Commit fbd22a3

Browse files
authored
[aiogoogle] clean up storage client Task management (#14347)
Alright, so. 1. If you `await`, call `exception`, or call `result` on a cancelled task, you receive a `CancelledError`. 2. A cancelled task is not necessarily done. The task could catch the `CancelledError` and do something, including raise a different exception (e.g. because a resource close failed). 3. Nested try-finally sucks. This PR adds `_cleanup_future` which: 1. Cancels the future. 2. Waits for the future to receive its cancellation and then terminate. 3. Checks if the future is cleanly cancelled (in which case there is nothing more for us to do). 4. Retrieves any present exceptions from a not-cancelled (but done!) future. We then use this function, in combination with exit stacks, to simply and cleanly manage exceptions. I also added some missing retry_transient_errors invocations.
1 parent 7ccd47c commit fbd22a3

File tree

1 file changed

+49
-41
lines changed

1 file changed

+49
-41
lines changed

hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py

+49-41
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from typing import Tuple, Any, Set, Optional, MutableMapping, Dict, AsyncIterator, cast, Type, List, Coroutine
33
from types import TracebackType
44
from multidict import CIMultiDictProxy # pylint: disable=unused-import
5-
import sys
65
import logging
76
import asyncio
87
import urllib.parse
98
import aiohttp
109
import datetime
10+
from contextlib import AsyncExitStack
1111
from hailtop import timex
1212
from hailtop.utils import secret_alnum_string, OnlineBoundedGather2, TransientError, retry_transient_errors
1313
from hailtop.aiotools.fs import (
@@ -69,39 +69,53 @@ async def __anext__(self):
6969
raise StopAsyncIteration
7070

7171

72+
async def _cleanup_future(fut: asyncio.Future):
73+
if not fut.done():
74+
fut.cancel()
75+
await asyncio.wait([fut])
76+
if not fut.cancelled():
77+
if exc := fut.exception():
78+
raise exc
79+
80+
7281
class InsertObjectStream(WritableStream):
7382
def __init__(self, it: FeedableAsyncIterable[bytes], request_task: asyncio.Task[aiohttp.ClientResponse]):
7483
super().__init__()
7584
self._it = it
7685
self._request_task = request_task
7786
self._value = None
87+
self._exit_stack = AsyncExitStack()
88+
89+
async def cleanup_request_task():
90+
if not self._request_task.cancelled():
91+
try:
92+
async with await self._request_task as response:
93+
self._value = await response.json()
94+
except AttributeError as err:
95+
raise ValueError(repr(self._request_task)) from err
96+
await _cleanup_future(self._request_task)
97+
98+
self._exit_stack.push_async_callback(cleanup_request_task)
7899

79100
async def write(self, b):
80101
assert not self.closed
81102

82103
fut = asyncio.ensure_future(self._it.feed(b))
83-
try:
84-
await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED)
85-
if fut.done() and not fut.cancelled():
86-
if exc := fut.exception():
87-
raise exc
88-
return len(b)
89-
raise ValueError('request task finished early')
90-
finally:
91-
fut.cancel()
104+
self._exit_stack.push_async_callback(_cleanup_future, fut)
105+
106+
await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED)
107+
if fut.done():
108+
await fut
109+
return len(b)
110+
raise ValueError('request task finished early')
92111

93112
async def _wait_closed(self):
94-
fut = asyncio.ensure_future(self._it.stop())
95113
try:
114+
fut = asyncio.ensure_future(self._it.stop())
115+
self._exit_stack.push_async_callback(_cleanup_future, fut)
96116
await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED)
97-
async with await self._request_task as resp:
98-
self._value = await resp.json()
99117
finally:
100-
if fut.done() and not fut.cancelled():
101-
if exc := fut.exception():
102-
raise exc
103-
else:
104-
fut.cancel()
118+
await self._exit_stack.aclose()
105119

106120

107121
class _TaskManager:
@@ -119,25 +133,10 @@ async def __aexit__(
119133
) -> None:
120134
assert self._task is not None
121135

122-
if not self._task.done():
123-
if exc_val:
124-
self._task.cancel()
125-
try:
126-
value = await self._task
127-
if self._closable:
128-
value.close()
129-
except:
130-
_, exc, _ = sys.exc_info()
131-
if exc is not exc_val:
132-
log.warning('dropping preempted task exception', exc_info=True)
133-
else:
134-
value = await self._task
135-
if self._closable:
136-
value.close()
136+
if self._closable and self._task.done() and not self._task.cancelled():
137+
(await self._task).close()
137138
else:
138-
value = await self._task
139-
if self._closable:
140-
value.close()
139+
await _cleanup_future(self._task)
141140

142141

143142
class ResumableInsertObjectStream(WritableStream):
@@ -526,7 +525,7 @@ async def __aenter__(self) -> 'GoogleStorageMultiPartCreate':
526525
return self
527526

528527
async def _compose(self, names: List[str], dest_name: str):
529-
await self._fs._storage_client.compose(self._bucket, names, dest_name)
528+
await retry_transient_errors(self._fs._storage_client.compose, self._bucket, names, dest_name)
530529

531530
async def __aexit__(
532531
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
@@ -560,18 +559,27 @@ async def tree_compose(names, dest_name):
560559

561560
chunk_names = [self._tmp_name(f'chunk-{secret_alnum_string()}') for _ in range(32)]
562561

563-
chunk_tasks = [pool.call(tree_compose, c, n) for c, n in zip(chunks, chunk_names)]
562+
async with AsyncExitStack() as stack:
563+
chunk_tasks = []
564+
for chunk, name in zip(chunks, chunk_names):
565+
fut = pool.call(tree_compose, chunk, name)
566+
stack.push_async_callback(_cleanup_future, fut)
567+
chunk_tasks.append(fut)
564568

565-
await pool.wait(chunk_tasks)
569+
await pool.wait(chunk_tasks)
566570

567571
await self._compose(chunk_names, dest_name)
568572

569573
for name in chunk_names:
570-
await pool.call(self._fs._remove_doesnt_exist_ok, f'gs://{self._bucket}/{name}')
574+
await pool.call(
575+
retry_transient_errors, self._fs._remove_doesnt_exist_ok, f'gs://{self._bucket}/{name}'
576+
)
571577

572578
await tree_compose([self._part_name(i) for i in range(self._num_parts)], self._dest_name)
573579
finally:
574-
await self._fs.rmtree(self._sema, f'gs://{self._bucket}/{self._dest_dirname}_/{self._token}')
580+
await retry_transient_errors(
581+
self._fs.rmtree, self._sema, f'gs://{self._bucket}/{self._dest_dirname}_/{self._token}'
582+
)
575583

576584

577585
class GoogleStorageAsyncFSURL(AsyncFSURL):

0 commit comments

Comments
 (0)