Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DotAlgorithm to StableHLO Python API #2521

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,61 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
.getIndexVectorDim();
}


//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation) {
return wrap(mlir::stablehlo::DotAlgorithmAttr::get(
unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType),
unwrap(accumulationType), lhsComponentCount, rhsComponentCount,
numPrimitiveOperations, allowImpreciseAccumulation));
}

bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) {
return llvm::isa<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr));
}

MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAccumulationType());
}

int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsComponentCount();
}

int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsComponentCount();
}

int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getNumPrimitiveOperations();
}

bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr){
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAllowImpreciseAccumulation();
}

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem(
MLIR_CAPI_EXPORTED int64_t
stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation);

MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(
MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions stablehlo/integrations/python/StablehloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) {
return stablehloGatherDimensionNumbersGetIndexVectorDim(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm)
.def_classmethod(
"get",
[](py::object cls, MlirType lhsPrecisionType,
MlirType rhsPrecisionType, MlirType accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation,
MlirContext ctx) {
return cls(stablehloDotAlgorithmGet(
ctx, lhsPrecisionType, rhsPrecisionType, accumulationType,
lhsComponentCount, rhsComponentCount, numPrimitiveOperations,
allowImpreciseAccumulation));
},
py::arg("cls"), py::arg("lhs_precision_type"),
py::arg("rhs_precision_type"), py::arg("accumulation_type"),
py::arg("lhs_component_count"), py::arg("rhs_component_count"),
py::arg("num_primitive_operations"),
py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(),
"Creates a DotAlgorithm attribute with the given dimension "
"configuration.")
.def_property_readonly(
"lhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsPrecisionType(self);
})
.def_property_readonly(
"rhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsPrecisionType(self);
})
.def_property_readonly(
"accumulation_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetAccumulationType(self);
})
.def_property_readonly(
"lhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsComponentCount(self);
})
.def_property_readonly(
"rhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsComponentCount(self);
})
.def_property_readonly(
"num_primitive_operations",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetNumPrimitiveOperations(self);
})
.def_property_readonly(
"allow_imprecise_accumulation", [](MlirAttribute self) {
return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers)
.def_classmethod(
Expand Down
26 changes: 26 additions & 0 deletions stablehlo/integrations/python/tests/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def test_conv_dimension_numbers():
assert attr.output_spatial_dimensions == [2, 3]


@run
def test_dot_algorithm():
# BF16_BF16_F32_X3
attr = stablehlo.DotAlgorithm.get(
lhs_precision_type=ir.BF16Type.get(),
rhs_precision_type=ir.BF16Type.get(),
accumulation_type=ir.F32Type.get(),
lhs_component_count=1,
rhs_component_count=1,
num_primitive_operations=3,
allow_imprecise_accumulation=False)
assert attr is not None
assert str(attr) == ("#stablehlo.dot_algorithm<lhs_precision_type = bf16, "
"rhs_precision_type = bf16, accumulation_type = f32, "
"lhs_component_count = 1, rhs_component_count = 1, "
"num_primitive_operations = 3, "
"allow_imprecise_accumulation = false>")
assert isinstance(attr.lhs_precision_type, ir.BF16Type)
assert isinstance(attr.rhs_precision_type, ir.BF16Type)
assert isinstance(attr.accumulation_type, ir.F32Type)
assert attr.lhs_component_count == 1
assert attr.rhs_component_count == 1
assert attr.num_primitive_operations == 3
assert attr.allow_imprecise_accumulation == False


@run
def test_dot_dimension_numbers():
attr = stablehlo.DotDimensionNumbers.get(
Expand Down
Loading