diff --git a/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/integrations/c/StablehloAttributes.cpp index 8707790810d..4f888c3d4bc 100644 --- a/stablehlo/integrations/c/StablehloAttributes.cpp +++ b/stablehlo/integrations/c/StablehloAttributes.cpp @@ -212,6 +212,60 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) { .getIndexVectorDim(); } +//===----------------------------------------------------------------------===// +// DotAlgorithm +//===----------------------------------------------------------------------===// + +MlirAttribute stablehloDotAlgorithmGet( + MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType, + MlirType accumulationType, int64_t lhsComponentCount, + int64_t rhsComponentCount, int64_t numPrimitiveOperations, + bool allowImpreciseAccumulation) { + return wrap(mlir::stablehlo::DotAlgorithmAttr::get( + unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType), + unwrap(accumulationType), lhsComponentCount, rhsComponentCount, + numPrimitiveOperations, allowImpreciseAccumulation)); +} + +bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)) + .getLhsPrecisionType()); +} + +MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)) + .getRhsPrecisionType()); +} + +MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)) + .getAccumulationType()); +} + +int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getLhsComponentCount(); +} + +int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getRhsComponentCount(); +} + +int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getNumPrimitiveOperations(); +} + +bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getAllowImpreciseAccumulation(); +} + //===----------------------------------------------------------------------===// // DotDimensionNumbers //===----------------------------------------------------------------------===// diff --git a/stablehlo/integrations/c/StablehloAttributes.h b/stablehlo/integrations/c/StablehloAttributes.h index 2663e12a4da..897bfaa1a48 100644 --- a/stablehlo/integrations/c/StablehloAttributes.h +++ b/stablehlo/integrations/c/StablehloAttributes.h @@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem( MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// DotAlgorithm +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet( + MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType, + MlirType accumulationType, int64_t lhsComponentCount, + int64_t rhsComponentCount, int64_t numPrimitiveOperations, + bool allowImpreciseAccumulation); + +MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirType +stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirType +stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirType +stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr); + +MLIR_CAPI_EXPORTED int64_t +stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr); + +MLIR_CAPI_EXPORTED int64_t +stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr); + +MLIR_CAPI_EXPORTED int64_t +stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr); + +MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation( + MlirAttribute attr); + //===----------------------------------------------------------------------===// // DotDimensionNumbers //===----------------------------------------------------------------------===// diff --git a/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/integrations/python/StablehloModule.cpp index 5fd995dd950..a3f05b8a746 100644 --- a/stablehlo/integrations/python/StablehloModule.cpp +++ b/stablehlo/integrations/python/StablehloModule.cpp @@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) { return stablehloGatherDimensionNumbersGetIndexVectorDim(self); }); + mlir::python::adaptors::mlir_attribute_subclass( + m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm) + .def_classmethod( + "get", + [](py::object cls, MlirType lhsPrecisionType, + MlirType rhsPrecisionType, MlirType accumulationType, + int64_t lhsComponentCount, int64_t rhsComponentCount, + int64_t numPrimitiveOperations, bool allowImpreciseAccumulation, + MlirContext ctx) { + return cls(stablehloDotAlgorithmGet( + ctx, lhsPrecisionType, rhsPrecisionType, accumulationType, + lhsComponentCount, rhsComponentCount, numPrimitiveOperations, + allowImpreciseAccumulation)); + }, + py::arg("cls"), py::arg("lhs_precision_type"), + py::arg("rhs_precision_type"), py::arg("accumulation_type"), + py::arg("lhs_component_count"), py::arg("rhs_component_count"), + py::arg("num_primitive_operations"), + py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(), + "Creates a DotAlgorithm attribute with the given dimension " + "configuration.") + .def_property_readonly( + "lhs_precision_type", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetLhsPrecisionType(self); + }) + .def_property_readonly( + "rhs_precision_type", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetRhsPrecisionType(self); + }) + .def_property_readonly( + "accumulation_type", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetAccumulationType(self); + }) + .def_property_readonly( + "lhs_component_count", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetLhsComponentCount(self); + }) + .def_property_readonly( + "rhs_component_count", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetRhsComponentCount(self); + }) + .def_property_readonly( + "num_primitive_operations", + [](MlirAttribute self) { + return stablehloDotAlgorithmGetNumPrimitiveOperations(self); + }) + .def_property_readonly( + "allow_imprecise_accumulation", [](MlirAttribute self) { + return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self); + }); + mlir::python::adaptors::mlir_attribute_subclass( m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers) .def_classmethod( diff --git a/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/integrations/python/tests/stablehlo.py index 4005d3fb2ea..39a0cfd3538 100644 --- a/stablehlo/integrations/python/tests/stablehlo.py +++ b/stablehlo/integrations/python/tests/stablehlo.py @@ -82,6 +82,32 @@ def test_conv_dimension_numbers(): assert attr.output_spatial_dimensions == [2, 3] +@run +def test_dot_algorithm(): + # BF16_BF16_F32_X3 + attr = stablehlo.DotAlgorithm.get( + lhs_precision_type=ir.BF16Type.get(), + rhs_precision_type=ir.BF16Type.get(), + accumulation_type=ir.F32Type.get(), + lhs_component_count=1, + rhs_component_count=1, + num_primitive_operations=3, + allow_imprecise_accumulation=False) + assert attr is not None + assert str(attr) == ("#stablehlo.dot_algorithm") + assert isinstance(attr.lhs_precision_type, ir.BF16Type) + assert isinstance(attr.rhs_precision_type, ir.BF16Type) + assert isinstance(attr.accumulation_type, ir.F32Type) + assert attr.lhs_component_count == 1 + assert attr.rhs_component_count == 1 + assert attr.num_primitive_operations == 3 + assert attr.allow_imprecise_accumulation == False + + @run def test_dot_dimension_numbers(): attr = stablehlo.DotDimensionNumbers.get(