From 16350657ee32765f1ad646e696a004a0627ff0bc Mon Sep 17 00:00:00 2001 From: Omar Elayan Date: Tue, 10 Dec 2024 14:10:51 +0200 Subject: [PATCH 1/2] [inf] Add config var to enable keeping module on host Using keep_module_on_host config var will let us control if the loaded checkpoints to model parameters will be moved to the device or stay on host --- deepspeed/inference/config.py | 9 +++++++ deepspeed/inference/engine.py | 2 +- deepspeed/module_inject/auto_tp.py | 33 ++++++++++++++++------- deepspeed/module_inject/replace_module.py | 3 ++- tests/unit/inference/test_inference.py | 13 +++++++-- 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index c7c7684fff79..365536b3dd85 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -171,6 +171,15 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): values for :any:`DeepSpeedMoEConfig`. """ + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In large very models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + Set only for models with injection policies and auto TP. + """ + quant: QuantizationConfig = {} """ NOTE: only works for int8 dtype. diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index cfca1ff4fe4c..f9eb264adc7b 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -170,7 +170,7 @@ def __init__(self, model, config): is_meta_device = hasattr(self.module, "device") and self.module.device.type == 'meta' if is_meta_device: self.module.to_empty(device=device) - else: + elif not config.keep_module_on_host: self.module.to(device) if config.tensor_parallel.tp_size > 1: diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 221d490a37d2..49a741a1e814 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,9 +17,11 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list -def move(tensor, device): +def move(tensor, device, keep_module_on_host=False): if tensor.is_meta: - return torch.empty_like(tensor, device=device) + return torch.empty_like(tensor, device='cpu' if keep_module_on_host else device) + elif keep_module_on_host: + return tensor.to('cpu') if device != 'cpu' else tensor else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu. @@ -188,7 +190,14 @@ def load(module, state_dict, prefix, mp_group=None): class AutoTP(): - def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl): + def __init__(self, + module, + all_reduce_linears, + prefix, + state_dict, + linear_layer_setting, + orig_layer_impl, + keep_module_on_host=False): self.module = module self.all_reduce_linears = all_reduce_linears self.prefix = prefix @@ -200,6 +209,7 @@ def __init__(self, module, all_reduce_linears, prefix, state_dict, linear_layer_ self.orig_layer_impl = orig_layer_impl self.linear_policies = None self.conv_linear_layer = False + self.keep_module_on_host = keep_module_on_host def in_module_list(module, module_list): for item in module_list: @@ -359,7 +369,8 @@ def _replace(self, child, name, conv_linear_layer): data = child.weight.data.split(get_shard_size_list( weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), dim=1) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host).detach() del data setattr(child, "replaced", True) @@ -368,9 +379,9 @@ def _replace(self, child, name, conv_linear_layer): torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else torch.nn.parameter.Parameter( move(child.bias, - get_accelerator().current_device_name())), self.mp_group) + get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())), self.mp_group) + torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) else: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] @@ -383,22 +394,24 @@ def _replace(self, child, name, conv_linear_layer): #The copy is a regular copy, The shape of dst and src is the same data_dc = move( prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + get_accelerator().current_device_name(), self.keep_module_on_host) bias_data_dc = None if child.bias is None else move( prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name()) + get_accelerator().current_device_name(), self.keep_module_on_host) else: data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), dim=1 if self.conv_linear_layer else 0) - data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + data_dc = move(data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host).detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), dim=0) - bias_data = move(bias_data[mp_replace.gpu_index], get_accelerator().current_device_name()) + bias_data = move(bias_data[mp_replace.gpu_index], + get_accelerator().current_device_name(), self.keep_module_on_host) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data else: diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 7afe6ca903fb..f7129165756d 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -268,7 +268,8 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) # 1. Create AutoTP object - _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl) + _autotp = AutoTP(module, all_reduce_linears, prefix, state_dict, linear_layer_setting, orig_layer_impl, + config.keep_module_on_host) # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 9b563523dbeb..fa57466d14b0 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -554,6 +554,7 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty @pytest.mark.seq_inference +@pytest.mark.parametrize('keep_module_on_host', [True, False]) @pytest.mark.parametrize( "model_w_task", [("Helsinki-NLP/opus-mt-en-de", "translation"), ("Salesforce/codegen-350M-mono", "text-generation")], @@ -570,6 +571,7 @@ def test( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -592,7 +594,10 @@ def test( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) @@ -607,6 +612,7 @@ def test_odd_world_size( inf_kwargs, assert_fn, dtype, + keep_module_on_host, ): invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) if invalid_test_msg: @@ -624,7 +630,10 @@ def test_odd_world_size( framework="pt") bs_output = pipe(query, **inf_kwargs) - pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + pipe.model = deepspeed.init_inference(pipe.model, + mp_size=world_size, + dtype=dtype, + keep_module_on_host=keep_module_on_host) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) From 7d39f606207f0d0041f22b22f6716677c9757d3f Mon Sep 17 00:00:00 2001 From: Omar Elayan Date: Tue, 17 Dec 2024 09:45:56 +0200 Subject: [PATCH 2/2] Code review fixes --- deepspeed/inference/config.py | 2 +- deepspeed/module_inject/auto_tp.py | 30 ++++++++++++++---------------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 631321ceaf04..6df61f7c8841 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -176,7 +176,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): keep_module_on_host: bool = False """ - When loading checkpoints to model parameters, they are moved to the device. In large very models + When loading checkpoints to model parameters, they are moved to the device. In very large models this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on host and not move them directly to the device (giving an option to quantize checkpoint data before moving it to the device for example). diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 49a741a1e814..15b80621ca0f 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,16 +17,14 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list -def move(tensor, device, keep_module_on_host=False): +def move(tensor, device, copy=True): if tensor.is_meta: - return torch.empty_like(tensor, device='cpu' if keep_module_on_host else device) - elif keep_module_on_host: - return tensor.to('cpu') if device != 'cpu' else tensor + return torch.empty_like(tensor, device=device) else: # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). # Using copy=True instead of clone() will help in case of cpu --> cpu. # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. - return tensor.to(device, copy=True) + return tensor.to(device, copy=copy) class ReplaceWithTensorSlicing: @@ -340,6 +338,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group): def _replace(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: return + device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() + # keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some + # cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy. + return_new_copy = not self.keep_module_on_host weight_shape = child.weight.shape mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) # For mixtral-7x8b, need to skip MoE gate linear replace. @@ -369,8 +371,7 @@ def _replace(self, child, name, conv_linear_layer): data = child.weight.data.split(get_shard_size_list( weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), dim=1) - data_dc = move(data[mp_replace.gpu_index], - get_accelerator().current_device_name(), self.keep_module_on_host).detach() + data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() del data setattr(child, "replaced", True) @@ -378,10 +379,9 @@ def _replace(self, child, name, conv_linear_layer): return LmHeadLinearAllreduce( torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), child.bias if child.bias is None else torch.nn.parameter.Parameter( - move(child.bias, - get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) + move(child.bias, device_name, return_new_copy)), self.mp_group) return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group) + torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group) else: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] @@ -394,24 +394,22 @@ def _replace(self, child, name, conv_linear_layer): #The copy is a regular copy, The shape of dst and src is the same data_dc = move( prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name(), self.keep_module_on_host) + device_name, return_new_copy) bias_data_dc = None if child.bias is None else move( prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), - get_accelerator().current_device_name(), self.keep_module_on_host) + device_name, return_new_copy) else: data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), dim=1 if self.conv_linear_layer else 0) - data_dc = move(data[mp_replace.gpu_index], - get_accelerator().current_device_name(), self.keep_module_on_host).detach() + data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() del data if child.bias is not None: bias_data = child.bias.data.split(get_shard_size_list( weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), dim=0) - bias_data = move(bias_data[mp_replace.gpu_index], - get_accelerator().current_device_name(), self.keep_module_on_host) + bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy) bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) del bias_data else: