-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RFC] Add f8E4M3 and f8E3M4 types support
- Loading branch information
1 parent
726ac03
commit 389e72c
Showing
1 changed file
with
178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# RFC: Float8E4M3 and Float8E3M4 | ||
|
||
Status: In Review<br/> | ||
Initial version: 8/8/2024<br/> | ||
Last updated: 8/9/2024<br/> | ||
Discussion thread: [PR-2486](https://github.com/openxla/stablehlo/pull/2486) | ||
[RFC] Add f8E4M3 and f8E3M4 types support | ||
|
||
## Summary | ||
|
||
Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These | ||
types are implemented in commercially available hardware[^1], and added to MLIR | ||
builtin types[^2]˒[^3] and LLVM APFloat[^4]˒[^5]. | ||
|
||
Both Float8E4M3 and Float8E3M4 follows IEEE 754 convention similar to existing | ||
type Float8E5M2. | ||
|
||
### Float8E4M3 | ||
|
||
8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa | ||
following IEEE-754 conventions with bit layout S1E4M3. | ||
|
||
```c | ||
f8E4M3 (IEEE 754) | ||
- Exponent bias: 7 | ||
- Minimum stored exponent value: 1 (binary 0001) | ||
- Maximum stored exponent value: 14 (binary 1110) | ||
- Minimum unbiased exponent value: 1 − 7 = −6 | ||
- Maximum unbiased exponent value: 14 - 7 = 7 | ||
- Precision specifies the total number of bits used for the significand | ||
(mantisa), including implicit leading integer bit = 3 + 1 = 4 | ||
- Follows IEEE 754 conventions for representation of special values | ||
- Has Positive and Negative zero | ||
- Has Positive and Negative infinity | ||
- Has NaNs | ||
|
||
Additional details: | ||
- Min exp (unbiased): -6 | ||
- Max exp (unbiased): 7 | ||
- Infinities (+/-): S.1111.000 | ||
- Zeros (+/-): S.0000.000 | ||
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} | ||
- Min normal number: S.0001.000 = +/-2^(1 - 7) x (1 + 0) = +/-2^(-6) | ||
- Max normal number: S.1110.111 = +/-2^(14 - 7) x (1 + 7/8) = +/-240 | ||
- Min subnormal number: S.0000.001 = +/-2^(-6) x 1/8 = +/-2^(-9) | ||
- Max subnormal number: S.0000.111 = +/-2^(-6) x 7/8 = +/-2^(-9) x 7 | ||
``` | ||
#### Comparison of Float8E4M3FN and Float8E4M3 | ||
| |Float8E4M3FN |Float8E4M3 | | ||
|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------| | ||
|Bias |7 |7 | | ||
|Min Normal Value |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> | | ||
|Max Normal Value |`0bS1111110` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>8</sup> = 448|`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240| | ||
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> | | ||
|Max Subnormal Value|`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> | | ||
|NaN |`0bS1111111` |`0bS1111MMM`, where `MMM` is non-zero. | | ||
|Infinity |N/A |`0bS1111000` | | ||
|-0.0 |`0b10000000` |`0b10000000` | | ||
### Float8E3M4 | ||
8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits mantissa | ||
following IEEE-754 conventions with bit layout S1E3M4. | ||
```c | ||
f8E3M4 (IEEE 754) | ||
- Exponent bias: 3 | ||
- Minimum stored exponent value: 1 (binary 001) | ||
- Maximum stored exponent value: 6 (binary 110) | ||
- Minimum unbiased exponent value: 1 − 3 = −2 | ||
- Maximum unbiased exponent value: 6 - 3 = 3 | ||
- Precision specifies the total number of bits used for the significand | ||
(mantissa), including implicit leading integer bit = 4 + 1 = 5 | ||
- Follows IEEE 754 conventions for representation of special values | ||
- Has Positive and Negative zero | ||
- Has Positive and Negative infinity | ||
- Has NaNs | ||
Additional details: | ||
- Min exp (unbiased): -2 | ||
- Max exp (unbiased): 3 | ||
- Infinities (+/-): S.111.0000 | ||
- Zeros (+/-): S.000.0000 | ||
- NaNs: S.111.{0,1}⁴ except S.111.0000 | ||
- Min normal number: S.001.0000 = +/-2^(1 - 3) x (1 + 0) = +/-0.25 | ||
- Max normal number: S.110.1111 = +/-2^(6 - 3) x (1 + 15/16) = +/-15.5 | ||
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-6) | ||
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-6) x 15 | ||
``` | ||
|
||
### Comparison of Float8E5M2, Float8E4M3 and Float8E3M4 | ||
|
||
| |Float8E5M2 |Float8E4M3 |Float8E3M4 | | ||
|-------------------|----------------------------------------------------------------------------|-------------------------------------------------------------------------|---------------------------------------------------------------------------| | ||
|Bias |15 |7 |3 | | ||
|Min Normal Value |`0bS0000100` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-14</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0010000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-2</sup> | | ||
|Max Normal Value |`0bS1111011` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>15</sup> = 57344 |`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|`0bS1101111` = -1<sup>S</sup> $\times$ 1.9375 $\times$ 2<sup>3</sup> = 15.5| | ||
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.25 $\times$ 2<sup>-14</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.0625 $\times$ 2<sup>-2</sup> | | ||
|Max Subnormal Value|`0bS0000011` = -1<sup>S</sup> $\times$ 0.75 $\times$ 2<sup>-14</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0001111` = -1<sup>S</sup> $\times$ 0.9375 $\times$ 2<sup>-2</sup> | | ||
|NaN |`0bS11111MM`, where `MM` is non-zero. |`0bS1111MMM`, where `MMM` is non-zero. |`0bS111MMMM`, where `MMMM` is non-zero. | | ||
|Infinity |`0bS1111100` |`0bS1111000` |`0bS1110000` | | ||
|-0.0 |`0b10000000` |`0b10000000` |`0b10000000` | | ||
|
||
## Changes in StableHLO | ||
|
||
I propose adding Float8E4M3 and Float8E3M4 types to StableHLO similar to the | ||
previously introduces FP8 types (below) with some differences: | ||
|
||
- [FP8 RFC](https://github.com/openxla/xla/discussions/22) | ||
- [[RFC] Add Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO](https://github.com/openxla/stablehlo/pull/1342) | ||
|
||
### StableHLO Interpreter | ||
|
||
To provide a reference implementation, I intend to add support for | ||
Float8E4M3 and Float8E3M4 in the StableHLO interpreter. This will be | ||
useful for testing other backends and validating new implementations. This will | ||
be achieved in two ways: | ||
|
||
1. Map directly to the appropriate APFloat operation. | ||
2. Cast up to the appropriate type, use that implementation, cast back down. | ||
|
||
### Float8E4M3 and Float8E3M4 Arithmetic | ||
|
||
I intend for Float8E4M3 and Float8E3M4 to be types that support the | ||
appropriate arithmetic operations, like any other floating point type. For | ||
platforms that don't have hardware support for these types, they may either | ||
throw an error and reject the program or cast up to an appropriate higher | ||
precision type that is supported, compute the answer, and cast back down. | ||
|
||
This is a simple approach that aligns with user expectations of a floating | ||
point data type, and is the approach taken by BFloat16. This also gives | ||
backends freedom to exploit any hardware support. | ||
|
||
Here's an example of a real JAX program (logging the MLIR) computing a simple | ||
dot product in Float8E4M3. Note the answer is slightly "wrong", as expected | ||
due to the lower precision (round-to-nearest). | ||
|
||
```python | ||
>>> import jax | ||
>>> import jax.numpy as jnp | ||
>>> x = jnp.arange(8, dtype=jnp.float8_e4m3) | ||
module @jit_iota { | ||
func.func public @main() -> tensor<8xf8E4M3> { | ||
%0 = stablehlo.iota dim = 0 : tensor<8xf8E4M3> | ||
return %0 : tensor<8xf8E4M3> | ||
} | ||
} | ||
>>> x | ||
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=float8_e4m3) | ||
>>> x @ x | ||
module @jit_matmul { | ||
func.func public @main(%arg0: tensor<8xf8E4M3> {mhlo.sharding = ""}, %arg1: tensor<8xf8E4M3> {mhlo.sharding = ""}) -> tensor<f8E4M3> { | ||
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor<f8E4M3> | ||
return %0 : tensor<f8E4M3> | ||
} | ||
} | ||
Array(144, dtype=float8_e4m3) | ||
``` | ||
|
||
### Testing | ||
|
||
Built on the StableHLO interpreter, I intend to introduce tests for all | ||
possible operations with Float8E4M3 and Float8E3M4 inputs. This will at | ||
a minimum mean adding additional cases to the `interpret_X.mlir` family of | ||
tests. | ||
|
||
### References and Links | ||
|
||
- [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md) | ||
- [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md) | ||
|
||
[^1]: [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/) | ||
[^2]: LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) | ||
[^3]: LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) | ||
[^4]: LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) | ||
[^5]: LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) |