Skip to content

Commit

Permalink
Add DotAlgorithm to StableHLO Python API (#2521)
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Sep 3, 2024
1 parent 21dcdd2 commit 1456dfa
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
54 changes: 54 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,60 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
.getIndexVectorDim();
}

//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation) {
return wrap(mlir::stablehlo::DotAlgorithmAttr::get(
unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType),
unwrap(accumulationType), lhsComponentCount, rhsComponentCount,
numPrimitiveOperations, allowImpreciseAccumulation));
}

bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) {
return llvm::isa<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr));
}

MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAccumulationType());
}

int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsComponentCount();
}

int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsComponentCount();
}

int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getNumPrimitiveOperations();
}

bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAllowImpreciseAccumulation();
}

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem(
MLIR_CAPI_EXPORTED int64_t
stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation);

MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(
MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions stablehlo/integrations/python/StablehloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) {
return stablehloGatherDimensionNumbersGetIndexVectorDim(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm)
.def_classmethod(
"get",
[](py::object cls, MlirType lhsPrecisionType,
MlirType rhsPrecisionType, MlirType accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation,
MlirContext ctx) {
return cls(stablehloDotAlgorithmGet(
ctx, lhsPrecisionType, rhsPrecisionType, accumulationType,
lhsComponentCount, rhsComponentCount, numPrimitiveOperations,
allowImpreciseAccumulation));
},
py::arg("cls"), py::arg("lhs_precision_type"),
py::arg("rhs_precision_type"), py::arg("accumulation_type"),
py::arg("lhs_component_count"), py::arg("rhs_component_count"),
py::arg("num_primitive_operations"),
py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(),
"Creates a DotAlgorithm attribute with the given dimension "
"configuration.")
.def_property_readonly(
"lhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsPrecisionType(self);
})
.def_property_readonly(
"rhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsPrecisionType(self);
})
.def_property_readonly(
"accumulation_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetAccumulationType(self);
})
.def_property_readonly(
"lhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsComponentCount(self);
})
.def_property_readonly(
"rhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsComponentCount(self);
})
.def_property_readonly(
"num_primitive_operations",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetNumPrimitiveOperations(self);
})
.def_property_readonly(
"allow_imprecise_accumulation", [](MlirAttribute self) {
return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers)
.def_classmethod(
Expand Down
26 changes: 26 additions & 0 deletions stablehlo/integrations/python/tests/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def test_conv_dimension_numbers():
assert attr.output_spatial_dimensions == [2, 3]


@run
def test_dot_algorithm():
# BF16_BF16_F32_X3
attr = stablehlo.DotAlgorithm.get(
lhs_precision_type=ir.BF16Type.get(),
rhs_precision_type=ir.BF16Type.get(),
accumulation_type=ir.F32Type.get(),
lhs_component_count=1,
rhs_component_count=1,
num_primitive_operations=3,
allow_imprecise_accumulation=False)
assert attr is not None
assert str(attr) == ("#stablehlo.dot_algorithm<lhs_precision_type = bf16, "
"rhs_precision_type = bf16, accumulation_type = f32, "
"lhs_component_count = 1, rhs_component_count = 1, "
"num_primitive_operations = 3, "
"allow_imprecise_accumulation = false>")
assert isinstance(attr.lhs_precision_type, ir.BF16Type)
assert isinstance(attr.rhs_precision_type, ir.BF16Type)
assert isinstance(attr.accumulation_type, ir.F32Type)
assert attr.lhs_component_count == 1
assert attr.rhs_component_count == 1
assert attr.num_primitive_operations == 3
assert attr.allow_imprecise_accumulation == False


@run
def test_dot_dimension_numbers():
attr = stablehlo.DotDimensionNumbers.get(
Expand Down

0 comments on commit 1456dfa

Please sign in to comment.