Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev][TL] Decouple 3rdparty TileLang Backend with TVM #217

Open
wants to merge 54 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
f3b1eb9
Refactor tilelang dequantize module and add matmul_blocked_weight_onl…
LeiWang1999 Sep 28, 2024
730d13e
remove un-implemented code.
LeiWang1999 Sep 28, 2024
8047ee7
Implement BaseScheduler to wrap some related items.
LeiWang1999 Sep 28, 2024
64db065
lint fix
LeiWang1999 Sep 28, 2024
cef04a8
test skip
LeiWang1999 Sep 28, 2024
f1652e9
Refactor tilelang dequantize module and add matmul_blocked_weight_onl…
LeiWang1999 Sep 29, 2024
4f6c545
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Sep 29, 2024
c485b68
test fix
LeiWang1999 Sep 29, 2024
ebe42a6
hardware tuning demo
LeiWang1999 Sep 29, 2024
88230ec
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Sep 29, 2024
44246a1
remove debug related items.
LeiWang1999 Sep 30, 2024
bb51e15
imlement tuner and cache fix
LeiWang1999 Oct 1, 2024
f42a3b9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 1, 2024
de7ae18
lint fix
LeiWang1999 Oct 1, 2024
ef40bd8
test case fix.
LeiWang1999 Oct 1, 2024
85f0a5f
Adapt Tuning Space generation with Roller
LeiWang1999 Oct 1, 2024
e9f7db3
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 1, 2024
9e31336
lint fix
LeiWang1999 Oct 1, 2024
2f1a260
Refactor select_scheduler function for fine-grained interface
LeiWang1999 Oct 1, 2024
f1378d4
Refactor select_scheduler function for fine-grained interface
LeiWang1999 Oct 1, 2024
137cce3
Refactor NotImplementedError message in BaseTLHint class
LeiWang1999 Oct 1, 2024
fc19fa2
Update submodule reference in 3rdparty/tvm
LeiWang1999 Oct 2, 2024
fe51bb1
Refactor matmul_finetune function to use topk=20 for hardware-aware f…
LeiWang1999 Oct 2, 2024
79878cb
Refactor submodule reference in 3rdparty/tvm
LeiWang1999 Oct 2, 2024
0fc7ab9
lint fix
LeiWang1999 Oct 2, 2024
255e925
Refactor test_general_matmul_tilelang_impl.py and test_tilelang_gemm.py
LeiWang1999 Oct 2, 2024
df47f63
Refactor MatmulConfig to enable weight propagation on supported devices
LeiWang1999 Oct 2, 2024
826255d
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 2, 2024
48dc94e
Refactor test_general_matmul_tilelang_impl.py and test_general_matmul…
LeiWang1999 Oct 2, 2024
82f39d7
test fix
LeiWang1999 Oct 2, 2024
02ef258
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 2, 2024
e753ef2
test fix
LeiWang1999 Oct 2, 2024
f6dd744
Refactor flash attention tests to use centered random values for inpu…
LeiWang1999 Oct 2, 2024
7417372
Refactor flash attention tests to use centered random values for inpu…
LeiWang1999 Oct 2, 2024
145a850
Refactor flash attention tests to skip test if flash_attn is not inst…
LeiWang1999 Oct 2, 2024
3384458
lint fix
LeiWang1999 Oct 3, 2024
82f50ea
test fix
LeiWang1999 Oct 3, 2024
d2ed936
test fix
LeiWang1999 Oct 3, 2024
6c56273
test fix
LeiWang1999 Oct 3, 2024
2e59e58
Merge branch 'main' of https://github.com/microsoft/BitBLAS into tl_o…
LeiWang1999 Oct 6, 2024
074b9ca
Refactor quantization module imports
LeiWang1999 Oct 6, 2024
0923344
lint fix
LeiWang1999 Oct 6, 2024
b30bcd4
Update yapf version in requirements-dev.txt and requirements-test.txt
LeiWang1999 Oct 6, 2024
d0a88ac
Refactor shared memory to global memory storage in MatmulFineGrainSch…
LeiWang1999 Oct 6, 2024
62303e2
test fix
LeiWang1999 Oct 6, 2024
01dc3f9
format
LeiWang1999 Oct 6, 2024
c621664
test fix
LeiWang1999 Oct 7, 2024
652b061
Merge branch 'main' of https://github.com/microsoft/BitBLAS into dequant
LeiWang1999 Oct 9, 2024
c95d537
Add tile-lang submodule for TileLang integration
LeiWang1999 Oct 10, 2024
77676aa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into migr…
LeiWang1999 Oct 10, 2024
70c23c3
Update tile-lang submodule commit
LeiWang1999 Oct 10, 2024
20cf4a6
Update TileLang import
LeiWang1999 Oct 10, 2024
f54fc65
test fix
LeiWang1999 Oct 10, 2024
9244541
Refactor test_general_matmul_tile_schedule.py for dequantization with…
LeiWang1999 Oct 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@
path = 3rdparty/cutlass
url = https://github.com/TileLang/cutlass
branch = tldev
[submodule "3rdparty/tile-lang"]
path = 3rdparty/tile-lang
url = https://github.com/TileLang/tile-lang
branch = dev
1 change: 1 addition & 0 deletions 3rdparty/tile-lang
Submodule tile-lang added at 84e731
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 69 files
+0 −9 CMakeLists.txt
+2 −2 cmake/config.cmake
+0 −20 python/tvm/tl/__init__.py
+0 −21 python/tvm/tl/_ffi_api.py
+0 −139 python/tvm/tl/autotuner.py
+0 −167 python/tvm/tl/engine.py
+0 −341 python/tvm/tl/language.py
+0 −94 python/tvm/tl/layout.py
+0 −152 python/tvm/tl/transform.py
+0 −304 python/tvm/tl/utils.py
+0 −18 src/tir/analysis/block_access_region_detector.cc
+0 −381 src/tir/transforms/thread_partial_sync.cc
+0 −1 src/tir/transforms/thread_storage_sync.cc
+0 −162 src/tl/ir.cc
+0 −348 src/tl/layout/gemm_layouts.cc
+0 −417 src/tl/layout/layout.cc
+0 −167 src/tl/layout/layout.h
+0 −116 src/tl/layout/swizzle.cc
+0 −91 src/tl/layout/swizzle.h
+0 −262 src/tl/layout/utils.cc
+0 −76 src/tl/layout/utils.h
+0 −106 src/tl/op/builtin.cc
+0 −168 src/tl/op/builtin.h
+0 −393 src/tl/op/bulk_copy.cc
+0 −82 src/tl/op/bulk_copy.h
+0 −363 src/tl/op/elem.cc
+0 −82 src/tl/op/elem.h
+0 −216 src/tl/op/gemm.cc
+0 −62 src/tl/op/gemm.h
+0 −102 src/tl/op/op.cc
+0 −113 src/tl/op/op.h
+0 −190 src/tl/op/parallel.cc
+0 −88 src/tl/op/parallel.h
+0 −230 src/tl/op/reduce.cc
+0 −62 src/tl/op/reduce.h
+0 −203 src/tl/runtime/runtime.cc
+0 −37 src/tl/runtime/runtime.h
+0 −1,465 src/tl/target/codegen.cc
+0 −100 src/tl/target/codegen.h
+0 −109 src/tl/target/rt_mod.cc
+0 −85 src/tl/target/utils.cc
+0 −48 src/tl/target/utils.h
+0 −69 src/tl/tl_templates/common.h
+0 −73 src/tl/tl_templates/copy.h
+0 −227 src/tl/tl_templates/copy_sm90.h
+0 −10 src/tl/tl_templates/gemm.h
+0 −160 src/tl/tl_templates/gemm_sm70.h
+0 −314 src/tl/tl_templates/gemm_sm80.h
+0 −187 src/tl/tl_templates/gemm_sm90.h
+0 −100 src/tl/tl_templates/ldsm.h
+0 −55 src/tl/tl_templates/reduce.h
+0 −39 src/tl/tl_templates/threadblock_swizzle.h
+0 −133 src/tl/transform/cluster_planning.cc
+0 −94 src/tl/transform/frontend_legalize.cc
+0 −170 src/tl/transform/inject_fence_proxy.cc
+0 −242 src/tl/transform/inject_mbarrier.cc
+0 −934 src/tl/transform/inject_pipeline.cc
+0 −291 src/tl/transform/layout_inference.cc
+0 −164 src/tl/transform/loop_partition.cc
+0 −46 src/tl/transform/loop_partition.h
+0 −456 src/tl/transform/loop_vectorize.cc
+0 −45 src/tl/transform/loop_vectorize.h
+0 −157 src/tl/transform/lower_hopper_intrin.cc
+0 −306 src/tl/transform/lower_tile_op.cc
+0 −321 src/tl/transform/multi_version_buffer_rewriter.cc
+0 −242 src/tl/transform/pipeline_planning.cc
+0 −475 src/tl/transform/simplify.cc
+0 −849 src/tl/transform/warp_specialized_pipeline.cc
+0 −939 src/tl/transform/warp_specialized_rewriter.cc
157 changes: 107 additions & 50 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,21 @@
import sys
import os

