From 5c808e31777e5e37db247ee68520bbdc7385d21f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 26 Nov 2024 13:25:40 -0500 Subject: [PATCH 1/2] support for wrapped schedulefree optimizer when using deepspeed --- src/accelerate/optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index acc238a1a99..f0fa6bdc175 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -126,6 +126,8 @@ def train(self): """ if hasattr(self.optimizer, "train") and callable(self.optimizer.train): self.optimizer.train() + elif hasattr(self.optimizer, "optimizer") and hasattr(self.optimizer.optimizer, "train") and callable(self.optimizer.optimizer.train): + self.optimizer.optimizer.train() def eval(self): """ From a666c6a5fbbb6de47821ca94dc421feb0b35b91f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 28 Nov 2024 22:38:10 -0500 Subject: [PATCH 2/2] add comment and lint --- src/accelerate/optimizer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/accelerate/optimizer.py b/src/accelerate/optimizer.py index f0fa6bdc175..25e2b95d98e 100644 --- a/src/accelerate/optimizer.py +++ b/src/accelerate/optimizer.py @@ -126,7 +126,12 @@ def train(self): """ if hasattr(self.optimizer, "train") and callable(self.optimizer.train): self.optimizer.train() - elif hasattr(self.optimizer, "optimizer") and hasattr(self.optimizer.optimizer, "train") and callable(self.optimizer.optimizer.train): + elif ( + hasattr(self.optimizer, "optimizer") + and hasattr(self.optimizer.optimizer, "train") + and callable(self.optimizer.optimizer.train) + ): + # the deepspeed optimizer further wraps the optimizer self.optimizer.optimizer.train() def eval(self):