diff --git a/stablehlo/integrations/c/StablehloApi.cpp b/stablehlo/integrations/c/StablehloApi.cpp index 8d9221989a..5f730eca12 100644 --- a/stablehlo/integrations/c/StablehloApi.cpp +++ b/stablehlo/integrations/c/StablehloApi.cpp @@ -78,7 +78,7 @@ MlirLogicalResult stablehloGetSmallerVersion(MlirStringRef version1, return mlirLogicalResultSuccess(); } -MlirLogicalResult stablehloSerializePortableArtifact( +MlirLogicalResult stablehloSerializePortableArtifactFromModule( MlirModule moduleStr, MlirStringRef targetVersion, MlirStringCallback callback, void *userData) { mlir::detail::CallbackOstream stream(callback, userData); @@ -88,7 +88,7 @@ MlirLogicalResult stablehloSerializePortableArtifact( return mlirLogicalResultSuccess(); } -MlirLogicalResult stablehloSerializePortableArtifact( +MlirLogicalResult stablehloSerializePortableArtifactFromStringRef( MlirStringRef moduleStr, MlirStringRef targetVersion, MlirStringCallback callback, void *userData) { mlir::detail::CallbackOstream stream(callback, userData); @@ -107,8 +107,8 @@ MlirLogicalResult stablehloDeserializePortableArtifact( return mlirLogicalResultSuccess(); } -MlirModule stablehloDeserializePortableArtifact(MlirStringRef artifactStr, - MlirContext ctx) { +MlirModule stablehloDeserializePortableArtifactNoError( + MlirStringRef artifactStr, MlirContext ctx) { return wrap(mlir::stablehlo::deserializePortableArtifact(unwrap(artifactStr), unwrap(ctx)) .release()); diff --git a/stablehlo/integrations/c/StablehloApi.h b/stablehlo/integrations/c/StablehloApi.h index 4c5425081e..4df239438e 100644 --- a/stablehlo/integrations/c/StablehloApi.h +++ b/stablehlo/integrations/c/StablehloApi.h @@ -16,6 +16,10 @@ limitations under the License. #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#ifdef __cplusplus +extern "C" { +#endif + // Get the current StableHLO API version. // // This value is incremented as needed to help integrate API changes. @@ -72,9 +76,11 @@ stablehloGetSmallerVersion(MlirStringRef version1, MlirStringRef version2, // `targetVersion` version of StableHLO, e.g. if it's using new or removed // features, or if it involves unsupported dialects. // Returns false on failure. -MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact( - MlirStringRef moduleStr, MlirStringRef targetVersion, - MlirStringCallback callback, void* userData); +MLIR_CAPI_EXPORTED MlirLogicalResult +stablehloSerializePortableArtifactFromStringRef(MlirStringRef moduleStr, + MlirStringRef targetVersion, + MlirStringCallback callback, + void* userData); // Write a StableHLO program expressed as a string (either prettyprinted MLIR // module or MLIR bytecode) to a portable artifact. @@ -82,9 +88,11 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact( // `targetVersion` version of StableHLO, e.g. if it's using new or removed // features, or if it involves unsupported dialects. // Returns false on failure. -MLIR_CAPI_EXPORTED MlirLogicalResult stablehloSerializePortableArtifact( - MlirModule moduleStr, MlirStringRef targetVersion, - MlirStringCallback callback, void* userData); +MLIR_CAPI_EXPORTED MlirLogicalResult +stablehloSerializePortableArtifactFromModule(MlirModule moduleStr, + MlirStringRef targetVersion, + MlirStringCallback callback, + void* userData); // Read a StableHLO program from a portable artifact, returning the module as // MLIR bytecode. Note, this bytecode returned is not a portable artifact, @@ -104,12 +112,7 @@ MLIR_CAPI_EXPORTED MlirLogicalResult stablehloDeserializePortableArtifact( // StableHLO, e.g. if it's using incompatible features. // // Returns empty module on failure. -MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact( - MlirStringRef artifactStr, MlirContext ctx); - -// Call the Interpreter, returns MlirArrayAttr of dense element -// MlirAttribute results -MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifact( +MLIR_CAPI_EXPORTED MlirModule stablehloDeserializePortableArtifactNoError( MlirStringRef artifactStr, MlirContext ctx); // Entrypoint for calling the StableHLO reference interpreter. @@ -120,4 +123,8 @@ MLIR_CAPI_EXPORTED MlirAttribute stablehloEvalModule(MlirModule module, MlirAttribute const* args, int* errorCode); +#ifdef __cplusplus +} +#endif + #endif // STABLEHLO_INTEGRATIONS_C_STABLEHLOAPI_H_ diff --git a/stablehlo/integrations/python/StablehloApi.cpp b/stablehlo/integrations/python/StablehloApi.cpp index 46a640e103..68ff3fa600 100644 --- a/stablehlo/integrations/python/StablehloApi.cpp +++ b/stablehlo/integrations/python/StablehloApi.cpp @@ -95,10 +95,11 @@ void AddStablehloApi(py::module &m) { "serialize_portable_artifact", [](MlirModule module, std::string_view target) -> py::bytes { StringWriterHelper accumulator; - if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact( - module, toMlirStringRef(target), - accumulator.getMlirStringCallback(), - accumulator.getUserData()))) { + if (mlirLogicalResultIsFailure( + stablehloSerializePortableArtifactFromModule( + module, toMlirStringRef(target), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { PyErr_SetString(PyExc_ValueError, "failed to serialize module"); return ""; } @@ -110,7 +111,7 @@ void AddStablehloApi(py::module &m) { m.def( "deserialize_portable_artifact", [](MlirContext context, std::string_view artifact) -> MlirModule { - auto module = stablehloDeserializePortableArtifact( + auto module = stablehloDeserializePortableArtifactNoError( toMlirStringRef(artifact), context); if (mlirModuleIsNull(module)) { PyErr_SetString(PyExc_ValueError, "failed to deserialize module"); @@ -197,11 +198,12 @@ void AddPortableApi(py::module &m) { [](std::string_view moduleStrOrBytecode, std::string_view targetVersion) -> py::bytes { StringWriterHelper accumulator; - if (mlirLogicalResultIsFailure(stablehloSerializePortableArtifact( - toMlirStringRef(moduleStrOrBytecode), - toMlirStringRef(targetVersion), - accumulator.getMlirStringCallback(), - accumulator.getUserData()))) { + if (mlirLogicalResultIsFailure( + stablehloSerializePortableArtifactFromStringRef( + toMlirStringRef(moduleStrOrBytecode), + toMlirStringRef(targetVersion), + accumulator.getMlirStringCallback(), + accumulator.getUserData()))) { PyErr_SetString(PyExc_ValueError, "failed to serialize module"); return ""; }