# installing tvm
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
fast_tune, # noqa: F401
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
ScheduleRule, # noqa: F401
normalize_prim_func, # noqa: F401
try_inline, # noqa: F401
try_inline_contiguous_spatial, # noqa: F401
)

from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

import warnings
import functools
import logging
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
""" Custom logging handler that directs log output to tqdm progress bar to avoid interference. """
"""Custom logging handler that directs log output to tqdm progress bar to avoid interference."""

def __init__(self, level=logging.NOTSET):
""" Initialize the handler with an optional log level. """
"""Initialize the handler with an optional log level."""
super().__init__(level)

def emit(self, record):
""" Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """
"""Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
try:
msg = self.format(record)
tqdm.write(msg)
Expand All @@ -67,8 +26,8 @@ def emit(self, record):


def set_log_level(level):
""" Set the logging level for the module's logger.
"""Set the logging level for the module's logger.

Args:
level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
Expand All @@ -80,15 +39,17 @@ def set_log_level(level):


def _init_logger():
""" Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """
"""Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
logger = logging.getLogger(__name__)
handler = TqdmLoggingHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
set_log_level('WARNING')
set_log_level("WARNING")


_init_logger()
Expand All @@ -107,12 +68,108 @@ def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2)
stacklevel=2,
)
return func(*args, **kwargs)

