From ff689f57aa111261e6c2a506a42479d99674b123 Mon Sep 17 00:00:00 2001 From: Benjamin Badger <54602201+blbadger@users.noreply.github.com> Date: Fri, 7 Jun 2024 07:50:35 -0400 Subject: [PATCH] Extend save_pretrained to offloaded models (#27412) * added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/modeling_utils.py | 64 +++++++++++++++++++++++++++--- tests/test_modeling_utils.py | 37 +++++++++++++++++ 2 files changed, 95 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a613fee62c42ab..34324560ae35cf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -119,6 +119,10 @@ set_module_tensor_to_device, ) + accelerate_version = version.parse(importlib.metadata.version("accelerate")) + if accelerate_version >= version.parse("0.31"): + from accelerate.utils.modeling import get_state_dict_from_offload + if is_safetensors_available(): from safetensors import safe_open from safetensors.torch import load_file as safe_load_file @@ -374,13 +378,12 @@ def shard_checkpoint( storage_id = id_tensor_storage(weight) # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` - if storage_id in storage_id_to_block: + if storage_id in storage_id_to_block and weight.device != torch.device("meta"): block_id = storage_id_to_block[storage_id] sharded_state_dicts[block_id][key] = weight continue weight_size = weight.numel() * dtype_byte_size(weight.dtype) - # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one # weight in the current shard. if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: @@ -2504,8 +2507,26 @@ def save_pretrained( current_peft_config = self.peft_config[active_adapter] current_peft_config.save_pretrained(save_directory) + # for offloaded modules + module_map = {} + # Save the model if state_dict is None: + # if any model parameters are offloaded to the disk, make module map + if hasattr(self, "hf_device_map") and ( + "cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values() + ): + warnings.warn( + "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)" + ) + for name, module in model_to_save.named_modules(): + if name == "": + continue + module_state_dict = module.state_dict() + + for key in module_state_dict: + module_map[name + f".{key}"] = module + state_dict = model_to_save.state_dict() # Translate state_dict from smp to hf if saving with smp >= 1.10 @@ -2531,12 +2552,24 @@ def save_pretrained( # In the non-tensor case, fall back to the pointer of the object itself ptrs[id(tensor)].append(name) - # These are all the pointers of shared tensors. - shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} - error_names = [] - to_delete_names = set() + # These are all the pointers of shared tensors + if hasattr(self, "hf_device_map"): + # if the model has offloaded parameters, we must check using find_tied_parameters() + tied_params = find_tied_parameters(self) + if tied_params: + tied_names = tied_params[0] + shared_ptrs = { + ptr: names for ptr, names in ptrs.items() if any(name in tied_names for name in names) + } + else: + shared_ptrs = {} + else: + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + # Recursively descend to find tied weight keys _tied_weights_keys = _get_tied_weight_keys(self) + error_names = [] + to_delete_names = set() for names in shared_ptrs.values(): # Removing the keys which are declared as known duplicates on # load. This allows to make sure the name which is kept is consistent. @@ -2609,6 +2642,25 @@ def save_pretrained( # Save the model for shard_file, shard in shards.items(): + # remake shard with onloaded parameters if necessary + if module_map: + if accelerate_version < version.parse("0.31"): + raise ImportError( + f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. " + f"Please upgrade accelerate with `pip install -U accelerate`" + ) + # init state_dict for this shard + state_dict = {name: "" for name in shard} + for module_name in shard: + module = module_map[module_name] + # update state dict with onloaded parameters + state_dict = get_state_dict_from_offload(module, module_name, state_dict) + + # assign shard to be the completed state dict + shard = state_dict + del state_dict + gc.collect() + if safe_serialization: # At some point we will need to deal better with save_function (used for TPU and other distributed # joyfulness), but for now this enough. diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 01620724e739b1..8a2db45d9be60e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1056,6 +1056,43 @@ def test_cached_files_are_used_when_internet_is_down(self): # This check we did call the fake head request mock_head.assert_called() + @require_accelerate + @mark.accelerate_tests + @require_torch_accelerator + def test_save_offloaded_model(self): + device_map = { + "transformer.wte": f"{torch_device}:0", + "transformer.wpe": f"{torch_device}:0", + "transformer.h.0": "cpu", + "transformer.h.1": "cpu", + "transformer.h.2": "cpu", + "transformer.h.3": "disk", + "transformer.h.4": "disk", + "transformer.ln_f": f"{torch_device}:0", + "lm_head": f"{torch_device}:0", + } + + # check_models_equal requires onloaded tensors + model_id = "hf-internal-testing/tiny-random-gpt2" + onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") + inputs = torch.tensor([[1, 2, 3]]).to(f"{torch_device}:0") + cpu_output = onloaded_model(inputs)[0] + + with tempfile.TemporaryDirectory() as tmp_dir: + offload_folder = os.path.join(tmp_dir, "offload") + offloaded_model = AutoModelForCausalLM.from_pretrained( + model_id, device_map=device_map, offload_folder=offload_folder + ) + presaved_output = offloaded_model(inputs)[0] + offloaded_model.save_pretrained( + tmp_dir, max_shard_size="200KB" + ) # model is 1.6MB, max shard size is allocated to cpu by default + saved_model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map) + postsaved_output = saved_model(inputs)[0] + + self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4)) + self.assertTrue(torch.allclose(presaved_output, postsaved_output)) + @require_safetensors def test_use_safetensors(self): # Should not raise anymore