Skip to content

Commit

Permalink
parallel_loader: Fix TPU memory leak when calling __iter__.
Browse files Browse the repository at this point in the history
A TPU memory leak occurs when calling __iter__ on the MpDeviceLoader
object, even if called only once. However, the memory growth becomes
more noticeable and critical when __iter__ is called repeatedly,
eventually leading to a crash. This issue is caused by the threads not
being properly terminated because the close() method was not invoked.

This commit resolves the issue by ensuring that close() is called,
which properly shuts down the threads and prevents memory from leaking.
  • Loading branch information
dudulightricks committed Sep 18, 2024
1 parent 0415cc6 commit 984af38
Showing 1 changed file with 41 additions and 24 deletions.
65 changes: 41 additions & 24 deletions torch_xla/distributed/parallel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def __init__(self,
self._done = False
self._queues = dict()
self._input_sharding = input_sharding
self._threads = []
for device in self._devices:
self._queues[device] = PerDeviceQueue(device, loader_prefetch_size,
device_prefetch_size)
thread = threading.Thread(target=self._loader_worker)
thread.daemon = True
thread.start()
self._threads.append(thread)
for dqueue in self._queues.values():
for i in range(host_to_device_transfer_threads):
thread = threading.Thread(
Expand All @@ -111,6 +113,7 @@ def __init__(self,
))
thread.daemon = True
thread.start()
self._threads.append(thread)

def per_device_loader(self, device):
"""Retrieves the loader iterator object for the given device.
Expand Down Expand Up @@ -139,6 +142,9 @@ def close(self):
dqueue.queue.close()
dqueue.loader_queue.close()

for thread in self._threads:
thread.join()

@property
def batches_per_execution(self):
return self._batches_per_execution
Expand All @@ -147,18 +153,21 @@ def _loader_worker(self):
queues = list(self._queues.values())
data_iter = enumerate(self._loader)
batch = []
while not self._done:
try:
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
batch = []
for dqueue in queues:
dqueue.loader_queue.close_write()

try:
while not self._done:
try:
_, data = next(data_iter)
except StopIteration:
break
batch.append(data)
if len(batch) == len(self._devices):
for queue_no, device_batch in enumerate(batch):
queues[queue_no].loader_queue.put(device_batch)
batch = []
finally:
for dqueue in queues:
dqueue.loader_queue.close_write()

def _get_batch(self, dqueue):
batch = []
Expand All @@ -171,16 +180,20 @@ def _get_batch(self, dqueue):

def _worker(self, dqueue, host_to_device_transfer_threads):
device = torch.device(dqueue.device)
while True:
batch = self._get_batch(dqueue)
if not batch:
break
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
for data in batch:
dqueue.queue.put(data)
close_queue_count = next(dqueue.close_queue_count)
if close_queue_count == host_to_device_transfer_threads - 1:
dqueue.queue.close_write()

try:
while True:
batch = self._get_batch(dqueue)
if not batch:
break
with torch.no_grad():
batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding)
for data in batch:
dqueue.queue.put(data)
finally:
close_queue_count = next(dqueue.close_queue_count)
if close_queue_count == host_to_device_transfer_threads - 1:
dqueue.queue.close_write()


class MpDeviceLoader(object):
Expand All @@ -206,11 +219,15 @@ def __init__(self, loader, device, **kwargs):
self._loader = loader
self._device = device
self._parallel_loader_kwargs = kwargs
self._parallel_loader = None

def __iter__(self):
parallel_loader = ParallelLoader(self._loader, [self._device],
if self._parallel_loader is not None:
self._parallel_loader.close()
self._parallel_loader = None
self._parallel_loader = ParallelLoader(self._loader, [self._device],
**self._parallel_loader_kwargs)
return parallel_loader.per_device_loader(self._device)
return self._parallel_loader.per_device_loader(self._device)

def __len__(self):
return len(self._loader)

0 comments on commit 984af38

Please sign in to comment.