From 984af3801c4951372f443d9f1075428d7af63a75 Mon Sep 17 00:00:00 2001 From: Dudu Moshe Date: Wed, 18 Sep 2024 16:10:05 +0300 Subject: [PATCH] parallel_loader: Fix TPU memory leak when calling __iter__. A TPU memory leak occurs when calling __iter__ on the MpDeviceLoader object, even if called only once. However, the memory growth becomes more noticeable and critical when __iter__ is called repeatedly, eventually leading to a crash. This issue is caused by the threads not being properly terminated because the close() method was not invoked. This commit resolves the issue by ensuring that close() is called, which properly shuts down the threads and prevents memory from leaking. --- torch_xla/distributed/parallel_loader.py | 65 +++++++++++++++--------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 3d98ff4a225a..0eada393e678 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -95,12 +95,14 @@ def __init__(self, self._done = False self._queues = dict() self._input_sharding = input_sharding + self._threads = [] for device in self._devices: self._queues[device] = PerDeviceQueue(device, loader_prefetch_size, device_prefetch_size) thread = threading.Thread(target=self._loader_worker) thread.daemon = True thread.start() + self._threads.append(thread) for dqueue in self._queues.values(): for i in range(host_to_device_transfer_threads): thread = threading.Thread( @@ -111,6 +113,7 @@ def __init__(self, )) thread.daemon = True thread.start() + self._threads.append(thread) def per_device_loader(self, device): """Retrieves the loader iterator object for the given device. @@ -139,6 +142,9 @@ def close(self): dqueue.queue.close() dqueue.loader_queue.close() + for thread in self._threads: + thread.join() + @property def batches_per_execution(self): return self._batches_per_execution @@ -147,18 +153,21 @@ def _loader_worker(self): queues = list(self._queues.values()) data_iter = enumerate(self._loader) batch = [] - while not self._done: - try: - _, data = next(data_iter) - except StopIteration: - break - batch.append(data) - if len(batch) == len(self._devices): - for queue_no, device_batch in enumerate(batch): - queues[queue_no].loader_queue.put(device_batch) - batch = [] - for dqueue in queues: - dqueue.loader_queue.close_write() + + try: + while not self._done: + try: + _, data = next(data_iter) + except StopIteration: + break + batch.append(data) + if len(batch) == len(self._devices): + for queue_no, device_batch in enumerate(batch): + queues[queue_no].loader_queue.put(device_batch) + batch = [] + finally: + for dqueue in queues: + dqueue.loader_queue.close_write() def _get_batch(self, dqueue): batch = [] @@ -171,16 +180,20 @@ def _get_batch(self, dqueue): def _worker(self, dqueue, host_to_device_transfer_threads): device = torch.device(dqueue.device) - while True: - batch = self._get_batch(dqueue) - if not batch: - break - batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding) - for data in batch: - dqueue.queue.put(data) - close_queue_count = next(dqueue.close_queue_count) - if close_queue_count == host_to_device_transfer_threads - 1: - dqueue.queue.close_write() + + try: + while True: + batch = self._get_batch(dqueue) + if not batch: + break + with torch.no_grad(): + batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding) + for data in batch: + dqueue.queue.put(data) + finally: + close_queue_count = next(dqueue.close_queue_count) + if close_queue_count == host_to_device_transfer_threads - 1: + dqueue.queue.close_write() class MpDeviceLoader(object): @@ -206,11 +219,15 @@ def __init__(self, loader, device, **kwargs): self._loader = loader self._device = device self._parallel_loader_kwargs = kwargs + self._parallel_loader = None def __iter__(self): - parallel_loader = ParallelLoader(self._loader, [self._device], + if self._parallel_loader is not None: + self._parallel_loader.close() + self._parallel_loader = None + self._parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs) - return parallel_loader.per_device_loader(self._device) + return self._parallel_loader.per_device_loader(self._device) def __len__(self): return len(self._loader)