Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/pablomlago/brevitas into fea…
Browse files Browse the repository at this point in the history
…t-auto-round
  • Loading branch information
pablomlago committed Nov 4, 2024
2 parents c3713b3 + 52e0059 commit 75c45d9
Show file tree
Hide file tree
Showing 31 changed files with 2,121 additions and 386 deletions.
92 changes: 92 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
We are more than happy to help you contributing to Brevitas.

Please follow the steps below and be sure that your contribution complies with our guidelines.

1. If you are looking for some issues to get started with, we have a list of <a href="https://github.com/Xilinx/brevitas/labels/good%20first%20issue">good first issues</a> in the issue tracker.

2. If you have some suggestion for features or have encoutered any bugs, don't hesitate to reach out through <a href="https://github.com/Xilinx/brevitas/issues">Brevitas Issue</a>

We welcome submissions for:

* New features like novel PTQ algorithms. Keep in mind that Brevitas tends to integrate new algorithms within the existing infrastructure rather than having standalone implementations
* Support for new quantized layers
* Support for new quantized topologies under brevitas_examples
* Contributions to the documentation and Jupyter notebooks/tutorials
* Bugfixes


2. Submitting your pull request:

1. Fork this repository to your own GitHub account using the *fork* button above.

2. Clone the fork to your local computer using *git clone*. Checkout the branch you want to work on.

3. Please install <a href="https://pre-commit.com/" target="_blank">pre-commit</a> to ensure your code is formatted to our style guidelines.

4. Add your contribution as needed.

5. Use *git add*, *git commit*, *git push* to add changes to your fork.

6. If you are introducing new functionality or fixing a bug, add at least one unit test under the `tests/` folder and make sure it passes before you submit the pull request.

7. Submit a pull request by clicking the *pull request* button on your GitHub repo:
* The <a href="https://github.com/Xilinx/brevitas" target="_blank">main branch</a> should always be treated as stable and clean. Only hot fixes are allowed to be pull-requested. The hot fix is supposed to be very important such that without this fix, a lot of things will break.
* For new features, bug fixes, doc updates, users should pull request against the <a href="https://github.com/Xilinx/brevitas/tree/dev" target="_blank">development branch</a>.

3. Sign Your Work

Please use the *Signed-off-by* line at the end of your patch which indicates that you accept the Developer Certificate of Origin (DCO) defined by https://developercertificate.org/ reproduced below::

```
Developer Certificate of Origin
Version 1.1
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
1 Letterman Drive
Suite D4700
San Francisco, CA, 94129
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license
indicated in the file; or
(b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that
work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated
in the file; or
(c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified
it.
(d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved.
```

You can enable Signed-off-by automatically by adding the `-s` flag to the `git commit` command.

Here is an example Signed-off-by line which indicates that the contributor accepts DCO:

```
This is my commit message
Signed-off-by: Jane Doe <[email protected]>
```

4. We will review your contribution and, if any additional fixes or modifications are
necessary, may provide feedback to guide you. When accepted, your pull request will
be merged to the repository. If you have more questions please contact us.
2 changes: 1 addition & 1 deletion requirements/requirements-export.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
onnx==1.15
onnx==1.17.0
onnxoptimizer
2 changes: 1 addition & 1 deletion requirements/requirements-finn-integration.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
bitstring
onnx==1.15
onnx==1.17.0
onnxoptimizer
onnxruntime>=1.15.0
qonnx
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-llm.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main
tqdm
transformers
transformers[sentencepiece]==4.45.2
2 changes: 1 addition & 1 deletion requirements/requirements-ort-integration.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
onnx==1.15
onnx==1.17.0
onnxoptimizer
onnxruntime>=1.15.0
qonnx
5 changes: 3 additions & 2 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,9 @@ def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
block_dim = self.group_dim + 1 if self.group_dim != -1 else -1
tensor_shape_list[self.group_dim] = (
tensor_shape_list[self.group_dim] + self.group_size - 1) // self.group_size
block_dim = self.group_dim + 1 if self.group_dim != -1 else len(tensor_shape_list)
tensor_shape_list.insert(block_dim, self.group_size)
x = x.view(tensor_shape_list)
return x
Expand Down
17 changes: 4 additions & 13 deletions src/brevitas/core/restrict_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@

class _RestrictClampValue(brevitas.jit.ScriptModule):

def __init__(self, scaling_min_val: Optional[float], restrict_value_impl: Optional[Module]):
def __init__(
self,
scaling_min_val: Optional[float] = None,
restrict_value_impl: Optional[Module] = None):
super(_RestrictClampValue, self).__init__()
if scaling_min_val is not None and scaling_min_val != 0:
self.clamp_min_ste = ScalarClampMinSte(scaling_min_val)
Expand Down Expand Up @@ -90,9 +93,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return x
Expand All @@ -116,9 +116,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.power_of_two(x)
Expand All @@ -143,9 +140,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return Identity()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x / threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand All @@ -171,9 +165,6 @@ def restrict_init_module(self):
def restrict_init_inplace_module(self):
return InplaceLogTwo()

