Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat auto round #1064

Merged
merged 48 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
dfb8579
AutoRound standalone implementation
pablomlago Oct 16, 2024
0a132c9
Initial implementation
pablomlago Oct 20, 2024
701f8be
Refactoring before removing legacy code
pablomlago Oct 23, 2024
f60c295
Remove legacy code
pablomlago Oct 23, 2024
d3d1d55
LLM learned round
pablomlago Oct 24, 2024
49c724b
Refactoring round methos
pablomlago Nov 4, 2024
1793575
Remove unused import
pablomlago Nov 4, 2024
1c0720c
Fix license and refactor benchmark
pablomlago Nov 4, 2024
d6ac8fd
Minor license change
pablomlago Nov 5, 2024
983eef2
Minor refactor in SignSGD
pablomlago Nov 5, 2024
401f176
Include appropiate licensing
pablomlago Nov 5, 2024
a650589
Address comments
pablomlago Nov 5, 2024
806661d
Add missing change
pablomlago Nov 6, 2024
2de7515
Fix progress bar
pablomlago Nov 6, 2024
ab97290
Minor improvements
pablomlago Nov 7, 2024
ea09841
Remove utils hierarchy
pablomlago Nov 11, 2024
b988fa2
Unify learned round methods
pablomlago Nov 11, 2024
f4bd984
Refactoring and fixes offload
pablomlago Nov 18, 2024
494a017
Initial implementation distributed training
pablomlago Nov 19, 2024
d79bfdc
Minor cleanup
pablomlago Nov 19, 2024
4042cc8
Fix tests
pablomlago Nov 20, 2024
a0d3d54
Enable scale tuning in learned round
pablomlago Nov 21, 2024
6b89ccc
Unified learned round methods
pablomlago Nov 21, 2024
2d1d9df
Improve argument parsing learned round
pablomlago Nov 27, 2024
7110e8b
Adress comments
pablomlago Nov 27, 2024
30cae34
Minor changes
pablomlago Nov 28, 2024
7825450
Enable passing block name in vision entrypoint
pablomlago Nov 29, 2024
d4f369c
Update vision requirements
pablomlago Nov 29, 2024
6558bd9
Remove cache to dataset methods
pablomlago Nov 29, 2024
56ecaf4
Update sign_sgd.py
Giuseppe5 Nov 30, 2024
183b3ae
Fix indentation
Giuseppe5 Nov 30, 2024
42462a5
Precommit fix
Giuseppe5 Nov 30, 2024
2c0b306
Remove references to LRScheduler for backwards compatibility
pablomlago Dec 2, 2024
14e15db
Fix import failing tests
pablomlago Dec 2, 2024
4e1fdab
Fix for PyTorch 1.11
pablomlago Dec 2, 2024
d748ce8
Remove depedency from SGD
pablomlago Dec 2, 2024
47e2f20
Remove default values
pablomlago Dec 2, 2024
5120cf9
Remove unused imports
pablomlago Dec 2, 2024
6c7e72d
Remove import and allow test to fail
pablomlago Dec 2, 2024
8e41ce5
Account for change in optimizer interface
pablomlago Dec 2, 2024
5e4217f
Transformers version check for sdpa attention
pablomlago Dec 2, 2024
8ae2222
Fix tests
pablomlago Dec 2, 2024
eb498bc
Relax test assertion
pablomlago Dec 2, 2024
6927700
Update test_sign_sgd.py
Giuseppe5 Dec 2, 2024
632e396
Update test_sign_sgd.py
Giuseppe5 Dec 2, 2024
1a2863b
Update learned_round.py
Giuseppe5 Dec 2, 2024
e5bc47c
Update learned_round.py
Giuseppe5 Dec 2, 2024
ba1344f
Fix tests
Giuseppe5 Dec 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements-vision.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
accelerate
torchvision
tqdm
55 changes: 41 additions & 14 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

import brevitas
from brevitas import config
from brevitas.core.function_wrapper.ops_ste import TensorClampSte
from brevitas.core.utils import SliceTensor
from brevitas.function.ops_ste import floor_ste
from brevitas.function.ops_ste import round_ste


class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule):
Expand All @@ -28,12 +30,17 @@ def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float =
self.learned_round_gamma = learned_round_gamma

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(x)
def forward(self, p: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(p)
p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma
p = torch.clamp(p, 0.0, 1.0)
if not self.training:
return p > 0.5
return p

def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundSigmoid(brevitas.jit.ScriptModule):
"""
Expand All @@ -47,10 +54,37 @@ def __init__(self, learned_round_temperature: float = 1.) -> None:
self.learned_round_temperature = learned_round_temperature

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(x / self.learned_round_temperature)
def forward(self, p: torch.Tensor) -> torch.Tensor:
if not self.training:
return p > 0
p = torch.sigmoid(p / self.learned_round_temperature)
return p

@brevitas.jit.script_method
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundIdentity(brevitas.jit.ScriptModule):
"""
Implementation for LearnedRound learned parameter
Adapted from https://arxiv.org/abs/2309.05516
"""

def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.tensor_clamp = TensorClampSte()
self.upper_lower_bound = brevitas.jit.Attribute(0.5, float)

def forward(self, p: torch.Tensor) -> torch.Tensor:
return self.tensor_clamp(
p,
min_val=torch.tensor(-self.upper_lower_bound).type_as(p),
max_val=torch.tensor(self.upper_lower_bound).type_as(p))

def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)


class LearnedRoundSte(brevitas.jit.ScriptModule):
"""
Expand All @@ -72,17 +106,10 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
p = self.p_forward()
p = self.learned_round_impl(self.value)
p = self.tensor_slicer(p)
return floor_ste(x) + p.to(x.dtype)

def p_forward(self):
# In eval mode, performs true quantization, otherwise "soft" quantization
if not self.training:
p = (self.value > 0)
else:
p = self.learned_round_impl(self.value)
return p
p = (p.to(x.dtype)).view_as(x)
return self.learned_round_impl.round_forward(x, p)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
from brevitas.graph.gpxq import SUPPORTED_TCONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.utils.torch_utils import StopFwdException


class GPFQ(GPxQ):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from brevitas import torch_version
from brevitas.graph.gpxq import GPxQ
from brevitas.graph.gpxq import gpxq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.utils.torch_utils import StopFwdException


class GPTQ(GPxQ):
Expand Down
4 changes: 0 additions & 4 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@
SUPPORTED_CONV_OP = (qnn.QuantConv1d, qnn.QuantConv2d, qnn.QuantConv3d, *SUPPORTED_TCONV_OP)


class StopFwdException(Exception):
pass


@dataclass
class LayerHandler:
layer_names: Set = field(default_factory=set)
Expand Down
1 change: 1 addition & 0 deletions src/brevitas/inject/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LearnedRoundImplType(AutoName):
"""
HARD_SIGMOID = auto()
SIGMOID = auto()
IDENTITY = auto()


class ScalingImplType(AutoName):
Expand Down
Loading
Loading