Skip to content

Commit

Permalink
Share mlir context across Module op and CompilerClient creation (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
parthchadha authored Aug 2, 2024
1 parent 9150fb4 commit aaae741
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
11 changes: 5 additions & 6 deletions tripy/tripy/backend/mlir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,23 @@
)
from tripy.logging import logger

G_MLIR_CONTEXT = None
G_COMPILER_CLIENT = None
G_TIMING_CACHE_FILE = cfg.timing_cache_file_path


# Avoid instantiating the compiler more than once.
def _get_compiler_objects() -> Tuple[ir.Context, compiler.CompilerClient]:
global G_MLIR_CONTEXT, G_COMPILER_CLIENT, G_TIMING_CACHE_FILE
global G_COMPILER_CLIENT, G_TIMING_CACHE_FILE
if G_TIMING_CACHE_FILE != cfg.timing_cache_file_path:
# Reinitialize the compiler if the timing cache file path has changed.
global G_COMPILER_CLIENT
G_COMPILER_CLIENT = None
G_TIMING_CACHE_FILE = cfg.timing_cache_file_path

if G_MLIR_CONTEXT is None or G_COMPILER_CLIENT is None:
G_MLIR_CONTEXT = make_ir_context()
G_COMPILER_CLIENT = compiler.CompilerClient(G_MLIR_CONTEXT)
return G_MLIR_CONTEXT, G_COMPILER_CLIENT
ctx = make_ir_context()
if G_COMPILER_CLIENT is None:
G_COMPILER_CLIENT = compiler.CompilerClient(ctx)
return ctx, G_COMPILER_CLIENT


class Compiler:
Expand Down
20 changes: 13 additions & 7 deletions tripy/tripy/backend/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,25 @@
from tripy.common.exception import OmitStackInfo, raise_error
from tripy.logging import logger

# MLIR context needs to be shared across the Module op and CompilerClient
class MLIRContext:
_instance = None

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.context = ir.Context()
return cls._instance.context

def get_max_upper_bounds():
return sys.maxsize


def make_ir_context() -> ir.Context:
context = ir.Context()

context.enable_multithreading(False)
ctx = MLIRContext()
ctx.enable_multithreading(False)
# Allow unregistered dialects to assign trt shape_profile attribute to stablehlo program.
context.allow_unregistered_dialects = True
return context

ctx.allow_unregistered_dialects = True
return ctx

def get_mlir_dtype(dtype: "tripy.dtype"):
"""
Expand Down

0 comments on commit aaae741

Please sign in to comment.