11from types import SimpleNamespace
22from datetime import datetime
3- import traceback
43import select
54import pytest
6- import sys
7- import os
8- from typing import NamedTuple
9- from functools import partial
10- from contextlib import contextmanager
11- from multiprocessing import Process , Pipe , Lock , SimpleQueue
12- from multiprocessing .connection import Connection
5+ from pulpcore .tasking import pubsub
6+ from pulpcore .tests .functional .utils import IpcUtil
137
148
159@pytest .fixture (autouse = True )
@@ -59,22 +53,13 @@ def test_postgres_pubsub():
5953 assert state .got_message is False
6054
6155
62- class PubsubMessage (NamedTuple ):
63- channel : str
64- payload : str
56+ M = pubsub .PubsubMessage
57+ PUBSUB_BACKENDS = [
58+ pubsub .PostgresPubSub ,
59+ ]
6560
6661
67- M = PubsubMessage
68-
69-
70- @pytest .fixture
71- def pubsub_backend ():
72- from pulpcore .tasking import pubsub
73-
74- return pubsub .PostgresPubSub
75-
76-
77- # @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
62+ @pytest .mark .parametrize ("pubsub_backend" , PUBSUB_BACKENDS )
7863class TestPublish :
7964
8065 @pytest .mark .parametrize (
@@ -88,11 +73,11 @@ class TestPublish:
8873 pytest .param (True , id = "bool" ),
8974 ),
9075 )
91- def test_with_payload_as (self , pubsub_backend , payload ):
76+ def test_with_payload_as (self , pubsub_backend : pubsub . BasePubSubBackend , payload ):
9277 pubsub_backend .publish ("channel" , payload = payload )
9378
9479
95- # @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
80+ @pytest .mark .parametrize ("pubsub_backend" , PUBSUB_BACKENDS )
9681@pytest .mark .parametrize (
9782 "messages" ,
9883 (
@@ -117,7 +102,9 @@ def publish_all(self, messages, publisher):
117102 for channel , payload in messages :
118103 publisher .publish (channel , payload = payload )
119104
120- def test_with (self , pubsub_backend , messages ):
105+ def test_with (
106+ self , pubsub_backend : pubsub .BasePubSubBackend , messages : list [pubsub .PubsubMessage ]
107+ ):
121108 channels = {m .channel for m in messages }
122109 publisher = pubsub_backend
123110 with pubsub_backend () as subscriber :
@@ -128,7 +115,9 @@ def test_with(self, pubsub_backend, messages):
128115 self .unsubscribe_all (channels , subscriber )
129116 assert subscriber .fetch () == []
130117
131- def test_select_readiness_with (self , pubsub_backend , messages ):
118+ def test_select_readiness_with (
119+ self , pubsub_backend : pubsub .BasePubSubBackend , messages : list [pubsub .PubsubMessage ]
120+ ):
132121 TIMEOUT = 0.1
133122 CHANNELS = {m .channel for m in messages }
134123 publisher = pubsub_backend
@@ -159,155 +148,6 @@ def test_select_readiness_with(self, pubsub_backend, messages):
159148 assert subscriber .fetch () == []
160149
161150
162- class ProcessErrorData (NamedTuple ):
163- error : Exception
164- stack_trace : str
165-
166-
167- class RemoteTracebackError (Exception ):
168- """An exception that wraps another exception and its remote traceback string."""
169-
170- def __init__ (self , message , remote_traceback ):
171- super ().__init__ (message )
172- self .remote_traceback = remote_traceback
173-
174- def __str__ (self ):
175- """Override __str__ to include the remote traceback when printed."""
176- return f"{ super ().__str__ ()} \n \n --- Remote Traceback ---\n { self .remote_traceback } "
177-
178-
179- class IpcUtil :
180-
181- @staticmethod
182- def run (host_act , child_act ) -> list :
183- # ensures a connection from one run doesn't interfere with the other
184- conn_1 , conn_2 = Pipe ()
185- log = SimpleQueue ()
186- lock = Lock ()
187- turn_1 = partial (IpcUtil ._actor_turn , conn_1 , starts = True , log = log , lock = lock )
188- turn_2 = partial (IpcUtil ._actor_turn , conn_2 , starts = False , log = log , lock = lock )
189- proc_1 = Process (target = host_act , args = (turn_1 , log ))
190- proc_2 = Process (target = child_act , args = (turn_2 , log ))
191- proc_1 .start ()
192- proc_2 .start ()
193- try :
194- proc_1 .join ()
195- finally :
196- conn_1 .send ("1" )
197- try :
198- proc_2 .join ()
199- finally :
200- conn_2 .send ("1" )
201- conn_1 .close ()
202- conn_2 .close ()
203- result = IpcUtil .read_log (log )
204- log .close ()
205- if proc_1 .exitcode != 0 or proc_2 .exitcode != 0 :
206- error = Exception ("General exception" )
207- for item in result :
208- if isinstance (item , ProcessErrorData ):
209- error , stacktrace = item
210- break
211- raise Exception (stacktrace ) from error
212- return result
213-
214- @staticmethod
215- @contextmanager
216- def _actor_turn (conn : Connection , starts : bool , log , lock : Lock , done : bool = False ):
217- TIMEOUT = 1
218-
219- try :
220-
221- def flush_conn (conn ):
222- if not conn .poll (TIMEOUT ):
223- err_msg = (
224- "Tip: make sure the last 'with turn()' (in execution order) "
225- "is called with 'actor_turn(done=True)', otherwise it may hang."
226- )
227- raise TimeoutError (err_msg )
228- conn .recv ()
229-
230- if starts :
231- with lock :
232- conn .send ("done" )
233- yield
234- if not done :
235- flush_conn (conn )
236- else :
237- flush_conn (conn )
238- with lock :
239- yield
240- conn .send ("done" )
241- except Exception as e :
242- traceback .print_exc (file = sys .stderr )
243- err_header = f"Error from sub-process (pid={ os .getpid ()} ) on test using IpcUtil"
244- traceback_str = f"{ err_header } \n \n { traceback .format_exc ()} "
245-
246- error = ProcessErrorData (e , traceback_str )
247- log .put (error )
248- exit (1 )
249-
250- @staticmethod
251- def read_log (log : SimpleQueue ) -> list :
252- result = []
253- while not log .empty ():
254- result .append (log .get ())
255- return result
256-
257-
258- def test_ipc_utils_error_catching ():
259-
260- def host_act (host_turn , log ):
261- with host_turn ():
262- log .put (0 )
263-
264- def child_act (child_turn , log ):
265- with child_turn ():
266- log .put (1 )
267- assert 1 == 0
268-
269- error_msg = "AssertionError: assert 1 == 0"
270- with pytest .raises (Exception , match = error_msg ):
271- IpcUtil .run (host_act , child_act )
272-
273-
274- def test_ipc_utils_correctness ():
275- RUNS = 1000
276- errors = 0
277-
278- def host_act (host_turn , log ):
279- with host_turn ():
280- log .put (0 )
281-
282- with host_turn ():
283- log .put (2 )
284-
285- with host_turn ():
286- log .put (4 )
287-
288- def child_act (child_turn , log ):
289- with child_turn ():
290- log .put (1 )
291-
292- with child_turn ():
293- log .put (3 )
294-
295- with child_turn ():
296- log .put (5 )
297-
298- def run ():
299- log = IpcUtil .run (host_act , child_act )
300- if log != [0 , 1 , 2 , 3 , 4 , 5 ]:
301- return 1
302- return 0
303-
304- for _ in range (RUNS ):
305- errors += run ()
306-
307- error_rate = errors / RUNS
308- assert error_rate == 0
309-
310-
311151def test_postgres_backend_ipc ():
312152 """Asserts that we are really testing two different connections.
313153
@@ -338,7 +178,7 @@ def child_act(child_turn, log):
338178 assert host_connection_pid != child_connection_pid
339179
340180
341- # @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
181+ @pytest .mark .parametrize ("pubsub_backend" , PUBSUB_BACKENDS )
342182@pytest .mark .parametrize (
343183 "messages" ,
344184 (
@@ -353,7 +193,9 @@ def child_act(child_turn, log):
353193)
354194class TestIpcSubscribeFetch :
355195
356- def test_with (self , pubsub_backend , messages ):
196+ def test_with (
197+ self , pubsub_backend : pubsub .BasePubSubBackend , messages : list [pubsub .PubsubMessage ]
198+ ):
357199 CHANNELS = {m .channel for m in messages }
358200 EXPECTED_LOG = [
359201 "subscribe" ,
@@ -411,7 +253,9 @@ def publisher_act(publisher_turn, log):
411253 log = IpcUtil .run (subscriber_act , publisher_act )
412254 assert log == EXPECTED_LOG
413255
414- def test_select_readiness_with (self , pubsub_backend , messages ):
256+ def test_select_readiness_with (
257+ self , pubsub_backend : pubsub .BasePubSubBackend , messages : list [pubsub .PubsubMessage ]
258+ ):
415259 TIMEOUT = 0.1
416260 CHANNELS = {m .channel for m in messages }
417261 EXPECTED_LOG = [
0 commit comments