Skip to content

Commit

Permalink
[memory leak] Replace GradientState -> DataLoader reference with we…
Browse files Browse the repository at this point in the history
…akrefs (#3391)

* Replace GradientState -> DataLoader reference with weakrefs

So they can be cleaned up. Otherwise, they will always stay in memory, leading to notable memory leaks. Note: even accelerator.free_memory() did not work!

* Add comments; initialize _dataloader_references_ref directly instead of indirectly
  • Loading branch information
tomaarsen authored Feb 11, 2025
1 parent 24f8d02 commit 526925b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
30 changes: 24 additions & 6 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import threading
import warnings
import weakref
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Optional
Expand Down Expand Up @@ -1164,8 +1165,7 @@ def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPl
self.__dict__ = self._shared_state
if not self.initialized:
self.sync_gradients = True
self.active_dataloader = None
self.dataloader_references = [None]
self._dataloader_references_ref = [None]
self.plugin_kwargs = (
gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {}
)
Expand Down Expand Up @@ -1242,13 +1242,31 @@ def _set_sync_gradients(self, sync_gradients):

def _add_dataloader(self, dataloader):
"Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this."
self.active_dataloader = dataloader
self.dataloader_references.append(self.active_dataloader)
# We explicitly use assignment to ensure that the property setter is triggered, which is required for garbage collection.
# Avoid using self.dataloader_references.append as it will not trigger the setter.
self.dataloader_references += [dataloader]

def _remove_dataloader(self, dataloader):
"Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this."
self.dataloader_references.remove(dataloader)
self.active_dataloader = self.dataloader_references[-1]
# We explicitly use assignment to ensure that the property setter is triggered.
self.dataloader_references = [
dataloader_ref for dataloader_ref in self.dataloader_references if dataloader_ref != dataloader
]

@property
def active_dataloader(self):
return self.dataloader_references[-1]

@property
def dataloader_references(self):
# We use a property getter and setter with weakrefs to avoid circular references that prevent garbage collection
return [reference() if reference is not None else reference for reference in self._dataloader_references_ref]

@dataloader_references.setter
def dataloader_references(self, references):
self._dataloader_references_ref = [
weakref.ref(dataloader) if dataloader is not None else dataloader for dataloader in references
]

@property
def in_dataloader(self) -> bool:
Expand Down
30 changes: 30 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import random
import unittest
import weakref

import pytest
import torch
Expand Down Expand Up @@ -498,6 +499,35 @@ def test_set_epoch_in_batch_sampler(self):
dataloader.set_epoch(1)
assert batch_sampler.epoch == 1

def test_ensure_dataloader_gets_cleaned_up(self):
# Ensure that the dataloader gets cleaned up properly
class Dummy:
def __init__(self):
dataset = list(range(16))
dataloader = DataLoader(dataset, batch_size=4)

self.accelerator = Accelerator()
self.dataloader = self.accelerator.prepare_data_loader(dataloader)

self.iter = iter(self.dataloader)

def __call__(self, *args, **kwds):
return next(self.iter)

instance = Dummy()
assert instance().tolist() == [0, 1, 2, 3]

# Create weak references to the objects that *should* be cleaned up if the instance is deleted
accelerator_ref = weakref.ref(instance.accelerator)
dataloader_ref = weakref.ref(instance.dataloader)
gradient_state_ref = weakref.ref(instance.dataloader.gradient_state)

del instance

assert accelerator_ref() is None
assert dataloader_ref() is None
assert gradient_state_ref() is None


class StatefulDataLoaderTester(unittest.TestCase):
@require_torchdata_stateful_dataloader
Expand Down

0 comments on commit 526925b

Please sign in to comment.