Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Feat cailey sgd #1127

Closed
wants to merge 12 commits into from
157 changes: 156 additions & 1 deletion src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@

from abc import ABC
from abc import abstractmethod
from collections import OrderedDict
import inspect
from inspect import getcallargs
from typing import Any, Callable, Dict, Type, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn import Parameter
import torch.nn.utils.parametrize as parametrize
from torch.overrides import get_testing_overrides

from brevitas.fx import GraphModule
Expand Down Expand Up @@ -154,7 +159,62 @@ def _init_new_module(self, old_module: Module, name=None):
def _replace_old_module(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
new_module.load_state_dict(old_module.state_dict())
# The dictionary entries relative to parametrizations need to be ignored, as these are passed
# when invoking transfer_parametrizations_and_params.
old_module_state_dict = old_module.state_dict()

# If the model is parametrized filter the state_dict appropiately
if parametrize.is_parametrized(old_module):
# Map the keys "parametrizations.tensor_name.original" to "tensor_name"
keys_to_remove = []
keys_value_to_add = []
for key, value in old_module_state_dict.items():
split_key = key.split(".")
if len(split_key) >= 3 and split_key[-3] == "parametrizations" and split_key[
-1] == "original":
tensor_name = split_key[-2]
keys_value_to_add.append((".".join(split_key[:-3] + [tensor_name]), value))
# We need to remove all the keys corresponding to the parametrizations added to the model
# to make sure the dictionary can be loaded with no missing/unused keys
# NOTE: For safety, an additional check could be added as this would not work if a model
# without parametrizations has any key containing "parametrizations"
if "parametrizations" in split_key:
keys_to_remove.append(key)
# The modifications need to be reflected in old_module_state_dict
for key in keys_to_remove:
del old_module_state_dict[key]
for key, value in keys_value_to_add:
old_module_state_dict[key] = value

# Note that strict is set to True, as all the adaptations to the state dict were performed
new_module.load_state_dict(old_module_state_dict)
# If the old module is parametrized, these need to be transferred to the new module
# We do not rely on the method transfer_parametrizations_and_params as using it can result
# in parameter ties being broken
# Note that unsafe is set to True for efficiency, as the checks should have been done
# when first registering the parametrization to old_module
if parametrize.is_parametrized(old_module):
for tensor_name in old_module.parametrizations:
for param_func in old_module.parametrizations[tensor_name]:
parametrize.register_parametrization(
new_module, tensor_name, param_func, unsafe=True)

# TODO: Remove after debugging
def _replace_old_module_legacy(self, model, old_module, new_module, load_state_dict=True):
replace_module(model, old_module, new_module)
if load_state_dict:
# The dictionary entries relative to parametrizations need to be ignored, as these are passed
# when invoking transfer_parametrizations_and_params.
old_module_state_dict = OrderedDict({
k: v for k,
v in old_module.state_dict().items() if not k.startswith("parametrizations")})
# If the old module is parametrized, these need to be transferred to the new module. Strict needs to be set to False,
# as there will be missing keys for those parameters which have any parametrizations attached.
if parametrize.is_parametrized(old_module):
new_module.load_state_dict(old_module_state_dict, strict=False)
parametrize.transfer_parametrizations_and_params(old_module, new_module)
else:
new_module.load_state_dict(old_module_state_dict)


class InsertModuleCallAfter(GraphTransform):
Expand All @@ -174,6 +234,76 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
return graph_model


class ModuleInstanceRegisterParametrization(Transform):

def __init__(
self, old_module_instance: Module, tensor_name: str,
parametrization_module: Module) -> None:
self.old_module_instance = old_module_instance
self.tensor_name = tensor_name
self.parametrization_module = parametrization_module

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
# register the parametrization in the old_module
parametrize.register_parametrization(
old_module, self.tensor_name, self.parametrization_module, unsafe=True)
break
return model


class ModuleInstanceFuseRotationWeights(Transform):

def __init__(
self,
old_module_instance: Module,
rot_mat: Union[Parameter, Tensor],
rot_func: Callable,
K: int,
tensor_name: str,
axis: int,
is_source: bool,
):
self.old_module_instance = old_module_instance
self.rot_mat = rot_mat
self.rot_func = rot_func
self.K = K
self.tensor_name = tensor_name
self.axis = axis
self.is_source = is_source

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
if hasattr(old_module, 'allocate_params'):
old_module.allocate_params(old_module)
weight = getattr(old_module, self.tensor_name).data

if self.is_source:
if self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
elif self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
else:
raise RuntimeError("Not supported yet")
# If not a source, the module is either a sink or an orphan
else:
if self.axis == 1:
weight = self.rot_func(weight, self.rot_mat, self.K)
elif self.axis == 0:
weight = self.rot_func(weight.t(), self.rot_mat, self.K).t()
else:
raise RuntimeError("Not supported yet")
# Modify the weights in-place
getattr(old_module, self.tensor_name).data = weight

if hasattr(old_module, 'offload_params'):
old_module.offload_params(old_module)
break
return model


class ModuleInstanceToModuleInstance(Transform):

def __init__(self, old_module_instance, new_module_instance):
Expand All @@ -189,6 +319,31 @@ def apply(self, model: GraphModule) -> GraphModule:
return model


class ModuleInstanceWrapModule(Transform):

def __init__(
self,
old_module_instance: Module,
wrapper_class: Type[Module],
module_attribute: str,
kwargs_wrapper: Dict[str, Any]):
self.old_module_instance = old_module_instance
self.wrapper_class = wrapper_class
self.module_attribute = module_attribute
self.kwargs_wrapper = kwargs_wrapper

def apply(self, model: GraphModule) -> GraphModule:
for old_module in model.modules():
if old_module is self.old_module_instance:
kwargs = {self.module_attribute: self.old_module_instance}
kwargs.update(self.kwargs_wrapper)
new_module_instance = self.wrapper_class(**kwargs)
# init the new module based on the old one
replace_module(model, old_module, new_module_instance)
break
return model


class ModuleToModuleByName(ModuleToModule):

def __init__(self, old_module_name, new_module_class, **kwargs):
Expand Down
Loading