Skip to content

Commit

Permalink
Fix JAX windows ci build error : missing stablehlo C API symbols (#2530)
Browse files Browse the repository at this point in the history
Issue: JAX windows build is failing because of missing stablehlo C API
symbols
https://github.com/google/jax/actions/runs/10739809804/job/29786443538

root cause: `StablehloApi.h` defs has function overloads. Compiler does
name mangling (decorating function names with additional information).
These symbols are missing in JAX APIs, JAX only allowlist symbol exports
for symbols starting with `stablehlo`, but the mangled c++ names don't
have that property.

fix:
1. add `extern "C" ` around stablehloapi.h declarations. extern "C"
instructs the compiler to suppress the mangling.
2. rename functions to avoid function overloading


thank you @hawkinsp  for help with root causing and validating the fix.
  • Loading branch information
abhigunj authored Sep 6, 2024
1 parent 3cc3767 commit c667089
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 26 deletions.
8 changes: 4 additions & 4 deletions stablehlo/integrations/c/StablehloApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ MlirLogicalResult stablehloGetSmallerVersion(MlirStringRef version1,
return mlirLogicalResultSuccess();
}

MlirLogicalResult stablehloSerializePortableArtifact(
MlirLogicalResult stablehloSerializePortableArtifactFromModule(
MlirModule moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
Expand All @@ -88,7 +88,7 @@ MlirLogicalResult stablehloSerializePortableArtifact(
return mlirLogicalResultSuccess();
}

MlirLogicalResult stablehloSerializePortableArtifact(
MlirLogicalResult stablehloSerializePortableArtifactFromStringRef(
MlirStringRef moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void *userData) {
mlir::detail::CallbackOstream stream(callback, userData);
Expand All @@ -107,8 +107,8 @@ MlirLogicalResult stablehloDeserializePortableArtifact(
return mlirLogicalResultSuccess();
}

MlirModule stablehloDeserializePortableArtifact(MlirStringRef artifactStr,
MlirContext ctx) {
MlirModule stablehloDeserializePortableArtifactNoError(
MlirStringRef artifactStr, MlirContext ctx) {
return wrap(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr),
unwrap(ctx))
.release());
Expand Down
31 changes: 19 additions & 12 deletions stablehlo/integrations/c/StablehloApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ limitations under the License.
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

#ifdef __cplusplus
extern "C" {
#endif

// Get the current StableHLO API version.
//
// This value is incremented as needed to help integrate API changes.
Expand Down Expand Up @@ -72,19 +76,23 @@ stablehloGetSmallerVersion(MlirStringRef version1, MlirStringRef version2,
// `targetVersion` version of StableHLO, e.g. if it's using new or removed
// features, or if it involves unsupported dialects.
// Returns false on failure.
MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
MlirStringRef moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void* userData);
MLIR_CAPI_EXPORTED MlirLogicalResult
stablehloSerializePortableArtifactFromStringRef(MlirStringRef moduleStr,
MlirStringRef targetVersion,
MlirStringCallback callback,
void* userData);

// Write a StableHLO program expressed as a string (either prettyprinted MLIR
// module or MLIR bytecode) to a portable artifact.
// Can fail if `moduleStr` cannot be parsed, or if it cannot be expressed in the
// `targetVersion` version of StableHLO, e.g. if it's using new or removed
// features, or if it involves unsupported dialects.
// Returns false on failure.
MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact(
MlirModule moduleStr, MlirStringRef targetVersion,
MlirStringCallback callback, void* userData);
MLIR_CAPI_EXPORTED MlirLogicalResult
stablehloSerializePortableArtifactFromModule(MlirModule moduleStr,
MlirStringRef targetVersion,
MlirStringCallback callback,
void* userData);

// Read a StableHLO program from a portable artifact, returning the module as
// MLIR bytecode. Note, this bytecode returned is not a portable artifact,
Expand All @@ -104,12 +112,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifact(
// StableHLO, e.g. if it's using incompatible features.
//
// Returns empty module on failure.
MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
MlirStringRef artifactStr, MlirContext ctx);

// Call the Interpreter, returns MlirArrayAttr of dense element
// MlirAttribute results
MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact(
MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifactNoError(
MlirStringRef artifactStr, MlirContext ctx);

// Entrypoint for calling the StableHLO reference interpreter.
Expand All @@ -120,4 +123,8 @@ MLIR_CAPI_EXPORTED MlirAttribute stablehloEvalModule(MlirModule module,
MlirAttribute const* args,
int* errorCode);

#ifdef __cplusplus
}
#endif

#endif // STABLEHLO_INTEGRATIONS_C_STABLEHLOAPI_H_
22 changes: 12 additions & 10 deletions stablehlo/integrations/python/StablehloApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ void AddStablehloApi(py::module &m) {
"serialize_portable_artifact",
[](MlirModule module, std::string_view target) -> py::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact(
module, toMlirStringRef(target),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
if (mlirLogicalResultIsFailure(
stablehloSerializePortableArtifactFromModule(
module, toMlirStringRef(target),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
return "";
}
Expand All @@ -110,7 +111,7 @@ void AddStablehloApi(py::module &m) {
m.def(
"deserialize_portable_artifact",
[](MlirContext context, std::string_view artifact) -> MlirModule {
auto module = stablehloDeserializePortableArtifact(
auto module = stablehloDeserializePortableArtifactNoError(
toMlirStringRef(artifact), context);
if (mlirModuleIsNull(module)) {
PyErr_SetString(PyExc_ValueError, "failed to deserialize module");
Expand Down Expand Up @@ -197,11 +198,12 @@ void AddPortableApi(py::module &m) {
[](std::string_view moduleStrOrBytecode,
std::string_view targetVersion) -> py::bytes {
StringWriterHelper accumulator;
if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact(
toMlirStringRef(moduleStrOrBytecode),
toMlirStringRef(targetVersion),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
if (mlirLogicalResultIsFailure(
stablehloSerializePortableArtifactFromStringRef(
toMlirStringRef(moduleStrOrBytecode),
toMlirStringRef(targetVersion),
accumulator.getMlirStringCallback(),
accumulator.getUserData()))) {
PyErr_SetString(PyExc_ValueError, "failed to serialize module");
return "";
}
Expand Down

0 comments on commit c667089

Please sign in to comment.