From d1e7777b8df6fb949c4680e6770d3925995c5168 Mon Sep 17 00:00:00 2001 From: TrainCheck Team Date: Mon, 16 Dec 2024 15:59:25 -0500 Subject: [PATCH] fix: remove mark-time checking for non-existence of the flag as DeepSpeedEngine propagates flag from the internal model --- deepspeed/__init__.py | 10 +++---- tests/unit/runtime/test_ds_initialize.py | 33 +++++++++++++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 6bc5642ec8ef..eb245b9492c3 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -68,10 +68,10 @@ def _parse_version(version_str): def _mark_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Mark a trainobj as initialized by setting the ds_is_inited attribute to True.""" - # we shouldn't hit the assert below, but just in case - assert not hasattr( - trainobj, 'ds_is_inited' - ), "Model has already been initialized, please make sure to only call deepspeed.initialize on a model once." + if hasattr(trainobj, 'ds_is_inited'): + assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + return + trainobj.ds_is_inited = True @@ -79,7 +79,7 @@ def _is_initialized(trainobj: Union[torch.nn.Module, Optimizer, _LRScheduler]): """Check if a trainobj has been initialized by checking the ds_is_inited attribute.""" if hasattr(trainobj, 'ds_is_inited'): # we shouldn't hit the assert below, but just in case - assert trainobj.ds_is_inited, "Not expecting the model has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." + assert trainobj.ds_is_inited, "Not expecting the training object has `ds_is_inited` to be False if it exists, make sure you didn't set it to False or called deepspeed.initialize on the model more than once." return True return False diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 2c9ad701bfff..0da24dc2ba32 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -445,17 +445,14 @@ def test_no_repeated_init(self): hidden_dim = 10 model = SimpleModel(hidden_dim) client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01) - - model = SimpleModel() # Initialize DeepSpeed configurations for fp16 config_dict = {'train_batch_size': 1} - client_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) # Initialize DeepSpeed engine _assert_trainobjs_not_inited(model=model, optimizer=client_optimizer, lr_scheduler=None) - model_engine, optim, dataloader, scheduler = deepspeed.initialize(model=model, - optimizer=client_optimizer, - config_params=config_dict) + model_engine, optim, _, _ = deepspeed.initialize(model=model, + optimizer=client_optimizer, + config_params=config_dict) # arguments should be marked as initialized now assert _is_initialized(model), "Client model should be marked as initialized" @@ -464,7 +461,6 @@ def test_no_repeated_init(self): # return values should also be marked as initialized assert _is_initialized(model_engine), "Model engine should be marked as initialized" assert _is_initialized(optim), "Optimizer should be marked as initialized" - assert _is_initialized(scheduler), "Scheduler should be marked as initialized" exception_raised = False try: @@ -473,3 +469,26 @@ def test_no_repeated_init(self): exception_raised = True assert exception_raised, "Repeated initialization should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Initialization on ds types should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + + assert exception_raised, "Initialization on ds types should raise an exception" + + exception_raised = False + try: + deepspeed.initialize(model=model_engine, optimizer=client_optimizer, config_params=config_dict) + except ValueError: + exception_raised = True + assert exception_raised, "Initialization on ds types should raise an exception"