From 68db66a41189f593a8810ff398452d9d326906e1 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 8 Nov 2024 15:06:08 -0500 Subject: [PATCH] [WIP] Expose StableHLO python bindings --- CMakeLists.txt | 11 +- compiler/bindings/python/CMakeLists.txt | 101 +++++ compiler/bindings/python/test/ir/stablehlo.py | 388 ++++++++++++++++++ compiler/src/iree/compiler/API/CMakeLists.txt | 51 ++- 4 files changed, 543 insertions(+), 8 deletions(-) create mode 100644 compiler/bindings/python/test/ir/stablehlo.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 16ab9a3c4d05c..4572763bf99e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -836,12 +836,6 @@ else() iree_llvm_add_usage_requirements(MLIRSupport IREELLVMIncludeSetup) # Add external projects. - - message(STATUS "Configuring llvm-external-projects/mlir-iree-dialects") - list(APPEND CMAKE_MESSAGE_INDENT " ") - iree_llvm_add_external_project(mlir-iree-dialects ${CMAKE_CURRENT_SOURCE_DIR}/llvm-external-projects/iree-dialects) - list(POP_BACK CMAKE_MESSAGE_INDENT) - if(IREE_INPUT_STABLEHLO) message(STATUS "Configuring third_party/stablehlo") list(APPEND CMAKE_MESSAGE_INDENT " ") @@ -849,6 +843,11 @@ else() list(POP_BACK CMAKE_MESSAGE_INDENT) endif() + message(STATUS "Configuring llvm-external-projects/mlir-iree-dialects") + list(APPEND CMAKE_MESSAGE_INDENT " ") + iree_llvm_add_external_project(mlir-iree-dialects ${CMAKE_CURRENT_SOURCE_DIR}/llvm-external-projects/iree-dialects) + list(POP_BACK CMAKE_MESSAGE_INDENT) + # Ensure that LLVM-based dependencies needed for testing are included. add_dependencies(iree-test-deps FileCheck) if(IREE_LLD_TARGET) diff --git a/compiler/bindings/python/CMakeLists.txt b/compiler/bindings/python/CMakeLists.txt index 76caebb180cc6..740c7fe601b9e 100644 --- a/compiler/bindings/python/CMakeLists.txt +++ b/compiler/bindings/python/CMakeLists.txt @@ -185,6 +185,107 @@ declare_mlir_python_extension(IREECompilerPythonExtensions.CompilerDialects LLVMSupport ) +if (IREE_INPUT_STABLEHLO) + set(STABLEHLO_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/stablehlo") + set(STABLEHLO_PYTHON_SOURCE_DIR "${STABLEHLO_SOURCE_DIR}/stablehlo/integrations/python") + include_directories(${STABLEHLO_SOURCE_DIR}) + + declare_mlir_python_sources(CheckPythonSources.Dialects + ADD_TO_PARENT IREEPythonSources + ) + + declare_mlir_dialect_python_bindings( + ADD_TO_PARENT CheckPythonSources.Dialects + ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" + TD_FILE dialects/CheckOps.td + SOURCES dialects/check.py + DIALECT_NAME check) + + declare_mlir_python_sources(ChloPythonSources.Dialects + ADD_TO_PARENT IREEPythonSources + ) + + declare_mlir_dialect_python_bindings( + ADD_TO_PARENT ChloPythonSources.Dialects + ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" + TD_FILE dialects/ChloOps.td + SOURCES dialects/chlo.py + DIALECT_NAME chlo) + + declare_mlir_python_sources(StablehloPythonSources.Dialects + ADD_TO_PARENT IREEPythonSources + ) + + declare_mlir_dialect_python_bindings( + ADD_TO_PARENT StablehloPythonSources.Dialects + ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" + TD_FILE dialects/StablehloOps.td + SOURCES dialects/stablehlo.py + DIALECT_NAME stablehlo) + + declare_mlir_python_sources(VhloPythonSources.Dialects + ADD_TO_PARENT IREEPythonSources + ) + + declare_mlir_dialect_python_bindings( + ADD_TO_PARENT VhloPythonSources.Dialects + ROOT_DIR "${STABLEHLO_PYTHON_SOURCE_DIR}/mlir" + TD_FILE dialects/VhloOps.td + SOURCES dialects/vhlo.py + DIALECT_NAME vhlo) + + ################################################################################ + # Extensions + ################################################################################ + + set(STABLEHLO_PYTHON_SOURCE_DIR "/../../../third_party/stablehlo/stablehlo/integrations/python") + + declare_mlir_python_extension(CheckPythonExtensions.Main + MODULE_NAME _check + ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects + SOURCES + "${STABLEHLO_PYTHON_SOURCE_DIR}/CheckModule.cpp" + EMBED_CAPI_LINK_LIBS + CheckCAPI + PRIVATE_LINK_LIBS + LLVMSupport + ) + + declare_mlir_python_extension(ChloPythonExtensions.Main + MODULE_NAME _chlo + ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects + SOURCES + "${STABLEHLO_PYTHON_SOURCE_DIR}/ChloModule.cpp" + EMBED_CAPI_LINK_LIBS + ChloCAPI + PRIVATE_LINK_LIBS + LLVMSupport + ) + + declare_mlir_python_extension(StablehloPythonExtensions.Main + MODULE_NAME _stablehlo + ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects + SOURCES + "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloApi.cpp" + "${STABLEHLO_PYTHON_SOURCE_DIR}/StablehloModule.cpp" + EMBED_CAPI_LINK_LIBS + StablehloCAPI + PRIVATE_LINK_LIBS + LLVMSupport + ) + + declare_mlir_python_extension(VhloPythonExtensions.Main + MODULE_NAME _vhlo + ADD_TO_PARENT IREECompilerPythonExtensions.CompilerDialects + SOURCES + "${STABLEHLO_PYTHON_SOURCE_DIR}/VhloModule.cpp" + EMBED_CAPI_LINK_LIBS + VhloCAPI + PRIVATE_LINK_LIBS + LLVMSupport + ) +endif() + ################################################################################ # Generate packages and shared library ################################################################################ diff --git a/compiler/bindings/python/test/ir/stablehlo.py b/compiler/bindings/python/test/ir/stablehlo.py new file mode 100644 index 0000000000000..eb8cb63c62ffa --- /dev/null +++ b/compiler/bindings/python/test/ir/stablehlo.py @@ -0,0 +1,388 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# Copyright 2022 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for StableHLO Python APIs.""" + +# pylint: disable=wildcard-import,undefined-variable + +import io +import re +from iree.compiler import ir, passmanager as pm +from iree.compiler.dialects import stablehlo + +import numpy as np + + +def run(f): + with ir.Context() as context: + stablehlo.register_dialect(context) + f() + return f + + +@run +def test_channel_handle(): + attr = stablehlo.ChannelHandle.get(handle=1, type=2) + assert attr is not None + assert attr.handle == 1 + assert attr.channel_type == 2 + + +@run +def test_comparison_direction_attr(): + attr = stablehlo.ComparisonDirectionAttr.get("EQ") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "EQ" + + +@run +def test_comparison_type_attr(): + attr = stablehlo.ComparisonTypeAttr.get("FLOAT") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "FLOAT" + + +@run +def test_conv_dimension_numbers(): + attr = stablehlo.ConvDimensionNumbers.get( + input_batch_dimension=0, + input_feature_dimension=1, + input_spatial_dimensions=[2, 3, 4], + kernel_input_feature_dimension=0, + kernel_output_feature_dimension=1, + kernel_spatial_dimensions=[2, 3], + output_batch_dimension=0, + output_feature_dimension=1, + output_spatial_dimensions=[2, 3]) + assert str(attr) == ("#stablehlo.conv<[b, f, 0, 1, 2]x[i, o, 0, 1]->" + "[b, f, 0, 1]>") + assert attr is not None + assert attr.input_batch_dimension == 0 + assert attr.input_feature_dimension == 1 + assert attr.input_spatial_dimensions == [2, 3, 4] + assert attr.kernel_input_feature_dimension == 0 + assert attr.kernel_output_feature_dimension == 1 + assert attr.kernel_spatial_dimensions == [2, 3] + assert attr.output_batch_dimension == 0 + assert attr.output_feature_dimension == 1 + assert attr.output_spatial_dimensions == [2, 3] + + +@run +def test_dot_algorithm(): + # BF16_BF16_F32_X3 + attr = stablehlo.DotAlgorithm.get( + lhs_precision_type=ir.BF16Type.get(), + rhs_precision_type=ir.BF16Type.get(), + accumulation_type=ir.F32Type.get(), + lhs_component_count=1, + rhs_component_count=1, + num_primitive_operations=3, + allow_imprecise_accumulation=False) + assert attr is not None + assert str(attr) == ("#stablehlo.dot_algorithm") + assert isinstance(attr.lhs_precision_type, ir.BF16Type) + assert isinstance(attr.rhs_precision_type, ir.BF16Type) + assert isinstance(attr.accumulation_type, ir.F32Type) + assert attr.lhs_component_count == 1 + assert attr.rhs_component_count == 1 + assert attr.num_primitive_operations == 3 + assert attr.allow_imprecise_accumulation == False + + +@run +def test_dot_dimension_numbers(): + attr = stablehlo.DotDimensionNumbers.get( + lhs_batching_dimensions=[0, 1], + rhs_batching_dimensions=[2, 3], + lhs_contracting_dimensions=[4, 5], + rhs_contracting_dimensions=[6, 7]) + assert attr is not None + assert str(attr) == ("#stablehlo.dot") + assert attr.lhs_batching_dimensions == [0, 1] + assert attr.rhs_batching_dimensions == [2, 3] + assert attr.lhs_contracting_dimensions == [4, 5] + assert attr.rhs_contracting_dimensions == [6, 7] + + +@run +def test_fft_type_attr(): + attr = stablehlo.FftTypeAttr.get("FFT") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "FFT" + + +@run +def test_gather_dimension_numbers(): + attr = stablehlo.GatherDimensionNumbers.get( + offset_dims=[1, 2], + collapsed_slice_dims=[3, 4, 5], + operand_batching_dims=[6, 7], + start_indices_batching_dims=[8, 9], + start_index_map=[10], + index_vector_dim=11, + ) + assert attr is not None + assert str(attr) == ( + "#stablehlo.gather" + ) + assert attr.offset_dims == [1, 2] + assert attr.collapsed_slice_dims == [3, 4, 5] + assert attr.operand_batching_dims == [6, 7] + assert attr.start_indices_batching_dims == [8, 9] + assert attr.start_index_map == [10] + assert attr.index_vector_dim == 11 + + +@run +def test_output_operand_alias(): + attr = stablehlo.OutputOperandAlias.get( + output_tuple_indices=[0], + operand_index=0, + operand_tuple_indices=[1]) + assert attr is not None + assert str(attr) == ("#stablehlo.output_operand_alias") + assert attr.output_tuple_indices == [0] + assert attr.operand_index == 0 + assert attr.operand_tuple_indices == [1] + + +@run +def test_precision_attr(): + attr = stablehlo.PrecisionAttr.get("DEFAULT") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "DEFAULT" + + +@run +def test_rng_algorithm_attr(): + attr = stablehlo.RngAlgorithmAttr.get("DEFAULT") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "DEFAULT" + + +@run +def test_rng_distribution_attr(): + attr = stablehlo.RngDistributionAttr.get("UNIFORM") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "UNIFORM" + + +@run +def test_scatter_dimension_numbers(): + attr = stablehlo.ScatterDimensionNumbers.get( + update_window_dims=[1, 2, 3], + inserted_window_dims=[4, 5], + input_batching_dims=[6, 7], + scatter_indices_batching_dims=[8, 9], + scattered_dims_to_operand_dims=[10, 11], + index_vector_dim=12, + ) + assert attr is not None + assert str(attr) == ( + "#stablehlo.scatter" + ) + assert attr.update_window_dims == [1, 2, 3] + assert attr.inserted_window_dims == [4, 5] + assert attr.input_batching_dims == [6, 7] + assert attr.scatter_indices_batching_dims == [8, 9] + assert attr.scattered_dims_to_operand_dims == [10, 11] + assert attr.index_vector_dim == 12 + + +@run +def test_transpose_attr(): + attr = stablehlo.TransposeAttr.get("TRANSPOSE") + assert attr is not None + assert str(attr) == ("#stablehlo") + assert attr.value == "TRANSPOSE" + + +@run +def test_token_type(): + type = stablehlo.TokenType.get() + assert type is not None + assert str(type) == "!stablehlo.token" + + +@run +def test_type_extensions(): + dyn_size = ir.ShapedType.get_dynamic_size() + attr = stablehlo.TypeExtensions.get(bounds=[128, dyn_size]) + assert attr is not None + assert attr.bounds == [128, dyn_size] + + +@run +def test_api_version(): + api_version = stablehlo.get_api_version() + assert type(api_version) == int + assert api_version > 0 + + +def is_semver_format(version_str): + return re.match("^\d+\.\d+\.\d+$", version_str) + + +@run +def test_current_version(): + curr_version = stablehlo.get_current_version() + assert is_semver_format(curr_version) + + +@run +def test_minimum_version(): + curr_version = stablehlo.get_minimum_version() + assert is_semver_format(curr_version) + + +@run +def test_version_requirements(): + for req in ( + stablehlo.StablehloCompatibilityRequirement.NONE, + stablehlo.StablehloCompatibilityRequirement.WEEK_4, + stablehlo.StablehloCompatibilityRequirement.WEEK_12, + stablehlo.StablehloCompatibilityRequirement.MAX, + ): + assert is_semver_format( + stablehlo.get_version_from_compatibility_requirement(req) + ) + + +ASM_FORMAT = """ +func.func @test(%arg0: tensor<{0}>) -> tensor<{0}> {{ + %0 = stablehlo.add %arg0, %arg0 : (tensor<{0}>, tensor<{0}>) -> tensor<{0}> + func.return %0 : tensor<{0}> +}} +""" + + +# @run +# def test_reference_api(): +# # Formatted as (tensor_type, np_value) +# # Program runs arg + arg, which is used for expected value +# tests = [ +# # No numpy types for f8 - skipping fp8 tests +# ("f16", np.asarray(1, np.float16)), +# ("f32", np.asarray(2, np.float32)), +# ("f64", np.asarray(3, np.double)), +# ("1xi8", np.asarray([4], np.int8)), +# ("1xi16", np.asarray([5], np.int16)), +# ("1xi32", np.asarray([-6], np.int32)), +# # Numpy's uint treated as int by DenseElementsAttr, skipping np.uint tests +# ("2x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), +# ("2x1x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,1,2)), +# ("?x?xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), +# ("?x2xf16", np.asarray([1, 2, 3, 4], np.float16).reshape(2,2)), +# ] +# for test in tests: +# tensor_type, arg = test +# with ir.Context() as context: +# stablehlo.register_dialect(context) +# m = ir.Module.parse(ASM_FORMAT.format(tensor_type)) +# args = [ir.DenseIntElementsAttr.get(arg)] +# +# actual = np.array(stablehlo.eval_module(m, args)[0]) +# expected = arg + arg +# assert (actual == expected).all() +# + +@run +def test_get_smaller_version(): + curr_version = stablehlo.get_current_version() + min_version = stablehlo.get_minimum_version() + assert stablehlo.get_smaller_version(curr_version, min_version) == min_version + + +@run +def test_serialization_apis(): + curr_version = stablehlo.get_current_version() + + with ir.Context() as context: + stablehlo.register_dialect(context) + m = ir.Module.parse(ASM_FORMAT.format("2xf32")) + assert m is not None + module_str = str(m) + serialized = stablehlo.serialize_portable_artifact(m, curr_version) + deserialized = stablehlo.deserialize_portable_artifact(context, serialized) + assert module_str == str(deserialized) + + +@run +def test_str_serialization_apis(): + curr_version = stablehlo.get_current_version() + + def module_to_bytecode(module: ir.Module) -> bytes: + output = io.BytesIO() + module.operation.write_bytecode(file=output) + return output.getvalue() + + with ir.Context() as context: + stablehlo.register_dialect(context) + m = ir.Module.parse(ASM_FORMAT.format("2xf32")) + assert m is not None + module_str = str(m) + bytecode = module_to_bytecode(m) + serialized = stablehlo.serialize_portable_artifact_str( + bytecode, curr_version + ) + deserialized = stablehlo.deserialize_portable_artifact_str(serialized) + deserialized_module = ir.Module.parse(deserialized) + assert module_str == str(deserialized_module) + + +@run +def test_register_passes(): + """Tests pass registration.""" + with ir.Context() as context: + stablehlo.register_dialect(context) + module = ir.Module.parse(ASM_FORMAT.format("2xf32")) + assert module is not None + + stablehlo.register_stablehlo_passes() + pipeline = [ + "stablehlo-legalize-to-vhlo", + "vhlo-legalize-to-stablehlo", + ] + pipeline = pm.PassManager.parse(f"builtin.module({','.join(pipeline)})") + + cloned_module = module.operation.clone() + pipeline.run(cloned_module.operation) + assert str(module) == str(cloned_module) diff --git a/compiler/src/iree/compiler/API/CMakeLists.txt b/compiler/src/iree/compiler/API/CMakeLists.txt index 00a7e395593a1..9b858a0add972 100644 --- a/compiler/src/iree/compiler/API/CMakeLists.txt +++ b/compiler/src/iree/compiler/API/CMakeLists.txt @@ -25,6 +25,7 @@ iree_cc_library( MLIRCAPITransformDialect MLIRCAPITransformDialectTransforms MLIRCAPITransforms + StablehloCAPI iree::compiler::API::Internal::CompilerDriver iree::compiler::API::Internal::IREECompileToolEntryPoint iree::compiler::API::Internal::IREEMLIRLSPServerToolEntryPoint @@ -76,6 +77,8 @@ set(_EXPORT_OBJECT_LIBS obj.MLIRCAPITransforms obj.MLIRCAPITransformDialect obj.MLIRCAPITransformDialectTransforms + obj.StablehloCAPI + StablehloCAPI iree_compiler_API_Internal_CompilerDriver.objects iree_compiler_API_Internal_IREECompileToolEntryPoint.objects iree_compiler_API_Internal_IREEGPUDialectCAPI.objects @@ -90,6 +93,42 @@ if(DEFINED IREE_COMPILER_API_ADDL_EXPORT_OBJECTS) list(APPEND _EXPORT_OBJECT_LIBS ${IREE_COMPILER_API_ADDL_EXPORT_OBJECTS}) endif() +# Get all propreties that cmake supports +if(NOT CMAKE_PROPERTY_LIST) + execute_process(COMMAND cmake --help-property-list OUTPUT_VARIABLE CMAKE_PROPERTY_LIST) + + # Convert command output into a CMake list + string(REGEX REPLACE ";" "\\\\;" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") + string(REGEX REPLACE "\n" ";" CMAKE_PROPERTY_LIST "${CMAKE_PROPERTY_LIST}") + list(REMOVE_DUPLICATES CMAKE_PROPERTY_LIST) +endif() + +function(print_properties) + message("CMAKE_PROPERTY_LIST = ${CMAKE_PROPERTY_LIST}") +endfunction() + +function(print_target_properties target) + if(NOT TARGET ${target}) + message(STATUS "There is no target named '${target}'") + return() + endif() + + foreach(property ${CMAKE_PROPERTY_LIST}) + string(REPLACE "" "${CMAKE_BUILD_TYPE}" property ${property}) + + # Fix https://stackoverflow.com/questions/32197663/how-can-i-remove-the-the-location-property-may-not-be-read-from-target-error-i + if(property STREQUAL "LOCATION" OR property MATCHES "^LOCATION_" OR property MATCHES "_LOCATION$") + continue() + endif() + + get_property(was_set TARGET ${target} PROPERTY ${property} SET) + if(was_set) + get_target_property(value ${target} ${property}) + message("${target} ${property} = ${value}") + endif() + endforeach() +endfunction() + set(_EXPORT_OBJECT_SRCS) set(_EXPORT_OBJECT_DEPS) foreach(_object_lib ${_EXPORT_OBJECT_LIBS}) @@ -136,11 +175,19 @@ foreach(_object_lib ${_EXPORT_OBJECT_LIBS}) # that show like generator expressions are showing up in link lines, this is # the culprit. Look at the export_objects_debug.txt to confirm. Then, add # another level of fix upstream if you like pain. - list(APPEND _EXPORT_OBJECT_DEPS "$>>>>") + message(STATUS "wtfbbq ${_object_lib}") + print_target_properties("${_object_lib}") + get_target_property(type "${_object_lib}" TYPE) + if (${type} STREQUAL "OBJECT_LIBRARY") + list(APPEND _EXPORT_OBJECT_DEPS "$>>>>>>>") + elseif (${type} STREQUAL "STATIC_LIBRARY") + get_target_property(_libs "${_object_lib}" INTERFACE_LINK_LIBRARIES) + list(APPEND _EXPORT_OBJECT_DEPS "${_libs}") + endif() endforeach() # UNCOMMENT TO DEBUG WHAT IS GOING ON. -# file(GENERATE OUTPUT export_objects_debug.txt CONTENT "OBJECTS:${_EXPORT_OBJECT_SRCS}\n\nDEPS:${_EXPORT_OBJECT_DEPS}") + file(GENERATE OUTPUT export_objects_debug.txt CONTENT "OBJECTS:${_EXPORT_OBJECT_SRCS}\n\nDEPS:${_EXPORT_OBJECT_DEPS}") # Disable .so.0 style naming/linking. In order to be consistent across platforms # and bindings, we will embed a major version in the library name when it is time.