diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index d6a327bab..4e5c1a162 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1356,13 +1356,12 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method= weight = module.weight.data if axis == 1: - weight = rot_func(weight, rot_mat, K) + _update_weights(module, rot_func(weight, rot_mat, K), 'weight') elif axis == 0: - weight = rot_func(weight.t(), rot_mat, K).t() + _update_weights(module, rot_func(weight.t(), rot_mat, K).t(), 'weight') else: raise RuntimeError("Not supported yet") - module.weight.data = weight if hasattr(module, 'offload_params'): module.offload_params(module)