return new_func

return decorator


logger = logging.getLogger(__name__)

# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."

# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)

if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python")
else:
# installed 3rdparty tvm
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tvm_path + "/python")

# developed 3rdparty tvm
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python")

if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)

install_tilelang_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tile-lang")
develop_tilelang_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tile-lang")

if os.environ.get("TL_TEMPLATE_PATH", None) is None:
sys.path.insert(0, install_tilelang_path + "/python")
if os.path.exists(install_tilelang_path):
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tilelang_path, "src")
elif os.path.exists(develop_tilelang_path):
os.environ["TL_TEMPLATE_PATH"] = os.path.join(develop_tilelang_path, "src")
else:
logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE)

if (os.path.exists(install_tilelang_path) and install_tilelang_path not in sys.path):
os.environ["PYTHONPATH"] = (
install_tilelang_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tilelang_path + "/python")

if (os.path.exists(develop_tilelang_path) and develop_tilelang_path not in sys.path):
os.environ["PYTHONPATH"] = (
develop_tilelang_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tilelang_path + "/python")

import tvm as tvm # noqa: E402
import tilelang as tilelang # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
fast_tune, # noqa: F401
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
ScheduleRule, # noqa: F401
normalize_prim_func, # noqa: F401
try_inline, # noqa: F401
try_inline_contiguous_spatial, # noqa: F401
)

from . import testing # noqa: F401
from .utils import (
auto_detect_nvidia_target, # noqa: F401
apply_transform_on_input, # noqa: F401
)
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import (
MatmulConfigWithSplitK, # noqa: F401
MatmulWithSplitK, # noqa: F401
)
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

__version__ = "0.0.1.dev15"
32 changes: 16 additions & 16 deletions bitblas/builder/lib_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False):
]

if with_tl:
install_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm")
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm")
install_tilelang_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tile-lang")
develop_tilelang_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tile-lang")

