Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FP8 weight export #907

Merged
merged 46 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5b168c5
placeholder version
costigt-dev Apr 11, 2024
d2b7d2d
checkpoint commit
costigt-dev Apr 12, 2024
e10e630
first working flow end to end
costigt-dev Apr 12, 2024
84e70f7
formatting
costigt-dev Apr 12, 2024
ef4c737
changes to tests
costigt-dev Apr 12, 2024
4aa4b21
added version check for test
costigt-dev Apr 12, 2024
3b05883
using existing functionality over homespun
costigt-dev Apr 12, 2024
cad5802
corrected mistake in copying and restored FloatClipMixin
costigt-dev Apr 12, 2024
4848248
fixed mistake
costigt-dev Apr 12, 2024
5188aa6
first pass activation fp8 export
costigt-dev Apr 16, 2024
29cb952
beginnings of activation fp8 export and change name of QCDQCastFloatW…
costigt-dev Apr 16, 2024
9bf9240
more changes to make naming scheme more consistent
costigt-dev Apr 16, 2024
f9406f1
added FloatFusedActivationQuantProxy
costigt-dev Apr 16, 2024
991ddb7
replaced zero_point workaround with placeholder implementation of fp8…
costigt-dev Apr 17, 2024
520db85
removed verbose flag
costigt-dev Apr 17, 2024
2bb2895
created context manager for fp8 workaround
costigt-dev Apr 17, 2024
8ffce48
added check that objects being compared are tensors in the fp8 workar…
costigt-dev Apr 17, 2024
7edf5bd
General equal implementation
Giuseppe5 May 14, 2024
bbd5362
fallback to fp32 if fp8
Giuseppe5 May 14, 2024
4bc126d
Fix for PT < 2.1
Giuseppe5 May 14, 2024
a55dcd0
Remove non existent destroy
Giuseppe5 May 14, 2024
cd6cad6
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 23, 2024
fabc8ae
Remove import
Giuseppe5 May 23, 2024
74b65a9
Fixed imports
Giuseppe5 May 23, 2024
cf1ea02
Fixed imports
Giuseppe5 May 23, 2024
cda7f1f
Fix export
Giuseppe5 May 23, 2024
8349391
more testing
Giuseppe5 May 23, 2024
11387d3
Fix
Giuseppe5 May 24, 2024
592ccd3
Fix
Giuseppe5 May 24, 2024
1fc5642
fix
Giuseppe5 May 25, 2024
58f46bc
Fix minifloat check
Giuseppe5 May 25, 2024
bd657b8
Last fix
Giuseppe5 May 25, 2024
630a3e3
Fix minifloat
Giuseppe5 May 27, 2024
38a37fb
Review
Giuseppe5 May 28, 2024
76b3193
Review 2
Giuseppe5 May 28, 2024
529470f
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 28, 2024
f2f8969
fix
Giuseppe5 May 28, 2024
44579f8
Typo
Giuseppe5 May 28, 2024
038cba9
fix tests
Giuseppe5 May 28, 2024
198c5af
Typo
Giuseppe5 May 28, 2024
c3d7d3c
fix
Giuseppe5 May 28, 2024
fef531d
last fix
Giuseppe5 May 28, 2024
6431882
Fix JIT
Giuseppe5 May 29, 2024
4b78543
Fix import
Giuseppe5 May 29, 2024
d762c99
Last fix
Giuseppe5 May 29, 2024
ac5e58c
correct skip
Giuseppe5 May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.scaling import ConstScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste
from brevitas.utils.torch_utils import float_internal_scale


class FloatQuant(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -64,21 +63,15 @@ def __init__(
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny

@brevitas.jit.script_method
def internal_scale(self, x):
internal_scale = floor_ste(torch.log2(torch.abs(x) + self.eps)) - self.mantissa_bit_width()
internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min())
internal_scale = torch.exp2(internal_scale)
return internal_scale

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scaling_impl_value = self.scaling_impl(x)
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scaling_impl_value / float_scaling_impl_value
scaled_x = x / scale
internal_scale = self.internal_scale(scaled_x)
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min())
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
return val_fp_quant, scale

Expand Down
23 changes: 22 additions & 1 deletion src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC
from abc import abstractmethod
import math
from warnings import warn

import torch
from torch import Tensor
Expand All @@ -12,7 +13,8 @@
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']
__all__ = [
'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin']


class BaseHandler(Module, ABC):
Expand All @@ -38,6 +40,13 @@ def quant_axis(cls, scale):
return None


class FloatClipMixin(ABC):

@classmethod
def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width):
return None


class ClipMixin(ABC):

@classmethod
Expand Down Expand Up @@ -112,6 +121,18 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
return -cls.validate_scalar_int_exponent(scale)


class FloatZeroPointHandlerMixin(ABC):

@classmethod
def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point):
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
else:
return zero_point.type(torch.float32)


class ZeroPointHandlerMixin(ABC):

@classmethod
Expand Down
Loading
Loading