diff --git a/src/brevitas_examples/common/accelerate_utils/accelerate.py b/src/brevitas_examples/common/accelerate_utils/accelerate.py index ead616ed2..f2c310ecc 100644 --- a/src/brevitas_examples/common/accelerate_utils/accelerate.py +++ b/src/brevitas_examples/common/accelerate_utils/accelerate.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: BSD-3-Clause import logging -from typing import Dict, Mapping, Optional, Union +from typing import Dict, List, Mapping, Optional, Union from accelerate import dispatch_model from accelerate import infer_auto_device_map from accelerate.hooks import add_hook_to_module from accelerate.hooks import AlignDevicesHook +from accelerate.hooks import ModelHook from accelerate.hooks import remove_hook_from_module +from accelerate.hooks import SequentialHook from accelerate.utils import check_tied_parameters_in_config from accelerate.utils import compute_module_sizes from accelerate.utils import find_tied_parameters @@ -18,6 +20,7 @@ from accelerate.utils.modeling import named_module_tensors from psutil import virtual_memory import torch +from torch import nn import brevitas.config as config from brevitas.graph.utils import get_module @@ -382,10 +385,205 @@ def calc_cpu_device_map(absolute_mem_margin: float = 2.0 * 1e9, return cpu_device_map +class UpdateStateDictHook(ModelHook): + """ + `ModelHook` that ensures that in-place operations during the model forward pass update the values + in the weights_maps, thus ensuring that future calls to offload_model result in the updated model + being retrieved. + + Args: + offload (`bool`, *optional*, defaults to `False`): + Whether or not the weights should be offloaded after the forward pass. + weights_map (`Mapping[str, torch.Tensor]`, *optional*): + When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. + """ + + def __init__( + self, + execution_device: Optional[Union[int, str, torch.device]] = None, + offload: bool = False, + weights_map: Optional[Mapping] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, + ): + self.execution_device = execution_device + self.offload = offload + self.weights_map = weights_map + + # The hook pre_forward/post_forward need to have knowledge of this dictionary, as updating the values in the state + # dict should remove the old values that might have been cached in each device. + self.tied_params_map = tied_params_map + + def __repr__(self): + return (f"UpdateStateDictHook(offload={self.offload})") + + def post_forward(self, module, output): + if self.offload: + prefix = self.weights_map.prefix + for key in module.state_dict().keys(): + value = recurse_getattr(module, key) + # It might happen that we call an quantization's inner modules, and this cause some parameters to be + # already on meta device. This is not a problem for their value but we need to check here + curr_device = value.device + if str(curr_device) != "meta": + # Check if there is an old value that needs to be replaced + self.weights_map.dataset.state_dict[prefix + key].copy_(value.detach().cpu()) + + return output + + +# TODO: Remove depending on whether to go with the first option. Still additional logic needs to be incorporated +# to handle tied parameters. +class UpdateStateDictLegacyHook(ModelHook): + """ + `ModelHook` that ensures that in-place operations during the model forward pass update the values + in the weights_maps, thus ensuring that future calls to offload_model result in the updated model + being retrieved. + + Args: + offload (`bool`, *optional*, defaults to `False`): + Whether or not the weights should be offloaded after the forward pass. + weights_map (`Mapping[str, torch.Tensor]`, *optional*): + When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. + """ + + def __init__( + self, + align_device_hook: AlignDevicesHook, + execution_device: Optional[Union[int, str, torch.device]] = None, + offload: bool = False, + weights_map: Optional[Mapping] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, + ): + self.execution_device = execution_device + self.offload = offload + self.weights_map = weights_map + + self.align_device_hook = align_device_hook + + # The hook pre_forward/post_forward need to have knowledge of this dictionary, as updating the values in the state + # dict should remove the old values that might have been cached in each device. + self.tied_params_map = tied_params_map + + def __repr__(self): + return (f"UpdateStateDictHook(offload={self.offload})") + + def post_forward(self, module, output): + if self.offload: + prefix = self.weights_map.prefix + for key in module.state_dict().keys(): + value = recurse_getattr(module, key) + # It might happen that we call an quantization's inner modules, and this cause some parameters to be + # already on meta device. This is not a problem for their value but we need to check here + curr_device = value.device + if str(curr_device) != "meta": + # Update tied_pointers_to_remove + update_tied_pointers_to_remove = False + # Check if there is an old value that needs to be replaced + if prefix + key in self.weights_map.dataset.state_dict: + old_value = self.weights_map.dataset.state_dict[prefix + key] + if (old_value is not None and self.tied_params_map is not None and + old_value.data_ptr() in self.tied_params_map): + # Remove from tied_params_map if present there + del self.tied_params_map[old_value.data_ptr()] + + if (old_value is not None and + self.align_device_hook.tied_pointers_to_remove is not None and + (old_value.data_ptr(), self.execution_device) + in self.align_device_hook.tied_pointers_to_remove): + self.align_device_hook.tied_pointers_to_remove.remove( + (old_value.data_ptr(), self.execution_device)) + # Ensure that the appropiate value is added + update_tied_pointers_to_remove = True + # Move to CPU before storing it in the weights_map + detached_value = value.detach().cpu() + # Reassign in tied_params_map to make sure that the tensor is re-used + self.tied_params_map[detached_value.data_ptr()] = {} + self.tied_params_map[detached_value.data_ptr()][self.execution_device] = value + + if update_tied_pointers_to_remove: + self.align_device_hook.tied_pointers_to_remove.add( + (detached_value.data_ptr(), self.execution_device)) + + # Reassign the tensor in the state_dict + self.weights_map.dataset.state_dict[prefix + key] = detached_value + + return output + + +def add_hook_to_module_with_pre_append( + module: nn.Module, hook: ModelHook, append: bool = False, pre_append: bool = False): + """ + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained after an existing one (if module already contains a hook) or not. + pre_append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained before an existing one (if module already contains a hook) or not. + + Returns: + `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can + be discarded). + """ + + if (append or pre_append) and (getattr(module, "_hf_hook", None) is not None): + old_hook = module._hf_hook + remove_hook_from_module(module) + if append and not pre_append: + hook = SequentialHook(old_hook, hook) + if not append and pre_append: + hook = SequentialHook(hook, old_hook) + else: + raise ValueError( + "Setting both append and pre_append to True is not allowed when adding a hook.") + + # Append is set to False as the appropiate SequentialHook is already attached + return add_hook_to_module(module, hook, append=False) + + +def attach_update_state_dict_hook_on_modules(module: nn.Module) -> None: + if hasattr(module, "_hf_hook"): + hf_hooks = module._hf_hook + align_device_hook = None + if isinstance(hf_hooks, SequentialHook): + for hook in hf_hooks.hooks: + if isinstance(hook, AlignDevicesHook): + align_device_hook = hook + break + elif isinstance(hf_hooks, AlignDevicesHook): + align_device_hook = hf_hooks + # If the align devices hook is present, include the update state dict hook + if align_device_hook is not None: + hook = UpdateStateDictHook( + execution_device=align_device_hook.execution_device, + offload=align_device_hook.offload, + weights_map=align_device_hook.weights_map, + tied_params_map=align_device_hook.tied_params_map, + ) + # Add hook so post-forward gets run first + add_hook_to_module_with_pre_append(module, hook, pre_append=True) + + for child in module.children(): + attach_update_state_dict_hook_on_modules(child) + + def offload_model( model: torch.nn.Module, gpu_device_map: Optional[Dict[int, float]] = None, cpu_device_map: Optional[Dict[str, float]] = None, + preload_module_classes: Optional[List[str]] = None, ) -> torch.nn.Module: """ Wraps accelerate's infer_auto_device_map and dispatch_model. @@ -408,7 +606,8 @@ def offload_model( device_map = infer_auto_device_map( model, memory_map, no_split_module_classes=model._no_split_modules) - model = dispatch_model(model, device_map) + model = dispatch_model( + model=model, device_map=device_map, preload_module_classes=preload_module_classes) # Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used. # TODO: Fix directly in accelerate. diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 495c47919..71e8cdd50 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -20,6 +20,8 @@ from brevitas.graph.equalize import LayerwiseActivationRotation from brevitas.graph.quantize import layerwise_quantize from brevitas.graph.utils import get_module +from brevitas_examples.common.accelerate_utils.accelerate import \ + attach_update_state_dict_hook_on_modules from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.generative.quantize import generate_quant_maps @@ -363,10 +365,14 @@ def main(args): model = add_zero_bias_to_linear(model) model = offload_model(model) + attach_update_state_dict_hook_on_modules(model) with torch.no_grad(): model(**calibration_loader[0]) + remove_hooks(model) + model = offload_model(model) + if args.act_calibration: print("Apply act calibration...") apply_calibration(model, calibration_loader) diff --git a/tests/brevitas_examples/test_accelerate.py b/tests/brevitas_examples/test_accelerate.py new file mode 100644 index 000000000..2b232ec40 --- /dev/null +++ b/tests/brevitas_examples/test_accelerate.py @@ -0,0 +1,240 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from dataclasses import dataclass +import functools +from functools import partial +from typing import Dict, List, Optional, Union +from unittest.mock import patch + +from accelerate import dispatch_model +import pytest +import pytest_cases +from pytest_cases import fixture +import torch +from torch import nn + +from brevitas_examples.common.accelerate_utils.accelerate import \ + attach_update_state_dict_hook_on_modules +from brevitas_examples.common.accelerate_utils.accelerate import offload_model +from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks +from brevitas_examples.common.accelerate_utils.accelerate import update_internal_dict + + +@dataclass +class ModelDataClass: + model_class: type[nn.Module] + output: torch.Tensor + block1_layer1_parameter: torch.Tensor + preload_module_classes: List + +class TestTiedLayer1(nn.Module): + def __init__(self, parameter: torch.Tensor): + super().__init__() + self.parameter = parameter + self.w = nn.Parameter(torch.tensor([2.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.parameter.detach().add_(x) + return self.w*x + +class TestTiedBlock1(nn.Module): + def __init__(self): + super().__init__() + self.tied_parameter = nn.Parameter(torch.tensor([1.0])) + self.layer1 = TestTiedLayer1(self.tied_parameter) + self.layer2 = TestTiedLayer1(self.tied_parameter) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.layer1(x) + out = self.layer2(out) + return out + +class TestBlock1(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.tensor([3.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w*x + +class TestModel1(nn.Module): + def __init__(self): + super().__init__() + self.block1 = TestTiedBlock1() + self.block2 = TestBlock1() + + self._no_split_modules = None + + def forward(self, x: torch.Tensor): + out = self.block1(x) + out = self.block2(out) + return out + +class TestTiedLayer2(nn.Module): + def __init__(self, parameter: torch.Tensor): + super().__init__() + self.parameter = parameter + self.w = nn.Parameter(torch.tensor([2.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.parameter.detach().add_(x) + return self.w*x + +class TestTiedBlock2(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = TestTiedLayer2(nn.Parameter(torch.tensor([1.0]))) + self.layer2 = TestTiedLayer2(self.layer1.parameter) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.layer1(x) + self.layer1.parameter.detach().add_(x) + out = self.layer2(out) + return out + +class TestBlock2(nn.Module): + def __init__(self): + super().__init__() + self.w = nn.Parameter(torch.tensor([3.0])) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w*x + +class TestModel2(nn.Module): + def __init__(self): + super().__init__() + self.block1 = TestTiedBlock2() + self.block2 = TestBlock2() + + self._no_split_modules = None + + def forward(self, x: torch.Tensor): + out = self.block1(x) + out = self.block2(out) + return out + +def dispatch_model_with_preload( + model: nn.Module, + device_map: Dict[str, Union[str, int, torch.device]], + ): + return dispatch_model( + model = model, + device_map = device_map, + preload_module_classes=["TestTiedBlock2"] + ) + + +class TestAccelerate: + + marker_model_dataclass = pytest_cases.parametrize( + "model_dataclass", [ + ModelDataClass( + model_class=TestModel1, + output=torch.tensor([24.0]), + block1_layer1_parameter=torch.tensor([7.0]), + preload_module_classes=[], + ), + ModelDataClass( + model_class=TestModel2, + output=torch.tensor([24.0]), + block1_layer1_parameter=torch.tensor([9.0]), + preload_module_classes=["TestTiedBlock2"], + ) + ] + ) + + marker_device_map = pytest_cases.parametrize("device_map", [{"": 0}, {"": "cpu"}, {"block1": "cpu", "block2": 0}]) + + @pytest.mark.xfail + @marker_model_dataclass + @marker_device_map + def test_accelerate_inplace_operation(self, model_dataclass, device_map): + with patch( + 'brevitas_examples.common.accelerate_utils.accelerate.infer_auto_device_map', + return_value=device_map, + ) as mock_infer: + test_model = model_dataclass.model_class() + test_model = offload_model(test_model, preload_module_classes=model_dataclass.preload_module_classes) + # Run forward pass through model + out = test_model(torch.tensor([2.0])).cpu() + # Hooks are removed and model moved to CPU, thus enabling + # to access the model parameters easily + remove_hooks(test_model) + # Verify that the mocked method was called once + mock_infer.assert_called_once() + # Verify that output is the expected + assert torch.allclose(out, model_dataclass.output) + # Verify that the inplace operations were performed correctly + assert torch.allclose(test_model.block1.layer1.parameter.detach(), model_dataclass.block1_layer1_parameter) + + @pytest.mark.xfail + @marker_model_dataclass + @marker_device_map + def test_accelerate_inplace_operation_post_forward_fix(self, model_dataclass, device_map): + with patch( + 'brevitas_examples.common.accelerate_utils.accelerate.infer_auto_device_map', + return_value=device_map, + ) as mock_infer: + test_model = model_dataclass.model_class() + test_model = offload_model(test_model, preload_module_classes=model_dataclass.preload_module_classes) + + dict_of_hooks = dict() + def hooked_on_a_function(function, prefunction): + @functools.wraps(function) + def run(*args, **kwargs): + prefunction() + return function(*args, **kwargs) + return run + + def update_params_post_init(module): + update_internal_dict(module) + + for m in test_model.modules(): + if hasattr(m, '_hf_hook'): + if m._hf_hook.weights_map is not None: + dict_of_hooks[m] = m._hf_hook.post_forward + new_funct = partial(update_params_post_init, m) + m._hf_hook.post_forward = hooked_on_a_function(m._hf_hook.post_forward, new_funct) + + # Run forward pass through model + out = test_model(torch.tensor([2.0])).cpu() + # Hooks are removed and model moved to CPU, thus enabling + # to access the model parameters easily + for k, v in dict_of_hooks.items(): + k._hf_hook.post_forward = v + + remove_hooks(test_model) + + # Verify that the mocked method was called once + mock_infer.assert_called_once() + # Verify that output is the expected + assert torch.allclose(out, model_dataclass.output) + # Verify that the inplace operations were performed correctly + assert torch.allclose(test_model.block1.layer1.parameter.detach(), model_dataclass.block1_layer1_parameter) + + @marker_model_dataclass + @marker_device_map + def test_accelerate_inplace_operation_hook_fix(self, model_dataclass, device_map): + with ( + patch('brevitas_examples.common.accelerate_utils.accelerate.infer_auto_device_map', return_value=device_map) as mock_infer, + ): + test_model = model_dataclass.model_class() + + test_model = offload_model(test_model, preload_module_classes=model_dataclass.preload_module_classes) + # Verify that the mocks were called + mock_infer.assert_called_once() + + attach_update_state_dict_hook_on_modules(test_model) + + # Run forward pass through model + out = test_model(torch.tensor([2.0])).cpu() + + remove_hooks(test_model) + + # Verify that the mocked method was called once + mock_infer.assert_called_once() + # Verify that output is the expected + assert torch.allclose(out, model_dataclass.output) + # Verify that the inplace operations were performed correctly + assert torch.allclose(test_model.block1.layer1.parameter.detach(), model_dataclass.block1_layer1_parameter)