tvm_root = next((path for path in [install_tvm_path, develop_tvm_path]
tilelang_root = next((path for path in [install_tilelang_path, develop_tilelang_path]
if os.path.exists(path) and path not in sys.path), None)

if "TL_TEMPLATE_PATH " in os.environ:
tl_template_path = os.environ["TL_TEMPLATE_PATH"]
else:
tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))

tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl"))
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
else:
cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include"))
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "cutlass")
cutlass_root = next((path for path in [install_cutlass_path, develop_cutlass_path]
if os.path.exists(path) and path not in sys.path), None)

tl_template_path = tl_template_path = os.environ["TL_TEMPLATE_PATH"] if "TL_TEMPLATE_PATH" in os.environ else osp.abspath(osp.join(tilelang_root, "src"))

cutlass_path = os.environ["TL_CUTLASS_PATH"] if "TL_CUTLASS_PATH" in os.environ else osp.abspath(osp.join(cutlass_root, "include"))

command += [
"-I" + tl_template_path,
Expand Down
6 changes: 5 additions & 1 deletion bitblas/ops/base_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from tvm import IRModule
from tvm.tir import PrimFunc
from typing import Union, Callable
from dataclasses import dataclass, field
from tvm.tl.transform import Simplify
from tilelang.transform import Simplify
from abc import ABC, abstractmethod
from bitblas.base.arch import TileDevice

Expand Down
3 changes: 2 additions & 1 deletion bitblas/ops/general_flashatten/tilelang/flashatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Licensed under the MIT License.

from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from bitblas.ops.base_scheduler import BaseScheduler
import tvm.tl.language as T
import tilelang.language as T
from dataclasses import dataclass
from typing import Optional
import logging
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from tvm import DataType
import tvm.tl.language as T
import tilelang.language as T
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from tvm import DataType
import tvm.tl.language as T
import tilelang.language as T
from typing import Optional, List, Literal
from bitblas.ops.base_scheduler import BaseScheduler
from bitblas.base.arch import TileDevice
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from tvm import DataType
import tvm.tl.language as T
import tilelang.language as T
from typing import Optional, List, Literal
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
Expand Down
4 changes: 2 additions & 2 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.
from abc import ABC, abstractmethod
from bitblas import tvm
from tvm import tl
from bitblas import tilelang
from tvm import IRModule
from tvm.runtime.module import Module
from tvm.target import Target
Expand Down Expand Up @@ -196,7 +196,7 @@ def tvm_callback_cuda_postproc(code, _):
"tir.disable_cse_tir": True,
**(self.pass_context if self.pass_context else {})
}):
rt_mod = tl.lower(tl_prim_func, target=target, runtime_only=True)
rt_mod = tilelang.lower(tl_prim_func, target=target, runtime_only=True)
else:
raise ValueError(f"Unsupported backend: {self.backend}")
except Exception as build_runtime_error: # noqa: F841
Expand Down
5 changes: 3 additions & 2 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm.tl.language as T
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
import tilelang.language as T

from typing import Union
from bitblas.ops.common import TransformKind
Expand Down
6 changes: 3 additions & 3 deletions bitblas/tl/tuner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm
from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Tuple, Optional, Literal
from tvm import tir, IRModule
from tvm.runtime import Module
from tvm.tir import Schedule
import tvm.tl as tl
from bitblas.ops.base_scheduler import BaseScheduler
from bitblas.base.arch import CUDA
from bitblas.base.utils import get_dummy_input_arrays
Expand Down Expand Up @@ -133,7 +133,7 @@ def tvm_callback_cuda_postproc(code, _):
"tir.disable_cse_tir": True,
**(config.pass_context if config.pass_context else {})
}):
rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True)
rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True)

from tvm.contrib.tar import tar # Import the tar module

Expand Down
5 changes: 3 additions & 2 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from bitblas import tvm as tvm
from bitblas import tilelang as tilelang
from tvm import arith
from tvm import DataType
import tvm.tl.language as T
import tilelang.language as T
from typing import Union, Literal
from .mma_layout import (
ldmatrix_32x8_to_shared_16x16_layout,
Expand Down
Loading