-
Notifications
You must be signed in to change notification settings - Fork 199
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (minifloat): add support for user specified minifloat format (#821)
- Loading branch information
1 parent
9443076
commit 0a023aa
Showing
11 changed files
with
453 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from brevitas.quant.base import MSESymmetricScale | ||
from brevitas.quant.experimental.float_base import FloatActBase | ||
from brevitas.quant.experimental.float_base import FloatWeightBase | ||
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin | ||
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin | ||
from brevitas.quant.experimental.float_base import ScaledFloatActBase | ||
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase | ||
|
||
|
||
class Fp8e4m3OCPMixin(Fp8e4m3Mixin): | ||
nan_values = (('111',)) | ||
inf_values = None | ||
|
||
|
||
class Fp8e5m2OCPMixin(Fp8e5m2Mixin): | ||
nan_values = ('01', '11', '10') | ||
inf_values = (('00',)) | ||
|
||
|
||
class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer. | ||
""" | ||
pass | ||
|
||
|
||
class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
|
||
|
||
class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
|
||
|
||
class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = False | ||
|
||
|
||
class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase): | ||
""" | ||
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
scaling_stats_permute_dims = (1, 0, 2, 3) | ||
|
||
|
||
class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = True | ||
|
||
|
||
class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase): | ||
""" | ||
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling. | ||
""" | ||
scaling_per_output_channel = False |
Oops, something went wrong.