Skip to content

Commit f82e67b

Browse files
committed
Clean pubsub test and separate ipc utils
1 parent 3f75212 commit f82e67b

File tree

3 files changed

+170
-179
lines changed

3 files changed

+170
-179
lines changed

pulpcore/tests/functional/test_pubsub.py

Lines changed: 22 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
from types import SimpleNamespace
22
from datetime import datetime
3-
import traceback
43
import select
54
import pytest
6-
import sys
7-
import os
8-
from typing import NamedTuple
9-
from functools import partial
10-
from contextlib import contextmanager
11-
from multiprocessing import Process, Pipe, Lock, SimpleQueue
12-
from multiprocessing.connection import Connection
5+
from pulpcore.tasking import pubsub
6+
from pulpcore.tests.functional.utils import IpcUtil
137

148

159
@pytest.fixture(autouse=True)
@@ -59,22 +53,13 @@ def test_postgres_pubsub():
5953
assert state.got_message is False
6054

6155

62-
class PubsubMessage(NamedTuple):
63-
channel: str
64-
payload: str
56+
M = pubsub.PubsubMessage
57+
PUBSUB_BACKENDS = [
58+
pubsub.PostgresPubSub,
59+
]
6560

6661

67-
M = PubsubMessage
68-
69-
70-
@pytest.fixture
71-
def pubsub_backend():
72-
from pulpcore.tasking import pubsub
73-
74-
return pubsub.PostgresPubSub
75-
76-
77-
# @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
62+
@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
7863
class TestPublish:
7964

8065
@pytest.mark.parametrize(
@@ -88,11 +73,11 @@ class TestPublish:
8873
pytest.param(True, id="bool"),
8974
),
9075
)
91-
def test_with_payload_as(self, pubsub_backend, payload):
76+
def test_with_payload_as(self, pubsub_backend: pubsub.BasePubSubBackend, payload):
9277
pubsub_backend.publish("channel", payload=payload)
9378

9479

