diff --git a/.gitmodules b/.gitmodules index c8a359670..bd90314bf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,7 @@ path = 3rdparty/cutlass url = https://github.com/TileLang/cutlass branch = tldev +[submodule "3rdparty/tile-lang"] + path = 3rdparty/tile-lang + url = https://github.com/TileLang/tile-lang + branch = dev diff --git a/3rdparty/tile-lang b/3rdparty/tile-lang new file mode 160000 index 000000000..84e7317a7 --- /dev/null +++ b/3rdparty/tile-lang @@ -0,0 +1 @@ +Subproject commit 84e7317a7b518cac79217eaeda825b9650dbe988 diff --git a/3rdparty/tvm b/3rdparty/tvm index af0b40391..1ea229576 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit af0b403916c853160df4ee3d046bdd4182c1ea44 +Subproject commit 1ea229576f4ebc42d6aef7d878c1eb35ec0092aa diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 3074e3fcb..e6e54f9a9 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -3,47 +3,6 @@ import sys import os -# installing tvm -install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") -install_cutlass_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") -if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") - os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" - os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") - sys.path.insert(0, install_tvm_path + "/python") - -develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") -develop_cutlass_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") -if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") - os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" - os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") - sys.path.insert(0, develop_tvm_path + "/python") - -import tvm as tvm # noqa: E402 -from . import gpu # noqa: F401 -from .base import ( - TileDevice, # noqa: F401 - fast_tune, # noqa: F401 - ApplyDefaultSchedule, # noqa: F401 - ApplyFastTuning, # noqa: F401 - BlockInfo, # noqa: F401 - IterInfo, # noqa: F401 - ScheduleRule, # noqa: F401 - normalize_prim_func, # noqa: F401 - try_inline, # noqa: F401 - try_inline_contiguous_spatial, # noqa: F401 -) - -from . import testing # noqa: F401 -from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 -from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 -from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 -from .module import Linear # noqa: F401 - import warnings import functools import logging @@ -51,14 +10,14 @@ class TqdmLoggingHandler(logging.Handler): - """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ + """Custom logging handler that directs log output to tqdm progress bar to avoid interference.""" def __init__(self, level=logging.NOTSET): - """ Initialize the handler with an optional log level. """ + """Initialize the handler with an optional log level.""" super().__init__(level) def emit(self, record): - """ Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """ + """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted.""" try: msg = self.format(record) tqdm.write(msg) @@ -67,8 +26,8 @@ def emit(self, record): def set_log_level(level): - """ Set the logging level for the module's logger. - + """Set the logging level for the module's logger. + Args: level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' @@ -80,15 +39,17 @@ def set_log_level(level): def _init_logger(): - """ Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """ + """Initialize the logger specific for this module with custom settings and a Tqdm-based handler.""" logger = logging.getLogger(__name__) handler = TqdmLoggingHandler() formatter = logging.Formatter( - fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) handler.setFormatter(formatter) logger.addHandler(handler) logger.propagate = False - set_log_level('WARNING') + set_log_level("WARNING") _init_logger() @@ -107,7 +68,8 @@ def new_func(*args, **kwargs): warnings.warn( f"Call to deprecated function {func.__name__} ({reason}).", category=DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return func(*args, **kwargs) return new_func @@ -115,4 +77,99 @@ def new_func(*args, **kwargs): return decorator +logger = logging.getLogger(__name__) + +# SETUP ENVIRONMENT VARIABLES +CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." +TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." + +# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path +TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) + +if TVM_IMPORT_PYTHON_PATH is not None: + os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python") +else: + # installed 3rdparty tvm + install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") + if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = ( + install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, install_tvm_path + "/python") + + # developed 3rdparty tvm + develop_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") + if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = ( + develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, develop_tvm_path + "/python") + +if os.environ.get("TL_CUTLASS_PATH", None) is None: + install_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") + develop_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") + if os.path.exists(install_cutlass_path): + os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" + elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path): + os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" + else: + logger.warning(CUTLASS_NOT_FOUND_MESSAGE) + +install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tile-lang") +develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tile-lang") + +if os.environ.get("TL_TEMPLATE_PATH", None) is None: + sys.path.insert(0, install_tilelang_path + "/python") + if os.path.exists(install_tilelang_path): + os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tilelang_path, "src") + elif os.path.exists(develop_tilelang_path): + os.environ["TL_TEMPLATE_PATH"] = os.path.join(develop_tilelang_path, "src") + else: + logger.warning(TL_TEMPLATE_NOT_FOUND_MESSAGE) + +if (os.path.exists(install_tilelang_path) and install_tilelang_path not in sys.path): + os.environ["PYTHONPATH"] = ( + install_tilelang_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, install_tilelang_path + "/python") + +if (os.path.exists(develop_tilelang_path) and develop_tilelang_path not in sys.path): + os.environ["PYTHONPATH"] = ( + develop_tilelang_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, develop_tilelang_path + "/python") + +import tvm as tvm # noqa: E402 +import tilelang as tilelang # noqa: E402 +from . import gpu # noqa: F401 +from .base import ( + TileDevice, # noqa: F401 + fast_tune, # noqa: F401 + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + ScheduleRule, # noqa: F401 + normalize_prim_func, # noqa: F401 + try_inline, # noqa: F401 + try_inline_contiguous_spatial, # noqa: F401 +) + +from . import testing # noqa: F401 +from .utils import ( + auto_detect_nvidia_target, # noqa: F401 + apply_transform_on_input, # noqa: F401 +) +from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 +from .ops.general_matmul_splitk import ( + MatmulConfigWithSplitK, # noqa: F401 + MatmulWithSplitK, # noqa: F401 +) +from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 +from .module import Linear # noqa: F401 + __version__ = "0.0.1.dev15" diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index 1a9ababd2..64b1fde95 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -50,24 +50,24 @@ def compile_lib(self, timeout: float = None, with_tl: bool = False): ] if with_tl: - install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tvm") - develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tvm") + install_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "tile-lang") + develop_tilelang_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "tile-lang") - tvm_root = next((path for path in [install_tvm_path, develop_tvm_path] + tilelang_root = next((path for path in [install_tilelang_path, develop_tilelang_path] if os.path.exists(path) and path not in sys.path), None) - - if "TL_TEMPLATE_PATH " in os.environ: - tl_template_path = os.environ["TL_TEMPLATE_PATH"] - else: - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) - - tl_template_path = osp.abspath(osp.join(tvm_root, "src/tl")) - if "TL_CUTLASS_PATH" in os.environ: - cutlass_path = os.environ["TL_CUTLASS_PATH"] - else: - cutlass_path = osp.abspath(osp.join(tvm_root, "3rdparty/cutlass/include")) + + install_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../..", "3rdparty", "cutlass") + develop_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "../../..", "3rdparty", "cutlass") + cutlass_root = next((path for path in [install_cutlass_path, develop_cutlass_path] + if os.path.exists(path) and path not in sys.path), None) + + tl_template_path = tl_template_path = os.environ["TL_TEMPLATE_PATH"] if "TL_TEMPLATE_PATH" in os.environ else osp.abspath(osp.join(tilelang_root, "src")) + + cutlass_path = os.environ["TL_CUTLASS_PATH"] if "TL_CUTLASS_PATH" in os.environ else osp.abspath(osp.join(cutlass_root, "include")) command += [ "-I" + tl_template_path, diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index 35beeaf6c..82240875b 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -1,8 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import IRModule from tvm.tir import PrimFunc from typing import Union, Callable from dataclasses import dataclass, field -from tvm.tl.transform import Simplify +from tilelang.transform import Simplify from abc import ABC, abstractmethod from bitblas.base.arch import TileDevice diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index 2d5386022..8177d2c88 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -2,8 +2,9 @@ # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from bitblas.ops.base_scheduler import BaseScheduler -import tvm.tl.language as T +import tilelang.language as T from dataclasses import dataclass from typing import Optional import logging diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 227de7ad3..fa9e8e55f 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List from bitblas.tl.utils import ( get_mma_micro_size, diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py index 7a06d6959..3f19ef23e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/block_primitive_tensorcore.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List, Literal from bitblas.ops.base_scheduler import BaseScheduler from bitblas.base.arch import TileDevice diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index c98474ec0..d85a790f8 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Optional, List, Literal from bitblas.tl.utils import ( get_mma_micro_size, # noqa: F401 diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index d928c451d..736a01348 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from abc import ABC, abstractmethod from bitblas import tvm -from tvm import tl +from bitblas import tilelang from tvm import IRModule from tvm.runtime.module import Module from tvm.target import Target @@ -196,7 +196,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(self.pass_context if self.pass_context else {}) }): - rt_mod = tl.lower(tl_prim_func, target=target, runtime_only=True) + rt_mod = tilelang.lower(tl_prim_func, target=target, runtime_only=True) else: raise ValueError(f"Unsupported backend: {self.backend}") except Exception as build_runtime_error: # noqa: F841 diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/macro_generator.py index 0f7adb791..475740cbc 100644 --- a/bitblas/tl/macro_generator.py +++ b/bitblas/tl/macro_generator.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -import tvm.tl.language as T +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang +import tilelang.language as T from typing import Union from bitblas.ops.common import TransformKind diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py index 6747d0632..336795821 100644 --- a/bitblas/tl/tuner.py +++ b/bitblas/tl/tuner.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas import tvm +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List, Tuple, Optional, Literal from tvm import tir, IRModule from tvm.runtime import Module from tvm.tir import Schedule -import tvm.tl as tl from bitblas.ops.base_scheduler import BaseScheduler from bitblas.base.arch import CUDA from bitblas.base.utils import get_dummy_input_arrays @@ -133,7 +133,7 @@ def tvm_callback_cuda_postproc(code, _): "tir.disable_cse_tir": True, **(config.pass_context if config.pass_context else {}) }): - rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True) + rt_mod = tilelang.lower(tl_prim_func, arch.target, runtime_only=True) from tvm.contrib.tar import tar # Import the tar module diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 4b8b4cf6e..4fb8c432c 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import arith from tvm import DataType -import tvm.tl.language as T +import tilelang.language as T from typing import Union, Literal from .mma_layout import ( ldmatrix_32x8_to_shared_16x16_layout, diff --git a/bitblas/utils/rtmod_analysis.py b/bitblas/utils/rtmod_analysis.py index e3fe4c1cb..9fee977b9 100644 --- a/bitblas/utils/rtmod_analysis.py +++ b/bitblas/utils/rtmod_analysis.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas import tvm +from bitblas import tvm as tvm +from bitblas import tilelang as tilelang from tvm import IRModule from tvm.runtime import ndarray from tvm.driver import lower from tvm.target import Target from typing import Tuple, List from tvm import tir -from tvm import tl -from tvm.tl.engine import is_device_call +from tilelang.engine import is_device_call def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule": @@ -16,18 +16,18 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule target = tvm.target.Target(target, target_host) mod = tir.transform.BindTarget(target)(mod) - mod = tl.transform.FrontendLegalize()(mod) + mod = tilelang.transform.FrontendLegalize()(mod) mod = tir.transform.Simplify()(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) + mod = tilelang.transform.LayoutInference()(mod) + mod = tilelang.transform.LowerTileOp()(mod) mod = tir.transform.Simplify()(mod) if target.arch == "sm_90": - mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tilelang.transform.WarpSpecializedPipeline()(mod) else: mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.FlattenBuffer()(mod) @@ -48,7 +48,7 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule mod = tir.transform.ThreadSync("shared")(mod) # TODO(lei): This is a hack to make sure the # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension + # in tilelang. As Tl only use one thread dimension # the var binding information will be lost # in the lowering process with Legalization # and Simplify pass. @@ -57,7 +57,7 @@ def get_annotated_device_mod_from_tl(mod: IRModule, target: Target) -> "IRModule # the Legalization. mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tl.transform.LowerHopperIntrin()(mod) + mod = tilelang.transform.LowerHopperIntrin()(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tir.transform.AnnotateDeviceRegions()(mod) diff --git a/install.sh b/install.sh index c3bb0fe0b..3a49bbd9b 100755 --- a/install.sh +++ b/install.sh @@ -3,6 +3,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +nproc=$(nproc) +if [ -z "$nproc" ]; then + nproc=1 +fi +# max 16 jobs +if [ $nproc -gt 16 ]; then + nproc=16 +fi + # install requirements pip install -r requirements.txt @@ -49,7 +58,7 @@ echo "Download and extraction completed successfully." LLVM_CONFIG_PATH="$(realpath ${EXTRACT_PATH}/$(basename ${FILE_NAME} .tar.xz)/bin/llvm-config)" echo "LLVM config path: $LLVM_CONFIG_PATH" -# clone and build tvm +# update and build tvm git submodule update --init --recursive cd 3rdparty/tvm @@ -59,11 +68,29 @@ fi mkdir build cp cmake/config.cmake build cd build + +# get the absolute path of the TVM prebuild path +ABS_TVM_PREBUILD_PATH=$(realpath .) + echo "set(USE_LLVM $LLVM_CONFIG_PATH)" >> config.cmake && echo "set(USE_CUDA /usr/local/cuda)" >> config.cmake -cmake .. && make -j && cd ../../.. +cmake .. && make -j $nproc && cd ../../.. + +# update and build tile-lang +cd 3rdparty/tile-lang +if [ -d build ]; then + rm -rf build +fi + +mkdir build + +cd build + +cmake .. -DTVM_PREBUILD_PATH=$ABS_TVM_PREBUILD_PATH && make -j $nproc && cd ../../.. echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc +# For 3rdparty/tile-lang import path +echo "export TVM_IMPORT_PYTHON_PATH=\$TVM_HOME/python" >> ~/.bashrc echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd):\$PYTHONPATH" >> ~/.bashrc source ~/.bashrc diff --git a/maint/scripts/format.sh b/maint/scripts/format.sh new file mode 100755 index 000000000..c5e81a1ef --- /dev/null +++ b/maint/scripts/format.sh @@ -0,0 +1,203 @@ +#!/usr/bin/env bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Usage: +# # Do work and commit your work. + +# # Format files that differ from origin/main. +# bash format.sh + +# # Commit changed files with message 'Run yapf and ruff' +# +# +# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase. +# You are encouraged to run this locally before pushing changes for review. + +# Cause the script to exit if a single command fails +set -eo pipefail + +# this stops git rev-parse from failing if we run this from the .git directory +builtin cd "$(dirname "${BASH_SOURCE:-$0}")" +ROOT="$(git rev-parse --show-toplevel)" +builtin cd "$ROOT" || exit 1 + +YAPF_VERSION=$(yapf --version | awk '{print $2}') +RUFF_VERSION=$(ruff --version | awk '{print $2}') +CODESPELL_VERSION=$(codespell --version) + +# # params: tool name, tool version, required version +tool_version_check() { + if [[ $2 != $3 ]]; then + echo "Wrong $1 version installed: $3 is required, not $2." + exit 1 + fi +} + +tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "ruff" $RUFF_VERSION "$(grep "ruff==" requirements-dev.txt | cut -d'=' -f3)" +tool_version_check "codespell" "$CODESPELL_VERSION" "$(grep codespell requirements-dev.txt | cut -d'=' -f3)" + +echo 'bitblas yapf: Check Start' + +YAPF_FLAGS=( + '--recursive' + '--parallel' +) + +YAPF_EXCLUDES=( + '--exclude' 'build/**' +) + +# Format specified files +format() { + yapf --in-place "${YAPF_FLAGS[@]}" "$@" +} + +# Format files that differ from main branch. Ignores dirs that are not slated +# for autoformat yet. +format_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause yapf to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # exist on both branches. + if git show-ref --verify --quiet refs/remotes/origin/main; then + BASE_BRANCH="origin/main" + else + BASE_BRANCH="main" + fi + + MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \ + yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" + fi + +} + +# Format all files +format_all() { + yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . +} + +## This flag formats individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + format "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is formatted. +elif [[ "$1" == '--all' ]]; then + format_all +else + # Format only the files that changed in last commit. + format_changed +fi +echo 'bitblas yapf: Done' + +echo 'bitblas codespell: Check Start' +# check spelling of specified files +spell_check() { + codespell "$@" +} + +spell_check_all(){ + codespell --toml pyproject.toml +} + +# Spelling check of files that differ from main branch. +spell_check_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + if git show-ref --verify --quiet refs/remotes/origin/main; then + BASE_BRANCH="origin/main" + else + BASE_BRANCH="main" + fi + + MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + codespell + fi +} + +# Run Codespell +## This flag runs spell check of individual files. --files *must* be the first command line +## arg to use this option. +if [[ "$1" == '--files' ]]; then + spell_check "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + spell_check_all +else + # Check spelling only of the files that changed in last commit. + spell_check_changed +fi +echo 'bitblas codespell: Done' + +echo 'bitblas ruff: Check Start' +# Lint specified files +lint() { + ruff "$@" +} + +# Lint files that differ from main branch. Ignores dirs that are not slated +# for autolint yet. +lint_changed() { + # The `if` guard ensures that the list of filenames is not empty, which + # could cause ruff to receive 0 positional arguments, making it hang + # waiting for STDIN. + # + # `diff-filter=ACM` and $MERGEBASE is to ensure we only lint files that + # exist on both branches. + if git show-ref --verify --quiet refs/remotes/origin/main; then + BASE_BRANCH="origin/main" + else + BASE_BRANCH="main" + fi + + MERGEBASE="$(git merge-base $BASE_BRANCH HEAD)" + + if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then + git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ + ruff + fi + +} + +# Run Ruff +### This flag lints individual files. --files *must* be the first command line +### arg to use this option. +if [[ "$1" == '--files' ]]; then + lint "${@:2}" + # If `--all` is passed, then any further arguments are ignored and the + # entire python directory is linted. +elif [[ "$1" == '--all' ]]; then + lint BitBLAS tests +else + # Format only the files that changed in last commit. + lint_changed +fi + +if ! git diff --quiet &>/dev/null; then + echo 'Reformatted files. Please review and stage the changes.' + echo 'Changes not staged for commit:' + echo + git --no-pager diff --name-only + + exit 1 +fi + +echo 'bitblas ruff: Done' + +echo 'bitblas: All checks passed' diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index 3183efb8f..06dc3cc94 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -3,7 +3,7 @@ import bitblas from bitblas.ops.general_matmul_splitk import MatmulWithSplitK, MatmulConfigWithSplitK - +bitblas.set_log_level("DEBUG") def get_codegen_result(ops): code = ops.get_source() @@ -107,7 +107,7 @@ def matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, propagate_a=False, propagate_b=False, ) - matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) + matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=True) input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) @@ -152,10 +152,8 @@ def map_torch_type(intype): def test_matmul_torch_forward_fp8e4m3(): matmul_torch_forward_fp8e4m3(1, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", "float16", "nt", False, -1, False, False, None) - matmul_torch_forward_fp8e4m3(4, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", - "float16", "nt", False, -1, False, False, None) # fmt: on if __name__ == "__main__": - bitblas.testing.main() + bitblas.testing.main() \ No newline at end of file diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 58f595984..7f573223b 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -159,7 +159,7 @@ def assert_correctness_with_ladder_ldmatrix_propagate( "block": [16, 128], "warp": [16, 32], "rstep": [128], - "pipeline_stage": 4, + "pipeline_stage": 2, "use_async": True, "intrin_info": intrin_info, "shared_scope": "shared.dyn", @@ -411,7 +411,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( "block": [16, 128], "warp": [16, 32], "rstep": [128], - "pipeline_stage": 4, + "pipeline_stage": 2, "use_async": True, "intrin_info": intrin_info, "shared_scope": "shared.dyn", diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 5c98cb948..e692d4afb 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang from bitblas.ops.general_matmul.tilelang.dense import ( matmul_blocked, matmul_macro_tensorcore, @@ -47,7 +47,7 @@ def assert_matmul_blocked_correctness(M, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -57,7 +57,7 @@ def assert_matmul_blocked_correctness(M, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -105,7 +105,7 @@ def assert_matmul_macro_tensorcore_correctness( num_stages=num_stages, enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code represents generated cuda source @@ -115,7 +115,7 @@ def assert_matmul_macro_tensorcore_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -164,7 +164,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -185,7 +185,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 349a69752..c317d5061 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import ( MatmulScheduler, MatmulFineGrainScheduler, @@ -40,7 +40,7 @@ def assert_matmul_blocked_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -50,7 +50,7 @@ def assert_matmul_blocked_with_default_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -98,7 +98,7 @@ def assert_matmul_blocked_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -108,7 +108,7 @@ def assert_matmul_blocked_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -144,7 +144,7 @@ def assert_matmul_fine_grained_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -152,7 +152,7 @@ def assert_matmul_fine_grained_with_default_correctness( B = (torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) if trans_B else torch.rand( K, N, device="cuda", dtype=getattr(torch, in_dtype))) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) latency = mod.do_bench(mod.func, warmup=25) @@ -264,7 +264,7 @@ def assert_matmul_fine_grained_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -274,7 +274,7 @@ def assert_matmul_fine_grained_apply_config_correctness( B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) - 0.5 C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -310,7 +310,7 @@ def assert_matmul_weight_propagation_with_default_correctness( accum_dtype=accum_dtype, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -331,7 +331,7 @@ def assert_matmul_weight_propagation_with_default_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -384,7 +384,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( enable_rasterization=enable_rasterization, ) - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -405,7 +405,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( LB = ladder_permutate(B.cpu()).cuda() - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, LB, C) @@ -458,7 +458,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( zeros_mode=zeros_mode, ).with_default_config() - mod, params = tl.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source assert src_code is not None @@ -517,7 +517,7 @@ def assert_matmul_blocked_dequant_with_default_correctness( permuted_inputs.append(inputs[2]) - mod = tl.Profiler(mod, params, [], tl.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(*permuted_inputs) diff --git a/testing/python/tilelang/test_simplifier.py b/testing/python/tilelang/test_simplifier.py index 96536670a..18613edc9 100644 --- a/testing/python/tilelang/test_simplifier.py +++ b/testing/python/tilelang/test_simplifier.py @@ -1,6 +1,6 @@ import tvm -from tvm import tl -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T def modify( @@ -36,7 +36,7 @@ def main( def test_modify(with_B=False, with_bias=False): tester = modify(with_B=with_B, with_bias=with_bias) mod = tvm.IRModule({tester.attrs["global_symbol"]: tester}) - mod2 = tl.transform.Simplify()(mod) + mod2 = tilelang.transform.Simplify()(mod) assert mod != mod2 @@ -71,11 +71,11 @@ def main( def test_matmul(): func = matmul(1024, 1024, 1024, 128, 128, 32) mod = tvm.IRModule({func.attrs["global_symbol"]: func}) - mod = tl.transform.Simplify()(mod) + mod = tilelang.transform.Simplify()(mod) - rt_mod, params = tl.lower(mod.functions_items()[0][1], runtime_only=False) + rt_mod, params = tilelang.lower(mod.functions_items()[0][1], runtime_only=False) # TODO Profiler only support TensorType, not dynamic variable - profiler = tl.Profiler(rt_mod, params, result_idx=[2]) + profiler = tilelang.Profiler(rt_mod, params, result_idx=[2]) import torch a = torch.randn(1024, 1024, dtype=torch.float16).cuda().half() diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 006b0665a..a0776fee1 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -5,8 +5,8 @@ import bitblas from bitblas import tvm as tvm from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert from bitblas.tl.utils import (make_swizzle_layout) from bitblas.tl.macro_generator import ( @@ -44,7 +44,7 @@ def matmul( local_size = MAX_TRANSACTION_SIZE_IN_BITS // DataType(in_dtype).bits local_size_compressed = local_size // num_elems_per_byte - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main( @@ -122,8 +122,8 @@ def run_gemm( num_threads, ) - mod, params = TL.lower(program) - mod = TL.Profiler(mod, params, [2], TL.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) out = mod.run_once() assert out is not None @@ -366,7 +366,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -405,7 +405,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C) diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index d0587ebef..784265704 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter) @@ -178,7 +178,7 @@ def main( def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_macro(N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -188,7 +188,7 @@ def assert_tl_matmul_macro_correctness(M, N, K, in_dtype, out_dtype, accum_dtype B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -217,7 +217,7 @@ def tl_matmul_block( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -271,13 +271,13 @@ def assert_tl_matmul_block_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) def ref_program(A, B): @@ -318,7 +318,7 @@ def tl_matmul_block_all_dynamic( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -370,13 +370,13 @@ def assert_tl_matmul_block_all_dynamic_correctness( num_stages, num_threads, ) - mod, params = TL.lower(program) + mod, params = tilelang.lower(program) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) print(mod.mod.imported_modules[0].get_source()) diff --git a/testing/python/tilelang/test_tilelang_flash_atten.py b/testing/python/tilelang/test_tilelang_flash_atten.py index e0e72c5d5..d71994609 100644 --- a/testing/python/tilelang/test_tilelang_flash_atten.py +++ b/testing/python/tilelang/test_tilelang_flash_atten.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from tvm import tl -import tvm.tl.language as T -from tvm.tl.autotuner import * +from bitblas import tilelang as tilelang +import tilelang.language as T +from tilelang.autotuner import * from functools import partial import itertools import torch @@ -64,10 +64,10 @@ def flashattn_tilelang(batch, heads, seq_len, dim, trans_K, dtypeQKV, dtypeAccu, num_stages=num_stages, is_causal=is_causal, ) - mod, params = tl.lower(tl_prim_func) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(tl_prim_func) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) from flash_attn.flash_attn_interface import flash_attn_func - # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tvm.tl.utils + # TODO Now hack to internal function get the same input, may need to modify 3rdparty:tilelang.utils ins = mod._get_inputs() tilelang_res = mod(*ins) Q, K, V = ins[0], ins[1], ins[2] @@ -175,8 +175,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) @@ -204,7 +204,7 @@ def flashattn_autotune(batch, heads, seq_len, dim, is_causal): ) @jit( out_idx=[3], - supply_type=tl.TensorSupplyType.Normal, + supply_type=tilelang.TensorSupplyType.Normal, ref_prog=partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01, @@ -396,8 +396,8 @@ def main( return main - mod, params = tl.lower(kernel()) - mod = tl.Profiler(mod, params, [3], tl.TensorSupplyType.Normal) + mod, params = tilelang.lower(kernel()) + mod = tilelang.Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) mod.assert_allclose(partial(ref_program, causal=is_causal), rtol=0.01, atol=0.01) diff --git a/testing/python/tilelang/test_tilelang_gemm.py b/testing/python/tilelang/test_tilelang_gemm.py index 38fc65a77..cf884be3b 100644 --- a/testing/python/tilelang/test_tilelang_gemm.py +++ b/testing/python/tilelang/test_tilelang_gemm.py @@ -3,7 +3,7 @@ from bitblas import tvm as tvm import bitblas.testing -from tvm import tl +from bitblas import tilelang as tilelang def matmul( @@ -26,7 +26,7 @@ def matmul( A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) - import tvm.tl.language as T + import tilelang.language as T @T.prim_func def main(A: T.Buffer(A_shape, in_dtype), B: T.Buffer(B_shape, in_dtype), C: T.Buffer( @@ -81,8 +81,8 @@ def run_gemm( num_stages, num_threads, ) - mod, params = tl.lower(program) - mod = tl.Profiler(mod, params, [2], tl.TensorSupplyType.Integer) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [2], tilelang.TensorSupplyType.Integer) def ref_program(A, B): import torch diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 4d1318960..ce9a7edef 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -6,8 +6,8 @@ from bitblas import tvm as tvm import bitblas.testing from tvm import DataType -from tvm import tl as TL -import tvm.tl.language as T +from bitblas import tilelang as tilelang +import tilelang.language as T from bitblas.tl.utils import get_swizzle_layout from bitblas.tl.macro_generator import ( TensorCoreIntrinEmitter, @@ -182,7 +182,7 @@ def main( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -192,7 +192,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -377,7 +377,7 @@ def main( def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): matmul = tl_matmul_with_block_reduce(M, N, K, in_dtype, out_dtype, accum_dtype) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -387,7 +387,7 @@ def assert_tl_matmul_with_block_reduce_correctness(M, N, K, in_dtype, out_dtype, B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, B, C) @@ -552,7 +552,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d matmul = tl_matmul_with_ladder_weight_only_transform(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -571,7 +571,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness(M, N, K, in_d ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config) - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) LB = ladder_permutate(B.cpu()).cuda() @@ -809,7 +809,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) - mod, params = TL.lower(matmul) + mod, params = tilelang.lower(matmul) src_code = mod.imported_modules[0].get_source() # src_code is the generated cuda source @@ -848,7 +848,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct QLB = ladder_permutate(qB.cpu()).cuda() QLB = lop3_permutate(QLB.cpu()).cuda() - mod = TL.Profiler(mod, params, [], TL.TensorSupplyType.Integer) + mod = tilelang.Profiler(mod, params, [], tilelang.TensorSupplyType.Integer) mod(A, QLB, C)