From 9bc33ea36129fce4d70ac6a332d29ed10d923eca Mon Sep 17 00:00:00 2001 From: Moshe Island Date: Wed, 28 Feb 2024 16:52:24 +0200 Subject: [PATCH] Fix loading universal checkpoint for BF16_Optimizer PR#5104 (Remove optimizer step on initialization) breaks loading universal checkpoint for BF16_Optimizer. This is since universal checkpoint attempts to load the optimizer states into lp._hp_mapping.optim_state dictionary before they are initialized (by step). As a workaround for loading universal checkpoint, perform step and init hp params optimizer's states before loading from universal checkpoint files. Signed-off-by: Moshe Island --- deepspeed/runtime/bf16_optimizer.py | 2 ++ deepspeed/runtime/engine.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index aaa836bf1c31..406888bd5030 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -445,6 +445,8 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._link_all_hp_params() def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights): + self.optimizer.step() + self._lazy_init_hp_params_optimizer_state() self._load_hp_checkpoint_state(checkpoint_folder) @property diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5c1202ba06ae..664ff1a89c0a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2785,7 +2785,7 @@ def load_checkpoint(self, if self.load_universal_checkpoint(): self.optimizer.update_lp_params() if load_zero_checkpoint: - self.update_optimizer_step(step=client_states['iteration'] + 1) + self.update_optimizer_step(step=client_states['iteration']) return load_path, client_states