From cd20a3bbc7713908d7fb5fd7af4a91d52f126370 Mon Sep 17 00:00:00 2001 From: ChenWenbin Date: Fri, 22 Nov 2024 02:32:03 +0800 Subject: [PATCH] Fix potential memory issues when use deepspeed Z3 (#6726) I had OOM problem when doing DPO training using zero3. It needs to call module twice in one training step, and second call is with no_grad(). The problem is caused by two bugs: 1. "__n_available_params", which helps to control fetched parameters, becomes negative after release_and_reset_all() function. 2. module.ds_grads_remaining becomes negative in backward() if we call module more than once in one training step. I tried to create two patches to fix these issues. --------- Signed-off-by: Wenbin Chen Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com> --- deepspeed/runtime/zero/parameter_offload.py | 3 +- .../zero/partitioned_param_coordinator.py | 3 +- .../runtime/zero/test_zero_multiple_run.py | 53 +++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) create mode 100644 tests/unit/runtime/zero/test_zero_multiple_run.py diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 082d7e874e4d..f945f5166190 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -392,7 +392,8 @@ def _run_before_forward_function(input): _run_after_backward_hook, inputs) def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 + if not hasattr(module, "ds_grads_remaining"): + module.ds_grads_remaining = 0 if not hasattr(module, "post_bwd_fn"): diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 49f477cc4a1b..596d0e9c20f9 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -252,7 +252,6 @@ def reset_step(self) -> None: self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) self.__step_id = 0 - self.__n_available_params = 0 self.__profiler.reset_events() def _dump_params(self, tag, sub_module, params, step_id=None): @@ -430,7 +429,7 @@ def release_and_reset_all(self, module: Module) -> None: # there's a hook execution issue param.ds_active_sub_modules.clear() self.__release_param(param) - + self.__n_available_params = 0 for param in iter_params(module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: raise RuntimeError(f"{param.ds_summary()} expected to be released") diff --git a/tests/unit/runtime/zero/test_zero_multiple_run.py b/tests/unit/runtime/zero/test_zero_multiple_run.py new file mode 100644 index 000000000000..d4eb3a578cc9 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_multiple_run.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +import torch +from unit.common import DistributedTest, preferred_dtype +from unit.simple_model import SimpleModel, random_dataloader + + +class TestZ3MultipleModelCall(DistributedTest): + world_size = 1 + + def test_z3_multiple_model_call(self): + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 3 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + hidden_dim, nlayers = 2048, 3 + model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers) + model_engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=model, + model_parameters=model.parameters()) + data_loader = iter( + random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device)) + + for n, batch in enumerate(data_loader): + loss1 = model_engine(batch[0], batch[1]) + with torch.no_grad(): + loss2 = model_engine(batch[0], batch[1]) + loss = loss1 + loss2 + model_engine.backward(loss) + for name, submodule in model_engine.module.linears._modules.items(): + assert hasattr(submodule, "ds_grads_remaining"), \ + f"linears.{name} does not have variable ds_grads_remaining" + assert submodule.ds_grads_remaining == 0, \ + f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})" + model_engine.step()