diff --git a/frontend/e2e_test.py b/frontend/e2e_test.py index 30cd2ff20..e06860bb2 100644 --- a/frontend/e2e_test.py +++ b/frontend/e2e_test.py @@ -1,12 +1,13 @@ from heir import compile from heir.mlir import * + from absl.testing import absltest # fmt: skip class EndToEndTest(absltest.TestCase): def test_simple_arithmetic(self): - @compile() # defaults to BGV and OpenFHE - def foo(a : Secret[I64], b : Secret[I64]): + @compile() # defaults to BGV and OpenFHE + def foo(a: Secret[I64], b: Secret[I64]): return a * a - b * b # Test plaintext functionality diff --git a/frontend/example.py b/frontend/example.py index 943a548b0..71a021a53 100644 --- a/frontend/example.py +++ b/frontend/example.py @@ -5,9 +5,10 @@ # TODO (#1162): Also add the tensorflow-to-tosa-to-HEIR example in example.py, even it doesn't use the main Python frontend? + ### Simple Example -@compile() # defaults to scheme="bgv", OpenFHE backend, and debug=False -def foo(x : Secret[I16], y : Secret[I16]): +@compile() # defaults to scheme="bgv", OpenFHE backend, and debug=False +def foo(x: Secret[I16], y: Secret[I16]): sum = x + y diff = x - y mul = x * y @@ -15,7 +16,8 @@ def foo(x : Secret[I16], y : Secret[I16]): deadcode = expression * mul return expression -foo.setup() # runs keygen/etc + +foo.setup() # runs keygen/etc enc_x = foo.encrypt_x(7) enc_y = foo.encrypt_y(8) result_enc = foo.eval(enc_x, enc_y) @@ -74,7 +76,6 @@ def foo(x : Secret[I16], y : Secret[I16]): # print(f"Expected result for `baz2`: {baz2(7,8,9)}, decrypted result: {result}") - # ### Custom Pipeline Example # @compile(heir_opt_options=["--mlir-to-secret-arithmetic", "--canonicalize", "--cse"], backend=None, debug=True) # def foo(x : Secret[I16], y : Secret[I16]): diff --git a/frontend/heir/backends/dummy.py b/frontend/heir/backends/dummy.py index a09160585..553d89abb 100644 --- a/frontend/heir/backends/dummy.py +++ b/frontend/heir/backends/dummy.py @@ -2,7 +2,8 @@ from colorama import Fore, Style, init -from heir.core import BackendInterface, CompilationResult, ClientInterface +from heir.core import BackendInterface, CompilationResult, ClientInterface + class DummyClientInterface(ClientInterface): @@ -10,10 +11,22 @@ def __init__(self, compilation_result: CompilationResult): self.compilation_result = compilation_result def setup(self): - print("HEIR Warning (Dummy Backend): " + Fore.YELLOW + Style.BRIGHT + f"{self.compilation_result.func_name}.setup() is a no-op in the Dummy Backend") + print( + "HEIR Warning (Dummy Backend): " + + Fore.YELLOW + + Style.BRIGHT + + f"{self.compilation_result.func_name}.setup() is a no-op in the Dummy" + " Backend" + ) def decrypt_result(self, result): - print("HEIR Warning (Dummy Backend): " + Fore.YELLOW + Style.BRIGHT + f"{self.compilation_result.func_name}.decrypt() is a no-op in the Dummy Backend") + print( + "HEIR Warning (Dummy Backend): " + + Fore.YELLOW + + Style.BRIGHT + + f"{self.compilation_result.func_name}.decrypt() is a no-op in the" + " Dummy Backend" + ) return result def __getattr__(self, key): @@ -22,13 +35,25 @@ def __getattr__(self, key): arg_name = key[len("encrypt_") :] def wrapper(arg): - print("HEIR Warning (Dummy Backend): " + Fore.YELLOW + Style.BRIGHT + f"{self.compilation_result.func_name}.{key}() is a no-op in the Dummy Backend") + print( + "HEIR Warning (Dummy Backend): " + + Fore.YELLOW + + Style.BRIGHT + + f"{self.compilation_result.func_name}.{key}() is a no-op in the" + " Dummy Backend" + ) return arg return wrapper if key == self.compilation_result.func_name or key == "eval": - print("HEIR Warning (Dummy Backend): " + Fore.YELLOW + Style.BRIGHT + f"{self.compilation_result.func_name}.eval() is the same as {self.compilation_result.func_name}() in the Dummy Backend.") + print( + "HEIR Warning (Dummy Backend): " + + Fore.YELLOW + + Style.BRIGHT + + f"{self.compilation_result.func_name}.eval() is the same as" + f" {self.compilation_result.func_name}() in the Dummy Backend." + ) return self.func raise AttributeError(f"Attribute {key} not found") @@ -36,17 +61,27 @@ def wrapper(arg): class DummyBackend(BackendInterface): - def run_backend(self, workspace_dir, heir_opt, heir_translate, func_name, arg_names, secret_args, heir_opt_output, debug): - - result = CompilationResult( - module=None, - func_name=func_name, - secret_args=secret_args, - arg_names=arg_names, - arg_enc_funcs=None, - result_dec_func=None, - main_func=None, - setup_funcs=None - ) - - return DummyClientInterface(result) + def run_backend( + self, + workspace_dir, + heir_opt, + heir_translate, + func_name, + arg_names, + secret_args, + heir_opt_output, + debug, + ): + + result = CompilationResult( + module=None, + func_name=func_name, + secret_args=secret_args, + arg_names=arg_names, + arg_enc_funcs=None, + result_dec_func=None, + main_func=None, + setup_funcs=None, + ) + + return DummyClientInterface(result) diff --git a/frontend/heir/backends/openfhe/__init__.py b/frontend/heir/backends/openfhe/__init__.py index 6b868823e..392efe851 100644 --- a/frontend/heir/backends/openfhe/__init__.py +++ b/frontend/heir/backends/openfhe/__init__.py @@ -1,4 +1,9 @@ from .backend import OpenFHEBackend from .config import OpenFHEConfig, DEFAULT_INSTALLED_OPENFHE_CONFIG, from_os_env -__all__ = ["OpenFHEBackend", "OpenFHEConfig", "DEFAULT_INSTALLED_OPENFHE_CONFIG", "from_os_env"] +__all__ = [ + "OpenFHEBackend", + "OpenFHEConfig", + "DEFAULT_INSTALLED_OPENFHE_CONFIG", + "from_os_env", +] diff --git a/frontend/heir/backends/openfhe/backend.py b/frontend/heir/backends/openfhe/backend.py index 4cbf43387..646ca1d2f 100644 --- a/frontend/heir/backends/openfhe/backend.py +++ b/frontend/heir/backends/openfhe/backend.py @@ -6,7 +6,7 @@ from colorama import Fore, Style, init -from heir.core import BackendInterface, CompilationResult, ClientInterface +from heir.core import BackendInterface, CompilationResult, ClientInterface from heir.backends.util import clang, pybind_helpers from .config import OpenFHEConfig @@ -14,6 +14,7 @@ pyconfig_ext_suffix = pybind_helpers.pyconfig_ext_suffix pybind11_includes = pybind_helpers.pybind11_includes + class OpenfheClientInterface(ClientInterface): def __init__(self, compilation_result: CompilationResult): @@ -42,10 +43,14 @@ def decrypt_result(self, result, *, crypto_context=None, secret_key=None): def __getattr__(self, key): if key == "crypto_context": - msg = f"HEIR Error: Please call {self.compilation_result.func_name}.setup() before calling {self.compilation_result.func_name}.encrypt/eval/decrypt"; - init(autoreset=True) - print(Fore.RED + Style.BRIGHT + msg) - raise RuntimeError(msg) + msg = ( + f"HEIR Error: Please call {self.compilation_result.func_name}.setup()" + " before calling" + f" {self.compilation_result.func_name}.encrypt/eval/decrypt" + ) + init(autoreset=True) + print(Fore.RED + Style.BRIGHT + msg) + raise RuntimeError(msg) if key.startswith("encrypt_"): # TODO (#1162): Prevent people from doing a = enc_x, b = enc_y for foo(x,y) but then calling foo(b,a)! @@ -81,147 +86,185 @@ def wrapper(*args, crypto_context=None): class OpenFHEBackend(BackendInterface): - def __init__(self, openfhe_config: OpenFHEConfig): - self.openfhe_config = openfhe_config - self.heir_opt_options = [] - - def run_backend(self, workspace_dir, heir_opt, heir_translate, func_name, arg_names, secret_args, heir_opt_output, debug): - # Initialize Colorama for error and debug messages - init(autoreset=True) - - # Convert from "scheme" to openfhe: - heir_opt_options = [f"--scheme-to-openfhe=entry-function={func_name}"] - if(debug): - heir_opt_options.append("--view-op-graph") - print("HEIRpy Debug (OpenFHE Backend): " + Style.BRIGHT + f"Running heir-opt {' '.join(heir_opt_options)}") - heir_opt_output, graph = heir_opt.run_binary_stderr( - input=heir_opt_output, - options=(heir_opt_options), - ) - if(debug): - # Print output after heir_opt: - mlirpath = Path(workspace_dir) / f"{func_name}.backend.mlir" - graphpath = Path(workspace_dir) / f"{func_name}.backend.dot" - print(f"HEIRpy Debug (OpenFHE Backend): Writing backend MLIR to {mlirpath}") - with open(mlirpath, "w") as f: - f.write(heir_opt_output) - print(f"HEIRpy Debug (OpenFHE Backend): Writing backend graph to {graphpath}") - with open(graphpath, "w") as f: - f.write(graph) - - - # Translate to *.cpp and Pybind - module_name = f"_heir_{func_name}" - cpp_filepath = Path(workspace_dir) / f"{func_name}.cpp" - h_filepath = Path(workspace_dir) / f"{func_name}.h" - pybind_filepath = Path(workspace_dir) / f"{func_name}_bindings.cpp" - include_type_flag = "--openfhe-include-type=" + self.openfhe_config.include_type - header_options = [ - "--emit-openfhe-pke-header", - include_type_flag, - "-o", - h_filepath, - ] - cpp_options = ["--emit-openfhe-pke", include_type_flag, "-o", cpp_filepath] - pybind_options = [ - "--emit-openfhe-pke-pybind", - f"--pybind-header-include={h_filepath.name}", - f"--pybind-module-name={module_name}", - "-o", - pybind_filepath, - ] - if(debug): - print("HEIRpy Debug (OpenFHE Backend): " + Style.BRIGHT + f"Running heir-translate {' '.join(str(o) for o in header_options)}") - heir_translate.run_binary( - input=heir_opt_output, - options=header_options, - ) - if(debug): - print("HEIRpy Debug (OpenFHE Backend): " + Style.BRIGHT + f"Running heir-translate {' '.join(str(o) for o in cpp_options)}") - heir_translate.run_binary( - input=heir_opt_output, - options=cpp_options, - ) - if(debug): - print("HEIRpy Debug (OpenFHE Backend): " + Style.BRIGHT + f"Running heir-translate {' '.join(str(o) for o in pybind_options)}") - heir_translate.run_binary( - input=heir_opt_output, - options=pybind_options, - ) - clang_backend = clang.ClangBackend() - so_filepath = Path(workspace_dir) / f"{func_name}.so" - linker_search_paths = [self.openfhe_config.lib_dir] - if(debug): - args = clang_backend.clang_arg_helper(cpp_source_filepath=cpp_filepath, - shared_object_output_filepath=so_filepath, - include_paths=self.openfhe_config.include_dirs, - linker_search_paths=linker_search_paths, - link_libs=self.openfhe_config.link_libs) - print( - f"HEIRpy Debug (OpenFHE Backend):\033[1m Running clang {' '.join(str(arg) for arg in args)}\033[0m" - ) - - clang_backend.compile_to_shared_object( - cpp_source_filepath=cpp_filepath, - shared_object_output_filepath=so_filepath, - include_paths=self.openfhe_config.include_dirs, - linker_search_paths=linker_search_paths, - link_libs=self.openfhe_config.link_libs, - ) + def __init__(self, openfhe_config: OpenFHEConfig): + self.openfhe_config = openfhe_config + self.heir_opt_options = [] - ext_suffix = pyconfig_ext_suffix() - pybind_so_filepath = Path(workspace_dir) / f"{module_name}{ext_suffix}" - if(debug): - args = clang_backend.clang_arg_helper(cpp_source_filepath=pybind_filepath, - shared_object_output_filepath=pybind_so_filepath, - include_paths=self.openfhe_config.include_dirs - + pybind11_includes() - + [workspace_dir], - linker_search_paths=linker_search_paths, - link_libs=self.openfhe_config.link_libs, - linker_args=["-rpath", ":".join(linker_search_paths)], - abs_link_lib_paths=[so_filepath],) - print( - f"HEIRpy Debug (OpenFHE Backend):\033[1m Running clang {' '.join(str(arg) for arg in args)}\033[0m" - ) - clang_backend.compile_to_shared_object( - cpp_source_filepath=pybind_filepath, - shared_object_output_filepath=pybind_so_filepath, - include_paths=self.openfhe_config.include_dirs - + pybind11_includes() - + [workspace_dir], - linker_search_paths=linker_search_paths, - link_libs=self.openfhe_config.link_libs, - linker_args=["-rpath", ":".join(linker_search_paths)], - abs_link_lib_paths=[so_filepath], - ) + def run_backend( + self, + workspace_dir, + heir_opt, + heir_translate, + func_name, + arg_names, + secret_args, + heir_opt_output, + debug, + ): + # Initialize Colorama for error and debug messages + init(autoreset=True) + + # Convert from "scheme" to openfhe: + heir_opt_options = [f"--scheme-to-openfhe=entry-function={func_name}"] + if debug: + heir_opt_options.append("--view-op-graph") + print( + "HEIRpy Debug (OpenFHE Backend): " + + Style.BRIGHT + + f"Running heir-opt {' '.join(heir_opt_options)}" + ) + heir_opt_output, graph = heir_opt.run_binary_stderr( + input=heir_opt_output, + options=(heir_opt_options), + ) + if debug: + # Print output after heir_opt: + mlirpath = Path(workspace_dir) / f"{func_name}.backend.mlir" + graphpath = Path(workspace_dir) / f"{func_name}.backend.dot" + print( + f"HEIRpy Debug (OpenFHE Backend): Writing backend MLIR to {mlirpath}" + ) + with open(mlirpath, "w") as f: + f.write(heir_opt_output) + print( + "HEIRpy Debug (OpenFHE Backend): Writing backend graph to" + f" {graphpath}" + ) + with open(graphpath, "w") as f: + f.write(graph) + + # Translate to *.cpp and Pybind + module_name = f"_heir_{func_name}" + cpp_filepath = Path(workspace_dir) / f"{func_name}.cpp" + h_filepath = Path(workspace_dir) / f"{func_name}.h" + pybind_filepath = Path(workspace_dir) / f"{func_name}_bindings.cpp" + include_type_flag = ( + "--openfhe-include-type=" + self.openfhe_config.include_type + ) + header_options = [ + "--emit-openfhe-pke-header", + include_type_flag, + "-o", + h_filepath, + ] + cpp_options = ["--emit-openfhe-pke", include_type_flag, "-o", cpp_filepath] + pybind_options = [ + "--emit-openfhe-pke-pybind", + f"--pybind-header-include={h_filepath.name}", + f"--pybind-module-name={module_name}", + "-o", + pybind_filepath, + ] + if debug: + print( + "HEIRpy Debug (OpenFHE Backend): " + + Style.BRIGHT + + f"Running heir-translate {' '.join(str(o) for o in header_options)}" + ) + heir_translate.run_binary( + input=heir_opt_output, + options=header_options, + ) + if debug: + print( + "HEIRpy Debug (OpenFHE Backend): " + + Style.BRIGHT + + f"Running heir-translate {' '.join(str(o) for o in cpp_options)}" + ) + heir_translate.run_binary( + input=heir_opt_output, + options=cpp_options, + ) + if debug: + print( + "HEIRpy Debug (OpenFHE Backend): " + + Style.BRIGHT + + f"Running heir-translate {' '.join(str(o) for o in pybind_options)}" + ) + heir_translate.run_binary( + input=heir_opt_output, + options=pybind_options, + ) + + clang_backend = clang.ClangBackend() + so_filepath = Path(workspace_dir) / f"{func_name}.so" + linker_search_paths = [self.openfhe_config.lib_dir] + if debug: + args = clang_backend.clang_arg_helper( + cpp_source_filepath=cpp_filepath, + shared_object_output_filepath=so_filepath, + include_paths=self.openfhe_config.include_dirs, + linker_search_paths=linker_search_paths, + link_libs=self.openfhe_config.link_libs, + ) + print( + "HEIRpy Debug (OpenFHE Backend):\033[1m Running clang" + f" {' '.join(str(arg) for arg in args)}\033[0m" + ) + + clang_backend.compile_to_shared_object( + cpp_source_filepath=cpp_filepath, + shared_object_output_filepath=so_filepath, + include_paths=self.openfhe_config.include_dirs, + linker_search_paths=linker_search_paths, + link_libs=self.openfhe_config.link_libs, + ) + + ext_suffix = pyconfig_ext_suffix() + pybind_so_filepath = Path(workspace_dir) / f"{module_name}{ext_suffix}" + if debug: + args = clang_backend.clang_arg_helper( + cpp_source_filepath=pybind_filepath, + shared_object_output_filepath=pybind_so_filepath, + include_paths=self.openfhe_config.include_dirs + + pybind11_includes() + + [workspace_dir], + linker_search_paths=linker_search_paths, + link_libs=self.openfhe_config.link_libs, + linker_args=["-rpath", ":".join(linker_search_paths)], + abs_link_lib_paths=[so_filepath], + ) + print( + "HEIRpy Debug (OpenFHE Backend):\033[1m Running clang" + f" {' '.join(str(arg) for arg in args)}\033[0m" + ) + clang_backend.compile_to_shared_object( + cpp_source_filepath=pybind_filepath, + shared_object_output_filepath=pybind_so_filepath, + include_paths=self.openfhe_config.include_dirs + + pybind11_includes() + + [workspace_dir], + linker_search_paths=linker_search_paths, + link_libs=self.openfhe_config.link_libs, + linker_args=["-rpath", ":".join(linker_search_paths)], + abs_link_lib_paths=[so_filepath], + ) + + sys.path.append(workspace_dir) + importlib.invalidate_caches() + bound_module = importlib.import_module(module_name) + + result = CompilationResult( + module=bound_module, + func_name=func_name, + secret_args=secret_args, + arg_names=arg_names, + arg_enc_funcs={ + arg_name: getattr(bound_module, f"{func_name}__encrypt__arg{i}") + for i, arg_name in enumerate(arg_names) + if i in secret_args + }, + result_dec_func=getattr(bound_module, f"{func_name}__decrypt__result0"), + main_func=getattr(bound_module, func_name), + setup_funcs={ + "generate_crypto_context": getattr( + bound_module, f"{func_name}__generate_crypto_context" + ), + "configure_crypto_context": getattr( + bound_module, f"{func_name}__configure_crypto_context" + ), + }, + ) - sys.path.append(workspace_dir) - importlib.invalidate_caches() - bound_module = importlib.import_module(module_name) - - - result = CompilationResult( - module=bound_module, - func_name=func_name, - secret_args=secret_args, - arg_names=arg_names, - arg_enc_funcs={ - arg_name: getattr(bound_module, f"{func_name}__encrypt__arg{i}") - for i, arg_name in enumerate(arg_names) - if i in secret_args - }, - result_dec_func=getattr(bound_module, f"{func_name}__decrypt__result0"), - main_func=getattr(bound_module, func_name), - setup_funcs={ - "generate_crypto_context": getattr( - bound_module, f"{func_name}__generate_crypto_context" - ), - "configure_crypto_context": getattr( - bound_module, f"{func_name}__configure_crypto_context" - ), - }, - ) - - return OpenfheClientInterface(result) + return OpenfheClientInterface(result) diff --git a/frontend/heir/backends/util/clang.py b/frontend/heir/backends/util/clang.py index f63f32994..e7b9c60e7 100644 --- a/frontend/heir/backends/util/clang.py +++ b/frontend/heir/backends/util/clang.py @@ -30,7 +30,6 @@ def _find_clang(self): return clang return shutil.which("clang") - def clang_arg_helper( self, cpp_source_filepath: Path, diff --git a/frontend/heir/core.py b/frontend/heir/core.py index 368661c83..075e2251a 100644 --- a/frontend/heir/core.py +++ b/frontend/heir/core.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass + @dataclass class CompilationResult: # The module object containing the compiled functions @@ -27,6 +28,7 @@ class CompilationResult: # Backend setup functions, if any setup_funcs: dict[str, object] + class ClientInterface(ABC): @abstractmethod @@ -51,5 +53,6 @@ def __call__(self, *args, **kwargs): """Forwards to the original function.""" return self.func(*args, **kwargs) + class BackendInterface(ABC): ... diff --git a/frontend/heir/heir_cli/heir_cli_config.py b/frontend/heir/heir_cli/heir_cli_config.py index 68bfb4f84..81d5cea8d 100644 --- a/frontend/heir/heir_cli/heir_cli_config.py +++ b/frontend/heir/heir_cli/heir_cli_config.py @@ -13,13 +13,15 @@ class HEIRConfig: heir_opt_path: str heir_translate_path: str + repo_root = pathlib.Path(__file__).resolve().parents[3] DEVELOPMENT_HEIR_CONFIG = HEIRConfig( - heir_opt_path= repo_root/"bazel-bin/tools/heir-opt", - heir_translate_path= repo_root/"bazel-bin/tools/heir-translate", + heir_opt_path=repo_root / "bazel-bin/tools/heir-opt", + heir_translate_path=repo_root / "bazel-bin/tools/heir-translate", ) -#TODO (#1326): Add a config that automagically downloads the nightlies +# TODO (#1326): Add a config that automagically downloads the nightlies + def from_os_env() -> HEIRConfig: """Create a HEIRConfig from environment variables. diff --git a/frontend/heir/mlir/linalg.py b/frontend/heir/mlir/linalg.py index 33ff52226..7deeeda11 100644 --- a/frontend/heir/mlir/linalg.py +++ b/frontend/heir/mlir/linalg.py @@ -2,9 +2,15 @@ from numba.core import sigutils import numpy as np + ## Example of an overload: matmul @mlir def matmul(typingctx, X, Y): # TODO (#1162): add a check if input types are valid! # TODO (#1162): How to get the shape (at least rank) of "intrinsic" (e.g. linalg.matul) arguments for correct type inference? - return sigutils._parse_signature_string("float32[:,:](float32[:,:],float32[:,:])"), np.matmul + return ( + sigutils._parse_signature_string( + "float32[:,:](float32[:,:],float32[:,:])" + ), + np.matmul, + ) diff --git a/frontend/heir/mlir/mlir_decorator.py b/frontend/heir/mlir/mlir_decorator.py index 285ddd4ab..c02bae2ee 100644 --- a/frontend/heir/mlir/mlir_decorator.py +++ b/frontend/heir/mlir/mlir_decorator.py @@ -11,106 +11,115 @@ class _MLIR(ReduceMixin): + """ + Dummy callable for _MLIR + """ + + _memo = weakref.WeakValueDictionary() + # hold refs to last N functions deserialized, retaining them in _memo + # regardless of whether there is another reference + _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE) + + __uuid = None + + def __init__(self, name, defn, prefer_literal=False, **kwargs): + self._ctor_kwargs = kwargs + self._name = name + self._defn = defn + self._prefer_literal = prefer_literal + functools.update_wrapper(self, defn) + + @property + def _uuid(self): """ - Dummy callable for _MLIR + An instance-specific UUID, to avoid multiple deserializations of + a given instance. + + Note this is lazily-generated, for performance reasons. + """ + u = self.__uuid + if u is None: + u = str(uuid.uuid1()) + self._set_uuid(u) + return u + + def _set_uuid(self, u): + assert self.__uuid is None + self.__uuid = u + self._memo[u] = self + self._recent.append(self) + + def _register(self): + # _ctor_kwargs + from numba.core.typing.templates import ( + make_intrinsic_template, + infer_global, + ) + + template = make_intrinsic_template( + self, + self._defn, + self._name, + prefer_literal=self._prefer_literal, + kwargs=self._ctor_kwargs, + ) + infer(template) + infer_global(self, types.Function(template)) + + def __call__(self, *args, **kwargs): + """ + Calls the Python Impl """ - _memo = weakref.WeakValueDictionary() - # hold refs to last N functions deserialized, retaining them in _memo - # regardless of whether there is another reference - _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE) - - __uuid = None - - def __init__(self, name, defn, prefer_literal=False, **kwargs): - self._ctor_kwargs = kwargs - self._name = name - self._defn = defn - self._prefer_literal = prefer_literal - functools.update_wrapper(self, defn) - - @property - def _uuid(self): - """ - An instance-specific UUID, to avoid multiple deserializations of - a given instance. - - Note this is lazily-generated, for performance reasons. - """ - u = self.__uuid - if u is None: - u = str(uuid.uuid1()) - self._set_uuid(u) - return u - - def _set_uuid(self, u): - assert self.__uuid is None - self.__uuid = u - self._memo[u] = self - self._recent.append(self) - - def _register(self): - # _ctor_kwargs - from numba.core.typing.templates import (make_intrinsic_template, - infer_global) - - template = make_intrinsic_template(self, self._defn, self._name, - prefer_literal=self._prefer_literal, - kwargs=self._ctor_kwargs) - infer(template) - infer_global(self, types.Function(template)) - - def __call__(self, *args, **kwargs): - """ - Calls the Python Impl - """ - _, impl = self._defn(None, *args, **kwargs) - return impl(*args, **kwargs) - - def __repr__(self): - return "".format(self._name) - - def __deepcopy__(self, memo): - # NOTE: Intrinsic are immutable and we don't need to copy. - # This is triggered from deepcopy of statements. - return self - - def _reduce_states(self): - """ - NOTE: part of ReduceMixin protocol - """ - return dict(uuid=self._uuid, name=self._name, defn=self._defn) - - @classmethod - def _rebuild(cls, uuid, name, defn): - """ - NOTE: part of ReduceMixin protocol - """ - try: - return cls._memo[uuid] - except KeyError: - llc = cls(name=name, defn=defn) - llc._register() - llc._set_uuid(uuid) - return llc + _, impl = self._defn(None, *args, **kwargs) + return impl(*args, **kwargs) + def __repr__(self): + return "".format(self._name) -def mlir(*args, **kwargs): + def __deepcopy__(self, memo): + # NOTE: Intrinsic are immutable and we don't need to copy. + # This is triggered from deepcopy of statements. + return self + + def _reduce_states(self): """ - TODO (#1162): update this doc + NOTE: part of ReduceMixin protocol """ - # Make inner function for the actual work - def _mlir(func): - name = getattr(func, '__name__', str(func)) - llc = _MLIR(name, func, **kwargs) - llc._register() - return llc - - if not kwargs: - # No option is given - return _mlir(*args) - else: - # options are given, create a new callable to recv the - # definition function - def wrapper(func): - return _mlir(func) - return wrapper + return dict(uuid=self._uuid, name=self._name, defn=self._defn) + + @classmethod + def _rebuild(cls, uuid, name, defn): + """ + NOTE: part of ReduceMixin protocol + """ + try: + return cls._memo[uuid] + except KeyError: + llc = cls(name=name, defn=defn) + llc._register() + llc._set_uuid(uuid) + return llc + + +def mlir(*args, **kwargs): + """ + TODO (#1162): update this doc + """ + + # Make inner function for the actual work + def _mlir(func): + name = getattr(func, "__name__", str(func)) + llc = _MLIR(name, func, **kwargs) + llc._register() + return llc + + if not kwargs: + # No option is given + return _mlir(*args) + else: + # options are given, create a new callable to recv the + # definition function + def wrapper(func): + return _mlir(func) + + return wrapper diff --git a/frontend/heir/mlir/types.py b/frontend/heir/mlir/types.py index 16814b07a..a682c4f09 100644 --- a/frontend/heir/mlir/types.py +++ b/frontend/heir/mlir/types.py @@ -2,85 +2,110 @@ from typing import TypeVar, TypeVarTuple, Generic, get_args, get_origin, _GenericAlias -T = TypeVar('T') +T = TypeVar("T") Ts = TypeVarTuple("Ts") + class MLIRTypeAnnotation: + def numba_str(): - raise NotImplementedError("No numba type exists for a generic MLIRTypeAnnotation") + raise NotImplementedError( + "No numba type exists for a generic MLIRTypeAnnotation" + ) + class Secret(Generic[T], MLIRTypeAnnotation): + def numba_str(): raise NotImplementedError("No numba type exists for a generic Secret") + class Tensor(Generic[*Ts], MLIRTypeAnnotation): + def numba_str(): raise NotImplementedError("No numba type exists for a generic Tensor") + class F32(MLIRTypeAnnotation): # TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable def numba_str(): return "float32" + class F64(MLIRTypeAnnotation): # TODO (#1162): For CKKS/Float: allow specifying actual intended precision/scale and warn/error if not achievable def numba_str(): return "float64" + class I1(MLIRTypeAnnotation): + def numba_str(): return "bool" + class I4(MLIRTypeAnnotation): + def numba_str(): return "int4" + class I8(MLIRTypeAnnotation): + def numba_str(): return "int8" + class I16(MLIRTypeAnnotation): + def numba_str(): return "int16" + class I32(MLIRTypeAnnotation): + def numba_str(): return "int32" + + class I64(MLIRTypeAnnotation): - def numba_str(): - return "int64" + def numba_str(): + return "int64" # Helper functions + def to_numba_str(type) -> str: - if(get_origin(type) == Secret): - raise TypeError("Secret type should not appear inside another type annotation.") + if get_origin(type) == Secret: + raise TypeError( + "Secret type should not appear inside another type annotation." + ) - if(get_origin(type) == Tensor): + if get_origin(type) == Tensor: args = get_args(type) inner_type = args[len(args) - 1] - if(get_origin(inner_type) == Tensor): + if get_origin(inner_type) == Tensor: raise TypeError("Nested Tensors are not yet supported.") return f"{to_numba_str(inner_type)}[{','.join([':'] * (len(args) - 1))}]" - if(issubclass(type, MLIRTypeAnnotation)): + if issubclass(type, MLIRTypeAnnotation): return type.numba_str() raise TypeError(f"Unsupported type annotation: {type}") def parse_annotations(annotations): - if (not annotations): + if not annotations: raise TypeError("Function is missing type annotations.") signature = "" secret_args = [] for idx, (_, arg_type) in enumerate(annotations.items()): - if get_origin(arg_type) == Secret: - secret_args.append(idx) - assert(len(get_args(arg_type)) == 1) - signature+= f"{to_numba_str(get_args(arg_type)[0])}," - else: - signature += f"{to_numba_str(arg_type)}," + if get_origin(arg_type) == Secret: + secret_args.append(idx) + assert len(get_args(arg_type)) == 1 + signature += f"{to_numba_str(get_args(arg_type)[0])}," + else: + signature += f"{to_numba_str(arg_type)}," return signature, secret_args diff --git a/frontend/heir/mlir_emitter.py b/frontend/heir/mlir_emitter.py index f95cdf439..9d2d8b181 100644 --- a/frontend/heir/mlir_emitter.py +++ b/frontend/heir/mlir_emitter.py @@ -7,9 +7,10 @@ from numba.core import ir from numba.core import types + def mlirType(numba_type): if isinstance(numba_type, types.Integer): - #TODO (#1162): fix handling of signedness + # TODO (#1162): fix handling of signedness # Since `arith` only allows signless integers, we ignore signedness here. return "i" + str(numba_type.bitwidth) if isinstance(numba_type, types.RangeType): @@ -21,14 +22,18 @@ def mlirType(numba_type): if isinstance(numba_type, types.Complex): return "complex<" + str(numba_type.bitwidth) + ">" if isinstance(numba_type, types.Array): - #TODO (#1162): implement support for statically sized tensors + # TODO (#1162): implement support for statically sized tensors # this probably requires extending numba with a new type # See https://numba.readthedocs.io/en/stable/extending/index.html - return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" + return "tensor<" + "?x" * numba_type.ndim + mlirType(numba_type.dtype) + ">" raise NotImplementedError("Unsupported type: " + str(numba_type)) -def mlirLoc(loc : ir.Loc): - return f"loc(\"{loc.filename or ''}\":{loc.line or 0}:{loc.col or 0})" + +def mlirLoc(loc: ir.Loc): + return ( + f"loc(\"{loc.filename or ''}\":{loc.line or 0}:{loc.col or 0})" + ) + def arithSuffix(numba_type): if isinstance(numba_type, types.Integer): @@ -38,7 +43,9 @@ def arithSuffix(numba_type): if isinstance(numba_type, types.Float): return "f" if isinstance(numba_type, types.Complex): - raise NotImplementedError("Complex numbers not supported in `arith` dialect") + raise NotImplementedError( + "Complex numbers not supported in `arith` dialect" + ) if isinstance(numba_type, types.Array): return arithSuffix(numba_type.dtype) raise NotImplementedError("Unsupported type: " + str(numba_type)) @@ -132,11 +139,12 @@ def is_range_call(instr, ssa_ir): class TextualMlirEmitter: + def __init__(self, ssa_ir, secret_args, typemap, retty): self.ssa_ir = ssa_ir self.secret_args = secret_args self.typemap = typemap - self.retty = retty, + self.retty = (retty,) self.temp_var_id = 0 self.numba_names_to_ssa_var_names = {} self.globals_map = {} @@ -149,10 +157,14 @@ def emit(self): secret_flag = " {secret.secret}" # probably should use unique name... # func_name = ssa_ir.func_id.unique_name - args_str = ", ".join([f"%{name}: {mlirType(self.typemap.get(name))}{secret_flag if idx in self.secret_args else str()} {mlirLoc(self.ssa_ir.loc)}" for idx, name in enumerate(self.ssa_ir.arg_names)]) + args_str = ", ".join([ + f"%{name}:" + f" {mlirType(self.typemap.get(name))}{secret_flag if idx in self.secret_args else str()} {mlirLoc(self.ssa_ir.loc)}" + for idx, name in enumerate(self.ssa_ir.arg_names) + ]) # TODO(#1162): support multiple return values! - if(len(self.retty) > 1): + if len(self.retty) > 1: raise NotImplementedError("Multiple return values not supported") return_types_str = mlirType(self.retty[0]) @@ -262,7 +274,10 @@ def emit_assign(self, assign): case ir.Expr(op="binop"): name = self.get_or_create_name(assign.target) emitted_expr = self.emit_binop(assign.value) - return f"{name} = {emitted_expr} : {mlirType(self.typemap.get(assign.target.name))} {mlirLoc(assign.loc)}" + return ( + f"{name} = {emitted_expr} :" + f" {mlirType(self.typemap.get(assign.target.name))} {mlirLoc(assign.loc)}" + ) case ir.Expr(op="call"): func = assign.value.func # if assert fails, variable was undefined @@ -287,20 +302,34 @@ def emit_assign(self, assign): dims = [] for i in range(target_numba_type.ndim): cst = self.get_next_name() - str += f"{cst} = arith.constant {i} : index {mlirLoc(assign.loc)}\n" + str += ( + f"{cst} = arith.constant {i} : index {mlirLoc(assign.loc)}\n" + ) dim = self.get_next_name() dims.append(dim) - str += f"{dim} = tensor.dim {lhs}, {cst} : {lhs_ty} {mlirLoc(assign.loc)}\n" + str += ( + f"{dim} = tensor.dim {lhs}, {cst} :" + f" {lhs_ty} {mlirLoc(assign.loc)}\n" + ) empty = self.get_next_name() - str += f"{empty} = tensor.empty({','.join(dims)}) : {out_ty} {mlirLoc(assign.loc)}\n" - str += f"{name} = linalg.matmul ins({lhs}, {rhs} : {lhs_ty}, {rhs_ty}) outs({empty} : {out_ty}) -> {out_ty} {mlirLoc(assign.loc)}" + str += ( + f"{empty} = tensor.empty({','.join(dims)}) :" + f" {out_ty} {mlirLoc(assign.loc)}\n" + ) + str += ( + f"{name} = linalg.matmul ins({lhs}, {rhs} : {lhs_ty}, {rhs_ty})" + f" outs({empty} : {out_ty}) -> {out_ty} {mlirLoc(assign.loc)}" + ) return str else: - #TODO (#1162): implement support for statically sized tensors + # TODO (#1162): implement support for statically sized tensors # this probably requires extending numba with a new type # See https://numba.readthedocs.io/en/stable/extending/index.html - raise NotImplementedError(f"Unsupported target type {target_numba_type} for {assign.target.name}.") + raise NotImplementedError( + f"Unsupported target type {target_numba_type} for" + f" {assign.target.name}." + ) else: raise NotImplementedError("Unknown global " + func.name) case ir.Expr(op="cast"): @@ -312,7 +341,10 @@ def emit_assign(self, assign): name = self.get_or_create_name(assign.target) # TODO(#1162): fix type (somehow the pretty printer on assign.value # knows it's an int???) - return f"{name} = arith.constant {assign.value.value} : i64 {mlirLoc(assign.loc)}" + return ( + f"{name} = arith.constant {assign.value.value} : i64" + f" {mlirLoc(assign.loc)}" + ) case ir.Global(): self.globals_map[assign.target.name] = assign.value.name return "" @@ -433,4 +465,7 @@ def emit_loop(self, target): def emit_return(self, ret): var = self.get_name(ret.value) - return f"func.return {var} : {mlirType(self.typemap.get(str(ret.value)))} {mlirLoc(ret.loc)}" + return ( + f"func.return {var} :" + f" {mlirType(self.typemap.get(str(ret.value)))} {mlirLoc(ret.loc)}" + ) diff --git a/frontend/heir/pipeline.py b/frontend/heir/pipeline.py index aa4ff038f..5830b795e 100644 --- a/frontend/heir/pipeline.py +++ b/frontend/heir/pipeline.py @@ -1,4 +1,5 @@ """The compilation pipeline.""" + import pathlib import shutil import tempfile @@ -24,13 +25,13 @@ Path = pathlib.Path HEIRConfig = heir_cli_config.HEIRConfig + def run_pipeline( function, - heir_opt_options : list[str], - backend : BackendInterface, + heir_opt_options: list[str], + backend: BackendInterface, heir_config: HEIRConfig = heir_cli_config.DEVELOPMENT_HEIR_CONFIG, debug: bool = False, - ) -> ClientInterface: """Run the pipeline.""" # The temporary workspace dir is so that heir-opt and the backend @@ -50,32 +51,52 @@ def run_pipeline( numba_signature = "" secret_args = "" try: - numba_signature, secret_args = parse_annotations(function.__annotations__) + numba_signature, secret_args = parse_annotations(function.__annotations__) except Exception as e: - print(Fore.RED + Style.BRIGHT + f"HEIR Error: Signature parsing failed for function {func_name} with {type(e).__name__}: {e}") - raise + print( + Fore.RED + + Style.BRIGHT + + "HEIR Error: Signature parsing failed for function" + f" {func_name} with {type(e).__name__}: {e}" + ) + raise try: - fn_args, _ = sigutils.normalize_signature(numba_signature) + fn_args, _ = sigutils.normalize_signature(numba_signature) except Exception as e: - print(Fore.RED + Style.BRIGHT + f"HEIR Error: Signature normalization failed for function {func_name} with signature {numba_signature} with {type(e).__name__}: {e}") - raise + print( + Fore.RED + + Style.BRIGHT + + "HEIR Error: Signature normalization failed for function" + f" {func_name} with signature {numba_signature} with" + f" {type(e).__name__}: {e}" + ) + raise typingctx = cpu_target.typing_context targetctx = cpu_target.target_context typingctx.refresh() targetctx.refresh() try: - typemap, restype, _, _ = type_inference_stage(typingctx, targetctx, ssa_ir, fn_args, None) + typemap, restype, _, _ = type_inference_stage( + typingctx, targetctx, ssa_ir, fn_args, None + ) except Exception as e: - print(Fore.RED + Style.BRIGHT + f"HEIR Error: Type inference failed for function {func_name} with signature {numba_signature} with {type(e).__name__}: {e}") - raise + print( + Fore.RED + + Style.BRIGHT + + f"HEIR Error: Type inference failed for function {func_name} with" + f" signature {numba_signature} with {type(e).__name__}: {e}" + ) + raise # Emit Textual IR - mlir_textual = TextualMlirEmitter(ssa_ir, secret_args, typemap, restype).emit() - if(debug): - mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" - print(f"HEIR Debug: Writing input MLIR to \t \t {mlir_in_filepath}") - with open(mlir_in_filepath, "w") as f: - f.write(mlir_textual) + mlir_textual = TextualMlirEmitter( + ssa_ir, secret_args, typemap, restype + ).emit() + if debug: + mlir_in_filepath = Path(workspace_dir) / f"{func_name}.in.mlir" + print(f"HEIR Debug: Writing input MLIR to \t \t {mlir_in_filepath}") + with open(mlir_in_filepath, "w") as f: + f.write(mlir_textual) # Try to find heir_opt and heir_translate heir_opt = heir_cli.HeirOptBackend(heir_config.heir_opt_path) @@ -84,43 +105,55 @@ def run_pipeline( ) # Print type annotated version of the input - if(debug): - mlirpath = Path(workspace_dir) / f"{func_name}.annotated.mlir" - graphpath = Path(workspace_dir) / f"{func_name}.annotated.dot" - heir_opt_output, graph = heir_opt.run_binary_stderr( - input=mlir_textual, - options=["--annotate-secretness", "--view-op-graph"], - ) - print(f"HEIR Debug: Writing secretness-annotated MLIR to \t {mlirpath}") - with open(mlirpath, "w") as f: - f.write(heir_opt_output) - - print(f"HEIR Debug: Writing secretness-annotated graph to \t {graphpath}") - with open(graphpath, "w") as f: - f.write(graph) - + if debug: + mlirpath = Path(workspace_dir) / f"{func_name}.annotated.mlir" + graphpath = Path(workspace_dir) / f"{func_name}.annotated.dot" + heir_opt_output, graph = heir_opt.run_binary_stderr( + input=mlir_textual, + options=["--annotate-secretness", "--view-op-graph"], + ) + print(f"HEIR Debug: Writing secretness-annotated MLIR to \t {mlirpath}") + with open(mlirpath, "w") as f: + f.write(heir_opt_output) + + print(f"HEIR Debug: Writing secretness-annotated graph to \t {graphpath}") + with open(graphpath, "w") as f: + f.write(graph) # Run heir_opt - if(debug): - heir_opt_options.append("--view-op-graph") - print("HEIR Debug: " + Style.BRIGHT + f"Running heir-opt {' '.join(heir_opt_options)}") + if debug: + heir_opt_options.append("--view-op-graph") + print( + "HEIR Debug: " + + Style.BRIGHT + + f"Running heir-opt {' '.join(heir_opt_options)}" + ) heir_opt_output, graph = heir_opt.run_binary_stderr( input=mlir_textual, options=(heir_opt_options), ) - if(debug): - # Print output after heir_opt: - mlirpath = Path(workspace_dir) / f"{func_name}.out.mlir" - graphpath = Path(workspace_dir) / f"{func_name}.out.dot" - print(f"HEIR Debug: Writing output MLIR to \t \t {mlirpath}") - with open(mlirpath, "w") as f: - f.write(heir_opt_output) - print(f"HEIR Debug: Writing output graph to \t \t {graphpath}") - with open(graphpath, "w") as f: - f.write(graph) + if debug: + # Print output after heir_opt: + mlirpath = Path(workspace_dir) / f"{func_name}.out.mlir" + graphpath = Path(workspace_dir) / f"{func_name}.out.dot" + print(f"HEIR Debug: Writing output MLIR to \t \t {mlirpath}") + with open(mlirpath, "w") as f: + f.write(heir_opt_output) + print(f"HEIR Debug: Writing output graph to \t \t {graphpath}") + with open(graphpath, "w") as f: + f.write(graph) # Run backend (which will call heir_translate and other tools, e.g., clang, as needed) - result = backend.run_backend(workspace_dir, heir_opt, heir_translate, func_name, arg_names, secret_args, heir_opt_output, debug) + result = backend.run_backend( + workspace_dir, + heir_opt, + heir_translate, + func_name, + arg_names, + secret_args, + heir_opt_output, + debug, + ) # Attach the original python func result.func = function @@ -129,18 +162,24 @@ def run_pipeline( finally: if debug: - print(f"HEIR Debug: Leaving workspace_dir {workspace_dir} for manual inspection.\n") + print( + f"HEIR Debug: Leaving workspace_dir {workspace_dir} for manual" + " inspection.\n" + ) else: shutil.rmtree(workspace_dir) - def compile( scheme: Optional[str] = "bgv", - backend: Optional[BackendInterface] = OpenFHEBackend(openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG), - heir_config: Optional[heir_cli_config.HEIRConfig] = heir_cli_config.from_os_env(), + backend: Optional[BackendInterface] = OpenFHEBackend( + openfhe_config.DEFAULT_INSTALLED_OPENFHE_CONFIG + ), + heir_config: Optional[ + heir_cli_config.HEIRConfig + ] = heir_cli_config.from_os_env(), debug: Optional[bool] = False, - heir_opt_options: Optional[list[str]] = None + heir_opt_options: Optional[list[str]] = None, ): """Compile a function to its private equivalent in FHE. @@ -158,7 +197,7 @@ def compile( Returns: The decorator to apply to the given function. """ - if(debug and heir_opt_options is not None): + if debug and heir_opt_options is not None: print(f"HEIR Debug: Overriding scheme with options {heir_opt_options}") def decorator(func): @@ -167,7 +206,7 @@ def decorator(func): heir_opt_options=heir_opt_options or [f"--mlir-to-{scheme}"], backend=backend or DummyBackend(), heir_config=heir_config or heir_config.from_os_env(), - debug = debug, + debug=debug, ) return decorator