diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 12cc92d3e..5fecc7272 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -1533,7 +1533,9 @@ def update_shards( return current_state = self.state_dict() - has_optimizer = len(self._optim._optims) > 0 + has_optimizer = len(self._optim._optims) > 0 and all( + len(i) > 0 for i in self._optim.state_dict()["state"].values() + ) # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again # TODO: Ensure lookup tensors are actually being deleted