Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Add snapshot to MultiProcessingReadingService #1039

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,44 @@ def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_pr
# cumulative_res.extend(res)
# self.assertEqual(list(range(n_elements)), sorted(cumulative_res))

@mp_ctx_parametrize
@dp_parametrize
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(2, 1, 1), (4, 1, 0), (4, 0, 0)])
def test_reading_service_snapshot(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
# Functional Test: Confirms that `snapshot` does capture the state of the underlying DataPipes properly
rs = MultiProcessingReadingService(
num_workers=n_workers,
worker_prefetch_cnt=worker_prefetch_cnt,
main_prefetch_cnt=main_prefetch_cnt,
multiprocessing_context=ctx,
)
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
res = []
stop_index = 3
for i, x in enumerate(dl):
res.append(x)
if i == stop_index:
snapshot = dl.reading_service.snapshot()
break
self.assertEqual(
n_workers,
len(snapshot),
msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, "
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
)

if worker_prefetch_cnt == 0 and main_prefetch_cnt == 0 and dp == dp1:
for snapshot_worker in snapshot:
self.assertAlmostEqual(
stop_index + worker_prefetch_cnt,
snapshot_worker["_number_of_samples_yielded"],
delta=2,
msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, "
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
)
dl.shutdown()

# TODO: Implemented in an upcoming PR
# def test_reading_service_snapshot(self) -> None:
# pass
#
# def test_dataloader2_snapshot(self) -> None:
# pass

Expand Down
20 changes: 20 additions & 0 deletions torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ def DataPipeBehindQueues(
forever = False
protocol.response_terminate()

elif isinstance(request, communication.messages.GetStateRequest):
datapipe_state = source_datapipe.__getstate__()
# Remove pickle-incompatible keys from the state
datapipe_state = {
k: v for k, v in datapipe_state.items() if not callable(v) and not isinstance(v, types.GeneratorType)
}
protocol.response_state(datapipe_state)
yield True # Return control

elif isinstance(request, communication.messages.GetNextRequest):
while forever:
if protocol.is_paused():
Expand Down Expand Up @@ -273,6 +282,14 @@ def resume(self):
if NonBlocking.not_available_hook is not None:
NonBlocking.not_available_hook()

def state_dict(self):
self.protocol.request_state()
try:
response = self.protocol.get_response_state(block=True, timeout=self._response_wait_time)
except communication.protocol.EmptyQueue:
raise NotAvailable
return response.value

def nonblocking_next(self):
if self._stop_iteration:
raise Exception("`next` or `nonblocking_next` called after receiving StopIteration")
Expand Down Expand Up @@ -374,3 +391,6 @@ def request_pause(self):
def request_resume(self):
for dp in self.datapipes:
dp.resume()

def state_dict(self):
return [dp.state_dict() for dp in self.datapipes]
9 changes: 9 additions & 0 deletions torchdata/dataloader2/communication/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,12 @@ def nonblocking_len(self):
except communication.protocol.EmptyQueue:
raise NotAvailable
return response.len

def state_dict(self):
if self.protocol.can_take_request():
self.protocol.request_state()
try:
response = self.protocol.get_response_state(block=True, timeout=self._response_wait_time)
except communication.protocol.EmptyQueue:
raise NotAvailable
return response.value
11 changes: 11 additions & 0 deletions torchdata/dataloader2/communication/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ class StopIterationResponse(Response):
pass


class GetStateRequest(Request):
pass


class GetStateResponse(Request):
__slots__ = "value"

def __init__(self, value):
self.value = value


class InvalidStateResponse(Response):
"""
Returned by DataPipe when it is expecting to get reset request,
Expand Down
35 changes: 35 additions & 0 deletions torchdata/dataloader2/communication/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def request_resume(self):
self.request_queue.put(request)
self.request_sent(request)

def request_state(self):
if not self.can_take_request():
raise Exception("Can not request state while we are still waiting response for previous request")
request = communication.messages.GetStateRequest()
self.request_queue.put(request)
self.request_sent(request)


class ProtocolServer(Protocol):
"""
Expand Down Expand Up @@ -132,6 +139,12 @@ def response_resume(self):
self.response_queue.put(communication.messages.ResumeResponse())
self._req_received = None

def response_state(self, value):
if not self.have_pending_request():
raise Exception("Attempting to reply with pending request")
self.response_queue.put(communication.messages.GetStateResponse(value))
self._req_received = None

def response_worker_exception(self, exception):
if not self.have_pending_request():
raise Exception("Attempting to reply with pending request")
Expand Down Expand Up @@ -205,6 +218,17 @@ def get_response_item(self, block=False, timeout=None):
# raise Exception('Invalid response received')
return response

def get_response_state(self, block=False, timeout=None):
if not self.waiting_for_response():
raise Exception("Can not expect any response without submitted request")
try:
response = self.response_queue.get(block=block, timeout=timeout)
except EmptyException:
raise EmptyQueue("queue is empty")
self.request_served(response)

return response


class EmptyQueue(Exception):
pass
Expand Down Expand Up @@ -311,3 +335,14 @@ def get_response_next(self, block=False, timeout=None):

# TODO(629): Add possible response types validation here
return response

def get_response_state(self, block=False, timeout=None):
if not self.waiting_for_response():
raise Exception("Can not expect any response without submitted request")
try:
response = self.response_queue.get(block=block, timeout=timeout)
except EmptyException:
raise EmptyQueue("queue is empty")
self.request_served(response)

return response
16 changes: 16 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,22 @@ def clean_me(process, req_queue, res_queue):
self._worker_processes = []
self._dispatch_process = None

def snapshot(self):
"""
Captures the state_dict of the underlying worker datapipes via the consumer datapipe.
We only capture the worker datapipes's states and not the prefetching datapipe.
This is a PoC for now so there is no corresponding restoring action to make it properly checkpointable.
"""
if self.num_workers == 0:
raise RuntimeError(
"If you would like to use `snapshot` with `MultiProcessingReadingService`, please use more than 0 workers."
)

self._pause()
result = self._worker_consumer_datapipe.state_dict()
self._resume()
return result

def _pause(self):
"""
Pauses DataPipes' activities such as prefetching, in order to collect state.
Expand Down