Skip to content

Commit fdedefc

Browse files
committed
address comments
1 parent e7cbc94 commit fdedefc

File tree

2 files changed

+54
-32
lines changed

2 files changed

+54
-32
lines changed

flink-python/pyflink/datastream/tests/test_async_function.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def timeout(self, value: Row):
126126
# note that we use assert_equals instead of assert_equals_sorted
127127
self.assert_equals(expected, results)
128128

129-
def test_complete_async_function_with_non_iterable_result(self):
129+
def test_non_iterable_result(self):
130130
self.env.set_parallelism(1)
131131
ds = self.env.from_collection(
132132
[(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)],
@@ -147,12 +147,12 @@ def timeout(self, value: Row):
147147
ds.add_sink(self.test_sink)
148148
try:
149149
self.env.execute()
150+
self.fail()
150151
except Exception as e:
151152
message = str(e)
152-
self.assertTrue("The 'result_future' of AsyncFunction should be completed with data of "
153-
"list type" in message)
153+
self.assertTrue("The result of AsyncFunction should be of list type" in message)
154154

155-
def test_complete_async_function_with_exception(self):
155+
def test_none_result(self):
156156
self.env.set_parallelism(1)
157157
ds = self.env.from_collection(
158158
[(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)],
@@ -162,20 +162,21 @@ def test_complete_async_function_with_exception(self):
162162
class MyAsyncFunction(AsyncFunction):
163163

164164
async def async_invoke(self, value: Row):
165-
raise Exception("encountered an exception")
165+
await asyncio.sleep(10)
166+
return None
166167

167168
def timeout(self, value: Row):
168-
# raise the same exception to make sure test case is stable in all cases
169-
raise Exception("encountered an exception")
169+
return None
170170

171171
ds = AsyncDataStream.unordered_wait(
172-
ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT())
172+
ds, MyAsyncFunction(), Time.seconds(1), 2, Types.INT())
173173
ds.add_sink(self.test_sink)
174174
try:
175175
self.env.execute()
176+
self.fail()
176177
except Exception as e:
177178
message = str(e)
178-
self.assertTrue("Could not complete the element" in message)
179+
self.assertTrue("The result of AsyncFunction cannot be none" in message)
179180

180181
def test_raise_exception_in_async_invoke(self):
181182
self.env.set_parallelism(1)
@@ -198,6 +199,7 @@ def timeout(self, value: Row):
198199
ds.add_sink(self.test_sink)
199200
try:
200201
self.env.execute()
202+
self.fail()
201203
except Exception as e:
202204
message = str(e)
203205
self.assertTrue("encountered an exception" in message)
@@ -223,6 +225,7 @@ def timeout(self, value: Row):
223225
ds.add_sink(self.test_sink)
224226
try:
225227
self.env.execute()
228+
self.fail()
226229
except Exception as e:
227230
message = str(e)
228231
self.assertTrue("encountered an exception" in message)
@@ -319,6 +322,7 @@ async def async_invoke(self, value: Row):
319322
try:
320323
AsyncDataStream.unordered_wait(
321324
ds, MyAsyncFunction(), Time.seconds(5), 2, Types.INT())
325+
self.fail()
322326
except Exception as e:
323327
message = str(e)
324328
self.assertTrue("AsyncFunction is still not supported for 'thread' mode" in message)

flink-python/pyflink/fn_execution/datastream/process/async_function/operation.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,37 @@ def complete(self, result: List[OUT]):
8585
if not self._completed.compare_and_set(False, True):
8686
return
8787

88+
self._complete_internal(result)
89+
90+
def _complete_internal(self, result: List[OUT]):
8891
if isinstance(result, Iterable):
8992
self._process_results(result)
9093
else:
94+
if result is None:
95+
self._exception_handler(
96+
RuntimeError("The result of AsyncFunction cannot be none, "
97+
"please check the methods 'async_invoke' and "
98+
"'timeout' of class '%s'." % self._classname))
99+
else:
100+
self._exception_handler(
101+
RuntimeError("The result of AsyncFunction should be of list type, "
102+
"please check the methods 'async_invoke' and "
103+
"'timeout' of class '%s'." % self._classname))
104+
91105
# complete with empty result, so that we remove timer and move ahead processing
92106
self._process_results([])
93107

94-
raise RuntimeError("The 'result_future' of AsyncFunction should be completed with "
95-
"data of list type, please check the methods 'async_invoke' and "
96-
"'timeout' of class '%s'." % self._classname)
97-
98108
def complete_exceptionally(self, error: Exception):
99109
# already completed, so ignore exception
100110
if not self._completed.compare_and_set(False, True):
101111
return
102112

103-
self._exception_handler(
104-
Exception("Could not complete the element:" + str(self._record), error))
113+
self._complete_exceptionally_internal(error)
114+
115+
def _complete_exceptionally_internal(self, error: Exception):
116+
self._exception_handler(Exception(
117+
"Error happens inside the class '%s' during handling input '%s'"
118+
% (self._classname, str(self._record)), error))
105119

106120
# complete with empty result, so that we remove timer and move ahead processing
107121
self._process_results([])
@@ -114,12 +128,14 @@ def _process_results(self, result: List[OUT]):
114128
self._result_future.complete(result)
115129

116130
def _timer_triggered(self):
117-
if not self._completed.get():
118-
try:
119-
result = self._timeout_func(self._record)
120-
self._result_future.complete(result)
121-
except Exception as e:
122-
self._result_future.complete_exceptionally(e)
131+
if not self._completed.compare_and_set(False, True):
132+
return
133+
134+
try:
135+
result = self._timeout_func(self._record)
136+
self._complete_internal(result)
137+
except Exception as error:
138+
self._complete_exceptionally_internal(error)
123139

124140

125141
class RetryableResultHandler(ResultFuture, Generic[IN, OUT]):
@@ -166,7 +182,7 @@ def _process_retry(self, result: Optional[List[OUT]], error: Optional[Exception]
166182
self._delayed_retry_timer = threading.Timer(next_backoff_time_sec, self._do_retry)
167183
self._delayed_retry_timer.start()
168184
else:
169-
if result is not None:
185+
if error is None:
170186
self._result_handler.complete(result)
171187
else:
172188
self._result_handler.complete_exceptionally(error)
@@ -189,18 +205,20 @@ def _timer_triggered(self):
189205
"""
190206
Rewrite the timeout process to deal with retry state.
191207
"""
192-
if not self._result_handler._completed.get():
193-
# cancel delayed retry timer first
194-
self._cancel_retry_timer()
208+
if not self._result_handler._completed.compare_and_set(False, True):
209+
return
195210

196-
# force reset _retry_awaiting to prevent the handler to trigger retry unnecessarily
197-
self._retry_awaiting.set(False)
211+
# cancel delayed retry timer first
212+
self._cancel_retry_timer()
198213

199-
try:
200-
result = self._result_handler._timeout_func(self._result_handler._record)
201-
self._result_handler.complete(result)
202-
except Exception as e:
203-
self._result_handler.complete_exceptionally(e)
214+
# force reset _retry_awaiting to prevent the handler to trigger retry unnecessarily
215+
self._retry_awaiting.set(False)
216+
217+
try:
218+
result = self._result_handler._timeout_func(self._result_handler._record)
219+
self._result_handler._complete_internal(result)
220+
except Exception as e:
221+
self._result_handler._complete_exceptionally_internal(e)
204222

205223

206224
class Emitter(threading.Thread):

0 commit comments

Comments
 (0)