def combine_scale_threshold(self, x: Tensor, threshold: Tensor) -> Tensor:
return x - threshold

@brevitas.jit.script_method
def forward(self, x: Tensor):
x = self.float_to_int_impl(x)
Expand Down
56 changes: 44 additions & 12 deletions src/brevitas/core/scaling/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,18 @@ def __init__(
tracked_parameter_list: List[torch.nn.Parameter],
scaling_shape: Tuple[int, ...],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(StatsFromParameterScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.parameter_list_stats = _ParameterListStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -44,6 +50,7 @@ def __init__(
tracked_parameter_list)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand All @@ -65,6 +72,7 @@ class _StatsScaling(brevitas.jit.ScriptModule):
def __init__(
self,
restrict_scaling_impl: Module,
restrict_threshold_impl: Module,
scaling_shape: Tuple[int, ...],
scaling_min_val: Optional[float],
affine_rescaling: bool,
Expand All @@ -81,19 +89,22 @@ def __init__(
else:
self.affine_rescaling = Identity()
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = restrict_scaling_impl.restrict_init_module()
self.restrict_scaling_impl = restrict_scaling_impl
self.restrict_threshold_pre = restrict_threshold_impl.restrict_init_module()

@brevitas.jit.script_method
def forward(
self, stats: torch.Tensor, threshold: Optional[torch.Tensor] = None) -> torch.Tensor:
if threshold is None:
threshold = torch.ones(1).type_as(stats)
threshold = self.restrict_scaling_pre(threshold)
threshold = self.restrict_threshold_pre(threshold)
threshold = self.restrict_clamp_threshold(threshold)
stats = self.restrict_scaling_pre(stats)
stats = self.restrict_scaling_impl.combine_scale_threshold(stats, threshold)
stats = self.affine_rescaling(stats)
stats = self.restrict_clamp_scaling(stats)
stats = stats / threshold
return stats


Expand All @@ -107,12 +118,17 @@ def __init__(
affine_rescaling: bool = False,
affine_shift_scale: bool = False,
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None,
scaling_stats_momentum: float = DEFAULT_MOMENTUM,
scaling_min_val: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None) -> None:
super(RuntimeStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.runtime_stats = _RuntimeStats(
scaling_stats_impl,
scaling_shape,
Expand All @@ -122,6 +138,7 @@ def __init__(
device)
self.stats_scaling_impl = _StatsScaling(
restrict_scaling_impl,
restrict_threshold_impl,
scaling_shape,
scaling_min_val,
affine_rescaling,
Expand Down Expand Up @@ -173,20 +190,32 @@ def _load_from_state_dict(
class RuntimeDynamicGroupStatsScaling(brevitas.jit.ScriptModule):

def __init__(
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue()) -> None:
self,
group_size: int,
group_dim: int,
input_view_impl: Module,
scaling_stats_impl: Module,
scaling_min_val: Optional[float],
restrict_scaling_impl: Module = FloatRestrictValue(),
restrict_threshold_impl: Optional[Module] = None) -> None:
super(RuntimeDynamicGroupStatsScaling, self).__init__()

# Ensure retro-compatibility with shared threshold/scaling restrict
if restrict_threshold_impl is None:
restrict_threshold_impl = restrict_scaling_impl

self.group_size = group_size
self.group_dim = group_dim
self.scaling_stats_impl = scaling_stats_impl
self.scaling_min_val = scaling_min_val
self.input_view_impl = input_view_impl
self.restrict_clamp_scaling = _RestrictClampValue(scaling_min_val, restrict_scaling_impl)
self.restrict_clamp_threshold = _RestrictClampValue(
restrict_value_impl=restrict_threshold_impl)
self.restrict_scaling_pre = self.restrict_clamp_scaling.restrict_value_impl.restrict_init_module(
)
self.restrict_threshold_pre = self.restrict_clamp_threshold.restrict_value_impl.restrict_init_module(
)

@brevitas.jit.script_method
def forward(
Expand All @@ -196,7 +225,10 @@ def forward(
if threshold is None:
threshold = torch.ones(1).type_as(stats_input)
stats_input_reshaped = self.input_view_impl(stats_input)
out = self.scaling_stats_impl(stats_input_reshaped) / threshold
threshold = self.restrict_clamp_threshold(self.restrict_threshold_pre(threshold))
out = self.scaling_stats_impl(stats_input_reshaped)
# Apply log scaling
out = self.restrict_scaling_pre(out)
# Scaling min val
out = self.restrict_clamp_scaling(out)
out = self.restrict_clamp_scaling(out) / threshold
return out
Loading

0 comments on commit 75c45d9

Please sign in to comment.