diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 42ffebbc4386..6df61f7c8841 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -174,6 +174,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): values for :any:`DeepSpeedMoEConfig`. """ + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In very large models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + Set only for models with injection policies and auto TP. + """ + quant: QuantizationConfig = {} """ NOTE: only works for int8 dtype. diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index cfca1ff4fe4c..f9eb264adc7b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -170,7 +170,7 @@ def __init__(self, model, config): is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' if is_meta_device: self.module.to_empty(device=device) - else: + elif not config.keep_module_on_host: self.module.to(device) if config.tensor_parallel.tp_size > 1: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 5441000e581d..01e766e671d0 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,14 +17,14 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list -def move(tensor, device): +def move(tensor, device, copy=True): if tensor.is_meta: return torch.empty_like(tensor, device=device) else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu. # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. - return tensor.to(device, copy=True) + return tensor.to(device, copy=copy) class ReplaceWithTensorSlicing: @@ -188,7 +188,14 @@ def load(module, state_dict, prefix, mp_group=None): class AutoTP(): - def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl): + def __init__(self, + module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + keep_module_on_host=False): self.module = module self.all_reduce_linears = all_reduce_linears self.prefix = prefix @@ -200,6 +207,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_ self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False + self.keep_module_on_host = keep_module_on_host def in_module_list(module, module_list): for item in module_list: @@ -330,6 +338,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group): def _replace(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: return + device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() + # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some + # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. + return_new_copy = not self.keep_module_on_host weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) # For mixtral-7x8b, need to skip MoE gate linear replace. @@ -363,7 +375,7 @@ def _replace(self, child, name, conv_linear_layer): data = child.weight.data.split(get_shard_size_list( weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), dim=1) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() del data setattr(child, "replaced", True) @@ -371,10 +383,9 @@ def _replace(self, child, name, conv_linear_layer): return LmHeadLinearAllreduce( torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else torch.nn.parameter.Parameter( - move(child.bias, - get_accelerator().current_device_name())), self.mp_group) + move(child.bias, device_name, return_new_copy)), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group) + torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group) else: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] @@ -387,22 +398,22 @@ def _replace(self, child, name, conv_linear_layer): #The copy is a regular copy, The shape of dst and src is the same data_dc = move( prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + device_name, return_new_copy) bias_data_dc = None if child.bias is None else move( prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + device_name, return_new_copy) else: data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), dim=1 if self.conv_linear_layer else 0) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), dim=0) - bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name()) + bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data else: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e59f84bc8453..c7fe9480ab43 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) # 1. Create AutoTP object - _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl) + _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl, + config.keep_module_on_host) # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 9b563523dbeb..fa57466d14b0 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty @pytest.mark.seq_inference +@pytest.mark.parametrize('keep_module_on_host', [True, False]) @pytest.mark.parametrize( "model_w_task", [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")], @@ -570,6 +571,7 @@ def test( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -592,7 +594,10 @@ def test( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) @@ -607,6 +612,7 @@ def test_odd_world_size( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -624,7 +630,10 @@ def test_odd_world_size( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output)