diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index ca76d2097..c560afd0e 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -371,8 +371,14 @@ def __init__( self.bit = bit # This is a hack to support the int4 and uint4 + # legalize the backend (hacky implementation) + # TODO(lei): In future release we should remove + # by implementing all the operators in the tl backend. if config.A_dtype in ["int4", "uint4"]: backend = "tl" + if source_format in ["nf"]: + backend = "tir" + super().__init__(name, config, target, backend) if source_format == "int" and self.with_zeros: