2
2
from typing import Tuple , Any , Set , Optional , MutableMapping , Dict , AsyncIterator , cast , Type , List , Coroutine
3
3
from types import TracebackType
4
4
from multidict import CIMultiDictProxy # pylint: disable=unused-import
5
- import sys
6
5
import logging
7
6
import asyncio
8
7
import urllib .parse
9
8
import aiohttp
10
9
import datetime
10
+ from contextlib import AsyncExitStack
11
11
from hailtop import timex
12
12
from hailtop .utils import secret_alnum_string , OnlineBoundedGather2 , TransientError , retry_transient_errors
13
13
from hailtop .aiotools .fs import (
@@ -69,39 +69,53 @@ async def __anext__(self):
69
69
raise StopAsyncIteration
70
70
71
71
72
+ async def _cleanup_future (fut : asyncio .Future ):
73
+ if not fut .done ():
74
+ fut .cancel ()
75
+ await asyncio .wait ([fut ])
76
+ if not fut .cancelled ():
77
+ if exc := fut .exception ():
78
+ raise exc
79
+
80
+
72
81
class InsertObjectStream (WritableStream ):
73
82
def __init__ (self , it : FeedableAsyncIterable [bytes ], request_task : asyncio .Task [aiohttp .ClientResponse ]):
74
83
super ().__init__ ()
75
84
self ._it = it
76
85
self ._request_task = request_task
77
86
self ._value = None
87
+ self ._exit_stack = AsyncExitStack ()
88
+
89
+ async def cleanup_request_task ():
90
+ if not self ._request_task .cancelled ():
91
+ try :
92
+ async with await self ._request_task as response :
93
+ self ._value = await response .json ()
94
+ except AttributeError as err :
95
+ raise ValueError (repr (self ._request_task )) from err
96
+ await _cleanup_future (self ._request_task )
97
+
98
+ self ._exit_stack .push_async_callback (cleanup_request_task )
78
99
79
100
async def write (self , b ):
80
101
assert not self .closed
81
102
82
103
fut = asyncio .ensure_future (self ._it .feed (b ))
83
- try :
84
- await asyncio .wait ([fut , self ._request_task ], return_when = asyncio .FIRST_COMPLETED )
85
- if fut .done () and not fut .cancelled ():
86
- if exc := fut .exception ():
87
- raise exc
88
- return len (b )
89
- raise ValueError ('request task finished early' )
90
- finally :
91
- fut .cancel ()
104
+ self ._exit_stack .push_async_callback (_cleanup_future , fut )
105
+
106
+ await asyncio .wait ([fut , self ._request_task ], return_when = asyncio .FIRST_COMPLETED )
107
+ if fut .done ():
108
+ await fut
109
+ return len (b )
110
+ raise ValueError ('request task finished early' )
92
111
93
112
async def _wait_closed (self ):
94
- fut = asyncio .ensure_future (self ._it .stop ())
95
113
try :
114
+ fut = asyncio .ensure_future (self ._it .stop ())
115
+ self ._exit_stack .push_async_callback (_cleanup_future , fut )
96
116
await asyncio .wait ([fut , self ._request_task ], return_when = asyncio .FIRST_COMPLETED )
97
- async with await self ._request_task as resp :
98
- self ._value = await resp .json ()
99
117
finally :
100
- if fut .done () and not fut .cancelled ():
101
- if exc := fut .exception ():
102
- raise exc
103
- else :
104
- fut .cancel ()
118
+ await self ._exit_stack .aclose ()
105
119
106
120
107
121
class _TaskManager :
@@ -119,25 +133,10 @@ async def __aexit__(
119
133
) -> None :
120
134
assert self ._task is not None
121
135
122
- if not self ._task .done ():
123
- if exc_val :
124
- self ._task .cancel ()
125
- try :
126
- value = await self ._task
127
- if self ._closable :
128
- value .close ()
129
- except :
130
- _ , exc , _ = sys .exc_info ()
131
- if exc is not exc_val :
132
- log .warning ('dropping preempted task exception' , exc_info = True )
133
- else :
134
- value = await self ._task
135
- if self ._closable :
136
- value .close ()
136
+ if self ._closable and self ._task .done () and not self ._task .cancelled ():
137
+ (await self ._task ).close ()
137
138
else :
138
- value = await self ._task
139
- if self ._closable :
140
- value .close ()
139
+ await _cleanup_future (self ._task )
141
140
142
141
143
142
class ResumableInsertObjectStream (WritableStream ):
@@ -526,7 +525,7 @@ async def __aenter__(self) -> 'GoogleStorageMultiPartCreate':
526
525
return self
527
526
528
527
async def _compose (self , names : List [str ], dest_name : str ):
529
- await self ._fs ._storage_client .compose ( self ._bucket , names , dest_name )
528
+ await retry_transient_errors ( self ._fs ._storage_client .compose , self ._bucket , names , dest_name )
530
529
531
530
async def __aexit__ (
532
531
self , exc_type : Optional [Type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
@@ -560,18 +559,27 @@ async def tree_compose(names, dest_name):
560
559
561
560
chunk_names = [self ._tmp_name (f'chunk-{ secret_alnum_string ()} ' ) for _ in range (32 )]
562
561
563
- chunk_tasks = [pool .call (tree_compose , c , n ) for c , n in zip (chunks , chunk_names )]
562
+ async with AsyncExitStack () as stack :
563
+ chunk_tasks = []
564
+ for chunk , name in zip (chunks , chunk_names ):
565
+ fut = pool .call (tree_compose , chunk , name )
566
+ stack .push_async_callback (_cleanup_future , fut )
567
+ chunk_tasks .append (fut )
564
568
565
- await pool .wait (chunk_tasks )
569
+ await pool .wait (chunk_tasks )
566
570
567
571
await self ._compose (chunk_names , dest_name )
568
572
569
573
for name in chunk_names :
570
- await pool .call (self ._fs ._remove_doesnt_exist_ok , f'gs://{ self ._bucket } /{ name } ' )
574
+ await pool .call (
575
+ retry_transient_errors , self ._fs ._remove_doesnt_exist_ok , f'gs://{ self ._bucket } /{ name } '
576
+ )
571
577
572
578
await tree_compose ([self ._part_name (i ) for i in range (self ._num_parts )], self ._dest_name )
573
579
finally :
574
- await self ._fs .rmtree (self ._sema , f'gs://{ self ._bucket } /{ self ._dest_dirname } _/{ self ._token } ' )
580
+ await retry_transient_errors (
581
+ self ._fs .rmtree , self ._sema , f'gs://{ self ._bucket } /{ self ._dest_dirname } _/{ self ._token } '
582
+ )
575
583
576
584
577
585
class GoogleStorageAsyncFSURL (AsyncFSURL ):
0 commit comments