diff --git a/rfcs/20240808-f8E4M3_f8E3M4.md b/rfcs/20240808-f8E4M3_f8E3M4.md
new file mode 100644
index 0000000000..8539267b96
--- /dev/null
+++ b/rfcs/20240808-f8E4M3_f8E3M4.md
@@ -0,0 +1,178 @@
+# RFC: Float8E4M3 and Float8E3M4
+
+Status: In Review
+Initial version: 8/8/2024
+Last updated: 8/9/2024
+Discussion thread: [PR-2486](https://github.com/openxla/stablehlo/pull/2486)
+[RFC] Add f8E4M3 and f8E3M4 types support
+
+## Summary
+
+Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These
+types are implemented in commercially available hardware[^1], and added to MLIR
+builtin types[^2]˒[^3] and LLVM APFloat[^4]˒[^5].
+
+Both Float8E4M3 and Float8E3M4 follows IEEE 754 convention similar to existing
+type Float8E5M2.
+
+### Float8E4M3
+
+8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa
+following IEEE-754 conventions with bit layout S1E4M3.
+
+```c
+f8E4M3 (IEEE 754)
+- Exponent bias: 7
+- Minimum stored exponent value: 1 (binary 0001)
+- Maximum stored exponent value: 14 (binary 1110)
+- Minimum unbiased exponent value: 1 − 7 = −6
+- Maximum unbiased exponent value: 14 - 7 = 7
+- Precision specifies the total number of bits used for the significand
+ (mantisa), including implicit leading integer bit = 3 + 1 = 4
+- Follows IEEE 754 conventions for representation of special values
+- Has Positive and Negative zero
+- Has Positive and Negative infinity
+- Has NaNs
+
+Additional details:
+- Min exp (unbiased): -6
+- Max exp (unbiased): 7
+- Infinities (+/-): S.1111.000
+- Zeros (+/-): S.0000.000
+- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
+- Min normal number: S.0001.000 = +/-2^(1 - 7) x (1 + 0) = +/-2^(-6)
+- Max normal number: S.1110.111 = +/-2^(14 - 7) x (1 + 7/8) = +/-240
+- Min subnormal number: S.0000.001 = +/-2^(-6) x 1/8 = +/-2^(-9)
+- Max subnormal number: S.0000.111 = +/-2^(-6) x 7/8 = +/-2^(-9) x 7
+```
+
+#### Comparison of Float8E4M3FN and Float8E4M3
+
+| |Float8E4M3FN |Float8E4M3 |
+|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------|
+|Bias |7 |7 |
+|Min Normal Value |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-6 |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-6 |
+|Max Normal Value |`0bS1111110` = -1S $\times$ 1.75 $\times$ 28 = 448|`0bS1110111` = -1S $\times$ 1.875 $\times$ 27 = 240|
+|Min Subnormal Value|`0bS0000001` = -1S $\times$ 0.125 $\times$ 2-6 |`0bS0000001` = -1S $\times$ 0.125 $\times$ 2-6 |
+|Max Subnormal Value|`0bS0000111` = -1S $\times$ 0.875 $\times$ 2-6 |`0bS0000111` = -1S $\times$ 0.875 $\times$ 2-6 |
+|NaN |`0bS1111111` |`0bS1111MMM`, where `MMM` is non-zero. |
+|Infinity |N/A |`0bS1111000` |
+|-0.0 |`0b10000000` |`0b10000000` |
+
+### Float8E3M4
+
+8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits mantissa
+following IEEE-754 conventions with bit layout S1E3M4.
+
+```c
+f8E3M4 (IEEE 754)
+- Exponent bias: 3
+- Minimum stored exponent value: 1 (binary 001)
+- Maximum stored exponent value: 6 (binary 110)
+- Minimum unbiased exponent value: 1 − 3 = −2
+- Maximum unbiased exponent value: 6 - 3 = 3
+- Precision specifies the total number of bits used for the significand
+ (mantissa), including implicit leading integer bit = 4 + 1 = 5
+- Follows IEEE 754 conventions for representation of special values
+- Has Positive and Negative zero
+- Has Positive and Negative infinity
+- Has NaNs
+
+Additional details:
+- Min exp (unbiased): -2
+- Max exp (unbiased): 3
+- Infinities (+/-): S.111.0000
+- Zeros (+/-): S.000.0000
+- NaNs: S.111.{0,1}⁴ except S.111.0000
+- Min normal number: S.001.0000 = +/-2^(1 - 3) x (1 + 0) = +/-0.25
+- Max normal number: S.110.1111 = +/-2^(6 - 3) x (1 + 15/16) = +/-15.5
+- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-6)
+- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-6) x 15
+```
+
+### Comparison of Float8E5M2, Float8E4M3 and Float8E3M4
+
+| |Float8E5M2 |Float8E4M3 |Float8E3M4 |
+|-------------------|----------------------------------------------------------------------------|-------------------------------------------------------------------------|---------------------------------------------------------------------------|
+|Bias |15 |7 |3 |
+|Min Normal Value |`0bS0000100` = -1S $\times$ 1.0 $\times$ 2-14 |`0bS0001000` = -1S $\times$ 1.0 $\times$ 2-6 |`0bS0010000` = -1S $\times$ 1.0 $\times$ 2-2 |
+|Max Normal Value |`0bS1111011` = -1S $\times$ 1.75 $\times$ 215 = 57344 |`0bS1110111` = -1S $\times$ 1.875 $\times$ 27 = 240|`0bS1101111` = -1S $\times$ 1.9375 $\times$ 23 = 15.5|
+|Min Subnormal Value|`0bS0000001` = -1S $\times$ 0.25 $\times$ 2-14 |`0bS0000001` = -1S $\times$ 0.125 $\times$ 2-6 |`0bS0000001` = -1S $\times$ 0.0625 $\times$ 2-2 |
+|Max Subnormal Value|`0bS0000011` = -1S $\times$ 0.75 $\times$ 2-14 |`0bS0000111` = -1S $\times$ 0.875 $\times$ 2-6 |`0bS0001111` = -1S $\times$ 0.9375 $\times$ 2-2 |
+|NaN |`0bS11111MM`, where `MM` is non-zero. |`0bS1111MMM`, where `MMM` is non-zero. |`0bS111MMMM`, where `MMMM` is non-zero. |
+|Infinity |`0bS1111100` |`0bS1111000` |`0bS1110000` |
+|-0.0 |`0b10000000` |`0b10000000` |`0b10000000` |
+
+## Changes in StableHLO
+
+I propose adding Float8E4M3 and Float8E3M4 types to StableHLO similar to the
+previously introduces FP8 types (below) with some differences:
+
+- [FP8 RFC](https://github.com/openxla/xla/discussions/22)
+- [[RFC] Add Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO](https://github.com/openxla/stablehlo/pull/1342)
+
+### StableHLO Interpreter
+
+To provide a reference implementation, I intend to add support for
+Float8E4M3 and Float8E3M4 in the StableHLO interpreter. This will be
+useful for testing other backends and validating new implementations. This will
+be achieved in two ways:
+
+1. Map directly to the appropriate APFloat operation.
+2. Cast up to the appropriate type, use that implementation, cast back down.
+
+### Float8E4M3 and Float8E3M4 Arithmetic
+
+I intend for Float8E4M3 and Float8E3M4 to be types that support the
+appropriate arithmetic operations, like any other floating point type. For
+platforms that don't have hardware support for these types, they may either
+throw an error and reject the program or cast up to an appropriate higher
+precision type that is supported, compute the answer, and cast back down.
+
+This is a simple approach that aligns with user expectations of a floating
+point data type, and is the approach taken by BFloat16. This also gives
+backends freedom to exploit any hardware support.
+
+Here's an example of a real JAX program (logging the MLIR) computing a simple
+dot product in Float8E4M3. Note the answer is slightly "wrong", as expected
+due to the lower precision (round-to-nearest).
+
+```python
+>>> import jax
+>>> import jax.numpy as jnp
+>>> x = jnp.arange(8, dtype=jnp.float8_e4m3)
+module @jit_iota {
+ func.func public @main() -> tensor<8xf8E4M3> {
+ %0 = stablehlo.iota dim = 0 : tensor<8xf8E4M3>
+ return %0 : tensor<8xf8E4M3>
+ }
+}
+>>> x
+Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=float8_e4m3)
+>>> x @ x
+module @jit_matmul {
+ func.func public @main(%arg0: tensor<8xf8E4M3> {mhlo.sharding = ""}, %arg1: tensor<8xf8E4M3> {mhlo.sharding = ""}) -> tensor {
+ %0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot, precision_config = [#stablehlo, #stablehlo]} : (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor
+ return %0 : tensor
+ }
+}
+Array(144, dtype=float8_e4m3)
+```
+
+### Testing
+
+Built on the StableHLO interpreter, I intend to introduce tests for all
+possible operations with Float8E4M3 and Float8E3M4 inputs. This will at
+a minimum mean adding additional cases to the `interpret_X.mlir` family of
+tests.
+
+### References and Links
+
+- [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
+- [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
+
+[^1]: [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
+[^2]: LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
+[^3]: LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
+[^4]: LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
+[^5]: LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
diff --git a/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/integrations/c/StablehloAttributes.cpp
index 8707790810..4f888c3d4b 100644
--- a/stablehlo/integrations/c/StablehloAttributes.cpp
+++ b/stablehlo/integrations/c/StablehloAttributes.cpp
@@ -212,6 +212,60 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
.getIndexVectorDim();
}
+//===----------------------------------------------------------------------===//
+// DotAlgorithm
+//===----------------------------------------------------------------------===//
+
+MlirAttribute stablehloDotAlgorithmGet(
+ MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
+ MlirType accumulationType, int64_t lhsComponentCount,
+ int64_t rhsComponentCount, int64_t numPrimitiveOperations,
+ bool allowImpreciseAccumulation) {
+ return wrap(mlir::stablehlo::DotAlgorithmAttr::get(
+ unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType),
+ unwrap(accumulationType), lhsComponentCount, rhsComponentCount,
+ numPrimitiveOperations, allowImpreciseAccumulation));
+}
+
+bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) {
+ return llvm::isa(unwrap(attr));
+}
+
+MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) {
+ return wrap(llvm::cast(unwrap(attr))
+ .getLhsPrecisionType());
+}
+
+MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) {
+ return wrap(llvm::cast(unwrap(attr))
+ .getRhsPrecisionType());
+}
+
+MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) {
+ return wrap(llvm::cast(unwrap(attr))
+ .getAccumulationType());
+}
+
+int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) {
+ return llvm::cast(unwrap(attr))
+ .getLhsComponentCount();
+}
+
+int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) {
+ return llvm::cast(unwrap(attr))
+ .getRhsComponentCount();
+}
+
+int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) {
+ return llvm::cast(unwrap(attr))
+ .getNumPrimitiveOperations();
+}
+
+bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr) {
+ return llvm::cast(unwrap(attr))
+ .getAllowImpreciseAccumulation();
+}
+
//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
diff --git a/stablehlo/integrations/c/StablehloAttributes.h b/stablehlo/integrations/c/StablehloAttributes.h
index 2663e12a4d..897bfaa1a4 100644
--- a/stablehlo/integrations/c/StablehloAttributes.h
+++ b/stablehlo/integrations/c/StablehloAttributes.h
@@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem(
MLIR_CAPI_EXPORTED int64_t
stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr);
+//===----------------------------------------------------------------------===//
+// DotAlgorithm
+//===----------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet(
+ MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
+ MlirType accumulationType, int64_t lhsComponentCount,
+ int64_t rhsComponentCount, int64_t numPrimitiveOperations,
+ bool allowImpreciseAccumulation);
+
+MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirType
+stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirType
+stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED MlirType
+stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED int64_t
+stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED int64_t
+stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED int64_t
+stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr);
+
+MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(
+ MlirAttribute attr);
+
//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
diff --git a/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/integrations/python/StablehloModule.cpp
index 5fd995dd95..a3f05b8a74 100644
--- a/stablehlo/integrations/python/StablehloModule.cpp
+++ b/stablehlo/integrations/python/StablehloModule.cpp
@@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) {
return stablehloGatherDimensionNumbersGetIndexVectorDim(self);
});
+ mlir::python::adaptors::mlir_attribute_subclass(
+ m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm)
+ .def_classmethod(
+ "get",
+ [](py::object cls, MlirType lhsPrecisionType,
+ MlirType rhsPrecisionType, MlirType accumulationType,
+ int64_t lhsComponentCount, int64_t rhsComponentCount,
+ int64_t numPrimitiveOperations, bool allowImpreciseAccumulation,
+ MlirContext ctx) {
+ return cls(stablehloDotAlgorithmGet(
+ ctx, lhsPrecisionType, rhsPrecisionType, accumulationType,
+ lhsComponentCount, rhsComponentCount, numPrimitiveOperations,
+ allowImpreciseAccumulation));
+ },
+ py::arg("cls"), py::arg("lhs_precision_type"),
+ py::arg("rhs_precision_type"), py::arg("accumulation_type"),
+ py::arg("lhs_component_count"), py::arg("rhs_component_count"),
+ py::arg("num_primitive_operations"),
+ py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(),
+ "Creates a DotAlgorithm attribute with the given dimension "
+ "configuration.")
+ .def_property_readonly(
+ "lhs_precision_type",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetLhsPrecisionType(self);
+ })
+ .def_property_readonly(
+ "rhs_precision_type",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetRhsPrecisionType(self);
+ })
+ .def_property_readonly(
+ "accumulation_type",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetAccumulationType(self);
+ })
+ .def_property_readonly(
+ "lhs_component_count",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetLhsComponentCount(self);
+ })
+ .def_property_readonly(
+ "rhs_component_count",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetRhsComponentCount(self);
+ })
+ .def_property_readonly(
+ "num_primitive_operations",
+ [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetNumPrimitiveOperations(self);
+ })
+ .def_property_readonly(
+ "allow_imprecise_accumulation", [](MlirAttribute self) {
+ return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self);
+ });
+
mlir::python::adaptors::mlir_attribute_subclass(
m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers)
.def_classmethod(
diff --git a/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/integrations/python/tests/stablehlo.py
index 4005d3fb2e..39a0cfd353 100644
--- a/stablehlo/integrations/python/tests/stablehlo.py
+++ b/stablehlo/integrations/python/tests/stablehlo.py
@@ -82,6 +82,32 @@ def test_conv_dimension_numbers():
assert attr.output_spatial_dimensions == [2, 3]
+@run
+def test_dot_algorithm():
+ # BF16_BF16_F32_X3
+ attr = stablehlo.DotAlgorithm.get(
+ lhs_precision_type=ir.BF16Type.get(),
+ rhs_precision_type=ir.BF16Type.get(),
+ accumulation_type=ir.F32Type.get(),
+ lhs_component_count=1,
+ rhs_component_count=1,
+ num_primitive_operations=3,
+ allow_imprecise_accumulation=False)
+ assert attr is not None
+ assert str(attr) == ("#stablehlo.dot_algorithm")
+ assert isinstance(attr.lhs_precision_type, ir.BF16Type)
+ assert isinstance(attr.rhs_precision_type, ir.BF16Type)
+ assert isinstance(attr.accumulation_type, ir.F32Type)
+ assert attr.lhs_component_count == 1
+ assert attr.rhs_component_count == 1
+ assert attr.num_primitive_operations == 3
+ assert attr.allow_imprecise_accumulation == False
+
+
@run
def test_dot_dimension_numbers():
attr = stablehlo.DotDimensionNumbers.get(