95-
# @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
80+
@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
9681
@pytest.mark.parametrize(
9782
"messages",
9883
(
@@ -117,7 +102,9 @@ def publish_all(self, messages, publisher):
117102
for channel, payload in messages:
118103
publisher.publish(channel, payload=payload)
119104

120-
def test_with(self, pubsub_backend, messages):
105+
def test_with(
106+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
107+
):
121108
channels = {m.channel for m in messages}
122109
publisher = pubsub_backend
123110
with pubsub_backend() as subscriber:
@@ -128,7 +115,9 @@ def test_with(self, pubsub_backend, messages):
128115
self.unsubscribe_all(channels, subscriber)
129116
assert subscriber.fetch() == []
130117

131-
def test_select_readiness_with(self, pubsub_backend, messages):
118+
def test_select_readiness_with(
119+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
120+
):
132121
TIMEOUT = 0.1
133122
CHANNELS = {m.channel for m in messages}
134123
publisher = pubsub_backend
@@ -159,155 +148,6 @@ def test_select_readiness_with(self, pubsub_backend, messages):
159148
assert subscriber.fetch() == []
160149

161150

162-
class ProcessErrorData(NamedTuple):
163-
error: Exception
164-
stack_trace: str
165-
166-
167-
class RemoteTracebackError(Exception):
168-
"""An exception that wraps another exception and its remote traceback string."""
169-
170-
def __init__(self, message, remote_traceback):
171-
super().__init__(message)
172-
self.remote_traceback = remote_traceback
173-
174-
def __str__(self):
175-
"""Override __str__ to include the remote traceback when printed."""
176-
return f"{super().__str__()}\n\n--- Remote Traceback ---\n{self.remote_traceback}"
177-
178-
179-
class IpcUtil:
180-
181-
@staticmethod
182-
def run(host_act, child_act) -> list:
183-
# ensures a connection from one run doesn't interfere with the other
184-
conn_1, conn_2 = Pipe()
185-
log = SimpleQueue()
186-
lock = Lock()
187-
turn_1 = partial(IpcUtil._actor_turn, conn_1, starts=True, log=log, lock=lock)
188-
turn_2 = partial(IpcUtil._actor_turn, conn_2, starts=False, log=log, lock=lock)
189-
proc_1 = Process(target=host_act, args=(turn_1, log))
190-
proc_2 = Process(target=child_act, args=(turn_2, log))
191-
proc_1.start()
192-
proc_2.start()
193-
try:
194-
proc_1.join()
195-
finally:
196-
conn_1.send("1")
197-
try:
198-
proc_2.join()
199-
finally:
200-
conn_2.send("1")
201-
conn_1.close()
202-
conn_2.close()
203-
result = IpcUtil.read_log(log)
204-
log.close()
205-
if proc_1.exitcode != 0 or proc_2.exitcode != 0:
206-
error = Exception("General exception")
207-
for item in result:
208-
if isinstance(item, ProcessErrorData):
209-
error, stacktrace = item
210-
break
211-
raise Exception(stacktrace) from error
212-
return result
213-
214-
@staticmethod
215-
@contextmanager
216-
def _actor_turn(conn: Connection, starts: bool, log, lock: Lock, done: bool = False):
217-
TIMEOUT = 1
218-
219-
try:
220-
221-
def flush_conn(conn):
222-
if not conn.poll(TIMEOUT):
223-
err_msg = (
224-
"Tip: make sure the last 'with turn()' (in execution order) "
225-
"is called with 'actor_turn(done=True)', otherwise it may hang."
226-
)
227-
raise TimeoutError(err_msg)
228-
conn.recv()
229-
230-
if starts:
231-
with lock:
232-
conn.send("done")
233-
yield
234-
if not done:
235-
flush_conn(conn)
236-
else:
237-
flush_conn(conn)
238-
with lock:
239-
yield
240-
conn.send("done")
241-
except Exception as e:
242-
traceback.print_exc(file=sys.stderr)
243-
err_header = f"Error from sub-process (pid={os.getpid()}) on test using IpcUtil"
244-
traceback_str = f"{err_header}\n\n{traceback.format_exc()}"
245-
246-
error = ProcessErrorData(e, traceback_str)
247-
log.put(error)
248-
exit(1)
249-
250-
@staticmethod
251-
def read_log(log: SimpleQueue) -> list:
252-
result = []
253-
while not log.empty():
254-
result.append(log.get())
255-
return result
256-
257-
258-
def test_ipc_utils_error_catching():
259-
260-
def host_act(host_turn, log):
261-
with host_turn():
262-
log.put(0)
263-
264-
def child_act(child_turn, log):
265-
with child_turn():
266-
log.put(1)
267-
assert 1 == 0
268-
269-
error_msg = "AssertionError: assert 1 == 0"
270-
with pytest.raises(Exception, match=error_msg):
271-
IpcUtil.run(host_act, child_act)
272-
273-
274-
def test_ipc_utils_correctness():
275-
RUNS = 1000
276-
errors = 0
277-
278-
def host_act(host_turn, log):
279-
with host_turn():
280-
log.put(0)
281-
282-
with host_turn():
283-
log.put(2)
284-
285-
with host_turn():
286-
log.put(4)
287-
288-
def child_act(child_turn, log):
289-
with child_turn():
290-
log.put(1)
291-
292-
with child_turn():
293-
log.put(3)
294-
295-
with child_turn():
296-
log.put(5)
297-
298-
def run():
299-
log = IpcUtil.run(host_act, child_act)
300-
if log != [0, 1, 2, 3, 4, 5]:
301-
return 1
302-
return 0
303-
304-
for _ in range(RUNS):
305-
errors += run()
306-
307-
error_rate = errors / RUNS
308-
assert error_rate == 0
309-
310-
311151
def test_postgres_backend_ipc():
312152
"""Asserts that we are really testing two different connections.
313153
@@ -338,7 +178,7 @@ def child_act(child_turn, log):
338178
assert host_connection_pid != child_connection_pid
339179

340180

341-
# @pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
181+
@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS)
342182
@pytest.mark.parametrize(
343183
"messages",
344184
(
@@ -353,7 +193,9 @@ def child_act(child_turn, log):
353193
)
354194
class TestIpcSubscribeFetch:
355195

356-
def test_with(self, pubsub_backend, messages):
196+
def test_with(
197+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
198+
):
357199
CHANNELS = {m.channel for m in messages}
358200
EXPECTED_LOG = [
359201
"subscribe",
@@ -411,7 +253,9 @@ def publisher_act(publisher_turn, log):
411253
log = IpcUtil.run(subscriber_act, publisher_act)
412254
assert log == EXPECTED_LOG
413255

414-
def test_select_readiness_with(self, pubsub_backend, messages):
256+
def test_select_readiness_with(
257+
self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage]
258+
):
415259
TIMEOUT = 0.1
416260
CHANNELS = {m.channel for m in messages}
417261
EXPECTED_LOG = [
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
from pulpcore.tests.functional.utils import IpcUtil
3+
4+
5+
class TestIpcUtil:
6+
7+
def test_catch_subprocess_errors(self):
8+
9+
def host_act(host_turn, log):
10+
with host_turn():
11+
log.put(0)
12+
13+
def child_act(child_turn, log):
14+
with child_turn():
15+
log.put(1)
16+
assert 1 == 0
17+
18+
error_msg = "AssertionError: assert 1 == 0"
19+
with pytest.raises(Exception, match=error_msg):
20+
IpcUtil.run(host_act, child_act)
21+
22+
def test_turns_are_respected(self):
23+
RUNS = 1000
24+
errors = 0
25+
26+
def host_act(host_turn, log):
27+
with host_turn():
28+
log.put(0)
29+
30+
with host_turn():
31+
log.put(2)
32+
33+
with host_turn():
34+
log.put(4)
35+
36+
def child_act(child_turn, log):
37+
with child_turn():
38+
log.put(1)
39+
40+
with child_turn():
41+
log.put(3)
42+
43+
with child_turn():
44+
log.put(5)
45+
46+
def run():
47+
log = IpcUtil.run(host_act, child_act)
48+
if log != [0, 1, 2, 3, 4, 5]:
49+
return 1
50+
return 0
51+
52+
for _ in range(RUNS):
53+
errors += run()
54+
55+
error_rate = errors / RUNS
56+
assert error_rate == 0

0 commit comments

Comments
 (0)