diff --git a/docs/conf.py b/docs/conf.py index fde38c490..1b1289038 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,12 +1,10 @@ -# -*- coding: utf-8 -*- - # General information about the project. project = "Tile Language
" author = "Tile Lang Contributors" -copyright = "2025-2025, %s" % author +copyright = f"2025-2025, {author}" # Version information. -with open("../VERSION", "r") as f: +with open("../VERSION") as f: version = f.read().strip() release = version diff --git a/pyproject.toml b/pyproject.toml index daa30406b..e76a267c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,17 @@ target-version = "py38" line-length = 100 output-format = "full" +exclude = [ + "3rdparty", + "examples/deepseek_v32/inference", +] + +[tool.ruff.lint.per-file-ignores] +# Do not upgrade type hint in testing and examples. +# See https://github.com/tile-ai/tilelang/issues/1079 for more information. +"testing/**.py" = ["UP", "FA"] +"examples/**.py" = ["UP", "FA"] + [tool.ruff.lint] select = [ # pycodestyle @@ -94,7 +105,7 @@ select = [ # Pyflakes "F", # pyupgrade - # "UP", + "UP", "FA", # flake8-bugbear "B", # flake8-simplify @@ -115,6 +126,8 @@ ignore = [ "SIM108", # key in dict.keys() "SIM118", + # open file w.o. ctx manager + "SIM115", # memory leaks "B019", # zip without explicit strict @@ -122,9 +135,6 @@ ignore = [ # No such file or directory "E902", ] -[tool.ruff.lint.per-file-ignores] -"3rdparty/**/*" = ["ALL"] -"examples/deepseek_v32/inference/**/*" = ["ALL"] [tool.pytest.ini_options] verbosity_assertions = 3 diff --git a/tilelang/autotuner/capture.py b/tilelang/autotuner/capture.py index 78f937de8..27c24f14e 100644 --- a/tilelang/autotuner/capture.py +++ b/tilelang/autotuner/capture.py @@ -1,5 +1,6 @@ +from __future__ import annotations import threading -from typing import List, Any, Optional +from typing import Any # Use thread local to store the stack # This is to avoid the cross-thread interference @@ -87,7 +88,7 @@ class AutotuneInputsCapture: __slots__ = ("tensors") - def __init__(self, tensors: List[Any]): + def __init__(self, tensors: list[Any]): self.tensors = tensors def __enter__(self) -> None: @@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture: return AutotuneInputsCapture(tensors) -def get_autotune_inputs() -> Optional[List[Any]]: +def get_autotune_inputs() -> list[Any] | None: """ Get the current autotune inputs from the stack. """ diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py index aa8f6b9de..a486e9018 100644 --- a/tilelang/autotuner/param.py +++ b/tilelang/autotuner/param.py @@ -1,11 +1,12 @@ """The auto-tune parameters. """ +from __future__ import annotations import tilelang from tilelang import tvm as tvm from tvm.tir import PrimFunc from tvm.target import Target -from typing import Callable, List, Literal, Any, Optional, Union, Dict +from typing import Callable, Literal, Any from dataclasses import dataclass from pathlib import Path @@ -47,12 +48,12 @@ class CompileArgs: "tl.disable_safe_memory_legalize": bool, default: False """ - out_idx: Optional[Union[List[int], int]] = None + out_idx: list[int] | int | None = None execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" target: Literal['auto', 'cuda', 'hip'] = 'auto' - target_host: Union[str, Target] = None + target_host: str | Target = None verbose: bool = False - pass_configs: Optional[Dict[str, Any]] = None + pass_configs: dict[str, Any] | None = None def compile_program(self, program: PrimFunc): return tilelang.compile( @@ -142,12 +143,12 @@ class AutotuneResult: func: Optimized function. kernel: Compiled kernel function. """ - latency: Optional[float] = None - config: Optional[dict] = None - ref_latency: Optional[float] = None - libcode: Optional[str] = None - func: Optional[Callable] = None - kernel: Optional[Callable] = None + latency: float | None = None + config: dict | None = None + ref_latency: float | None = None + libcode: str | None = None + func: Callable | None = None + kernel: Callable | None = None def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False): """ @@ -211,9 +212,9 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo def _load_kernel_from_disk( self, cache_path: Path, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, - out_idx: Optional[Union[List[int], int]] = None, + target: str | Target = "auto", + target_host: str | Target = None, + out_idx: list[int] | int | None = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", pass_configs: dict = None, func: Callable = None, @@ -239,14 +240,14 @@ def _load_kernel_from_disk( if not os.path.exists(cache_path): return None - kernel_global_source: Optional[str] = None - kernel_params: Optional[List[KernelParam]] = None + kernel_global_source: str | None = None + kernel_params: list[KernelParam] | None = None try: wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH) if verbose: logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") - with open(wrapped_kernel_path, "r") as f: + with open(wrapped_kernel_path) as f: kernel_global_source = f.read() except Exception as e: logger.error(f"Error loading wrapped kernel source code from disk: {e}") @@ -307,7 +308,7 @@ def save_to_disk(self, path: Path, verbose: bool = False): self._save_kernel_to_disk(path, self.kernel) @classmethod - def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult': + def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult: if not os.path.exists(path): return None @@ -315,7 +316,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul # load best config if verbose: logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}") - with open(path / BEST_CONFIG_PATH, "r") as f: + with open(path / BEST_CONFIG_PATH) as f: config = json.load(f) # load function @@ -327,7 +328,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul # load latency if verbose: logger.debug(f"Loading latency from file: {path / LATENCY_PATH}") - with open(path / LATENCY_PATH, "r") as f: + with open(path / LATENCY_PATH) as f: latency = json.load(f) latency, ref_latency = latency["latency"], latency["ref_latency"] diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 2173a1392..e94ac7466 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -3,6 +3,7 @@ This module provides functionality for auto-tuning tilelang programs, including JIT compilation and performance optimization through configuration search. """ +from __future__ import annotations import tilelang from tilelang import tvm as tvm @@ -10,7 +11,7 @@ from tvm.target import Target import inspect from functools import partial -from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple) +from typing import (Callable, Literal, Any, overload) from tqdm import tqdm import logging import functools @@ -103,8 +104,8 @@ class AutoTuner: compile_args = CompileArgs() profile_args = ProfileArgs() - _kernel_parameters: Optional[Tuple[str, ...]] = None - _function_parameters: Optional[Dict[str, Any]] = None + _kernel_parameters: tuple[str, ...] | None = None + _function_parameters: dict[str, Any] | None = None _lock = threading.Lock() # For thread safety _memory_cache = {} # In-memory cache dictionary cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" @@ -131,12 +132,12 @@ def from_kernel(cls, kernel: Callable, configs): return cls(kernel, configs) def set_compile_args(self, - out_idx: Union[List[int], int, None] = None, + out_idx: list[int] | int | None = None, target: Literal['auto', 'cuda', 'hip'] = 'auto', execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", - target_host: Union[str, Target] = None, + target_host: str | Target = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None): + pass_configs: dict[str, Any] | None = None): """Set compilation arguments for the auto-tuner. Args: @@ -223,12 +224,12 @@ def set_profile_args(self, return self - def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]): + def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]): # for cache key generation self._kernel_parameters = k_parameters self._function_parameters = f_parameters - def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: + def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None: """Generate a cache key for the auto-tuning process. """ @@ -307,8 +308,8 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): return result best_latency: float = 1e8 - best_config: Optional[Dict[str, Any]] = None - best_kernel: Optional[tilelang.JITKernel] = None + best_config: dict[str, Any] | None = None + best_kernel: tilelang.JITKernel | None = None def _compile(**config_arg) -> tilelang.JITKernel: compile_args = self.compile_args @@ -591,7 +592,7 @@ class _AutoTunerImplementation: warmup: int = 25 rep: int = 100 timeout: int = 100 - configs: Union[Dict, Callable] = None + configs: dict | Callable = None supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto ref_prog: Callable = None supply_prog: Callable = None @@ -603,7 +604,7 @@ class _AutoTunerImplementation: cache_input_tensors: bool = False def __init__(self, - configs: Union[Dict, Callable], + configs: dict | Callable, warmup: int = 25, rep: int = 100, timeout: int = 100, @@ -653,12 +654,12 @@ def __init__(self, self.cache_input_tensors = cache_input_tensors # Reuse inputs # Cache for storing tuned kernel implementations - self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel + self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel # This tells the type checker what the *wrapper* function will return. # this is for linting, please do not remove it. @overload - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]: + def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]: ... @overload @@ -720,9 +721,9 @@ def jit_compile(**config_arg): def autotune( # This is the new public interface - func: Union[Callable[_P, _RProg], PrimFunc, None] = None, + func: Callable[_P, _RProg] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only - configs: Union[Dict, Callable], + configs: dict | Callable, # profile arguments warmup: int = 25, rep: int = 100, diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py index 72d003318..c338ce61d 100644 --- a/tilelang/cache/__init__.py +++ b/tilelang/cache/__init__.py @@ -1,6 +1,7 @@ """The cache utils with class and database persistence - Init file""" +from __future__ import annotations -from typing import List, Union, Literal, Optional +from typing import Literal from tvm.target import Target from tvm.tir import PrimFunc from tilelang.jit import JITKernel @@ -13,14 +14,14 @@ def cached( func: PrimFunc = None, - out_idx: List[int] = None, + out_idx: list[int] = None, *args, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, - execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython", - verbose: Optional[bool] = False, - pass_configs: Optional[dict] = None, - compile_flags: Optional[Union[List[str], str]] = None, + target: str | Target = "auto", + target_host: str | Target = None, + execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython", + verbose: bool | None = False, + pass_configs: dict | None = None, + compile_flags: list[str] | str | None = None, ) -> JITKernel: """ Caches and reuses compiled kernels (using KernelCache class). diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py index b6d2e77b7..d0a801fb4 100644 --- a/tilelang/cache/kernel_cache.py +++ b/tilelang/cache/kernel_cache.py @@ -1,4 +1,5 @@ """The cache utils with class and database persistence - KernelCache Class""" +from __future__ import annotations import json import logging @@ -7,7 +8,7 @@ import threading import uuid from hashlib import sha256 -from typing import Callable, List, Literal, Optional, Union +from typing import Callable, Literal import cloudpickle from tvm.target import Target @@ -67,13 +68,13 @@ def _create_dirs(): def _generate_key( self, func: Callable, - out_idx: List[int], + out_idx: list[int], execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", args=None, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, + target: str | Target = "auto", + target_host: str | Target = None, pass_configs: dict = None, - compile_flags: Optional[Union[List[str], str]] = None, + compile_flags: list[str] | str | None = None, ) -> str: """ Generates a unique hash key for caching compiled kernels. @@ -112,14 +113,14 @@ def _generate_key( def cached( self, func: PrimFunc = None, - out_idx: List[int] = None, + out_idx: list[int] = None, *args, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, + target: str | Target = "auto", + target_host: str | Target = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", verbose: bool = False, pass_configs: dict = None, - compile_flags: Optional[Union[List[str], str]] = None, + compile_flags: list[str] | str | None = None, ) -> JITKernel: """ Caches and reuses compiled kernels to avoid redundant compilation. @@ -322,15 +323,15 @@ def _save_kernel_to_disk(self, def _load_kernel_from_disk( self, key: str, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, - out_idx: List[int] = None, + target: str | Target = "auto", + target_host: str | Target = None, + out_idx: list[int] = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", pass_configs: dict = None, - compile_flags: Optional[Union[List[str], str]] = None, + compile_flags: list[str] | str | None = None, func: Callable = None, verbose: bool = False, - ) -> Optional[JITKernel]: + ) -> JITKernel | None: """ Loads a previously compiled kernel from disk cache. @@ -355,15 +356,15 @@ def _load_kernel_from_disk( if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]): return None - kernel_global_source: Optional[str] = None - kernel_params: Optional[List[KernelParam]] = None + kernel_global_source: str | None = None + kernel_params: list[KernelParam] | None = None # Load the kernel source file (optional) try: if verbose: self.logger.debug( f"Loading wrapped kernel source code from file: {wrapped_kernel_path}") - with open(wrapped_kernel_path, "r") as f: + with open(wrapped_kernel_path) as f: kernel_global_source = f.read() except Exception as e: self.logger.error(f"Error loading wrapped kernel source code from disk: {e}") diff --git a/tilelang/carver/analysis.py b/tilelang/carver/analysis.py index 653392df7..96606e790 100644 --- a/tilelang/carver/analysis.py +++ b/tilelang/carver/analysis.py @@ -1,5 +1,5 @@ """Analysis on TIR blocks, loops and functions.""" -from typing import List, Optional, Set, Union +from __future__ import annotations from typing_extensions import Literal from tvm import ir, tir, DataType @@ -31,7 +31,7 @@ def __init__( self.loop_rv = loop_rv @property - def dom(self) -> Union[int, tir.PrimExpr]: + def dom(self) -> int | tir.PrimExpr: """The iteration domain of the loop.""" return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom @@ -46,14 +46,14 @@ class BlockInfo: """Information about a TIR block.""" name: str - iters: List[IterInfo] + iters: list[IterInfo] block_rv: tir.schedule.BlockRV _reduction_block: bool def __init__( self, name: str, - iters: List[IterInfo], + iters: list[IterInfo], block_rv: tir.schedule.BlockRV, reduction_block: bool = False, ): @@ -63,7 +63,7 @@ def __init__( self.iters = iters self._reduction_block = reduction_block - def dom(self) -> List[Union[int, tir.PrimExpr]]: + def dom(self) -> list[int | tir.PrimExpr]: """The iteration domain of the block.""" return [i.dom for i in self.iters] @@ -118,7 +118,7 @@ def __repr__(self) -> str: _normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") -def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: +def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None: """Normalize the primfunc to normal form""" try: result = _normalize_prim_func(sch) @@ -133,7 +133,7 @@ def _iter_kind(i: tir.IterVar) -> str: tir.IterVar.CommReduce: "R", }.get(i.iter_type, "O") - blocks: List[BlockInfo] = [] + blocks: list[BlockInfo] = [] for block, loops, iters, is_reduction in zip(*result): blocks.append( BlockInfo( @@ -203,7 +203,7 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: def collect_block_iter_vars_used_in_access_region(block: tir.Block, - region: List[ir.Range]) -> Set[tir.Var]: + region: list[ir.Range]) -> set[tir.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() for expr in region: @@ -214,7 +214,7 @@ def collect_block_iter_vars_used_in_access_region(block: tir.Block, return tir_vars -def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: +def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]: """Collect the variables used in the PrimExpr.""" tir_vars = set() @@ -259,7 +259,7 @@ def is_broadcast_epilogue( def get_reduction_blocks(sch: tir.Schedule, - blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]: + blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]: # Get the main computation block def is_reduction(block: BlockRV) -> bool: block_stmt = sch.get(block) @@ -286,7 +286,7 @@ def is_spatial(block: BlockRV) -> bool: def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: # gpu memory prefer 128 bits coalesced access (e.g. four banks) # 128 bits - buffers: List[tir.Buffer] = [] + buffers: list[tir.Buffer] = [] for read in block_stmt.reads: buffers.append(read.buffer) for write in block_stmt.writes: diff --git a/tilelang/carver/arch/__init__.py b/tilelang/carver/arch/__init__.py index 3793d3a13..c2bc9c75d 100644 --- a/tilelang/carver/arch/__init__.py +++ b/tilelang/carver/arch/__init__.py @@ -1,14 +1,15 @@ +from __future__ import annotations + from .arch_base import TileDevice from .cuda import * from .cpu import * from .cdna import * from .metal import * -from typing import Union from tvm.target import Target import torch -def get_arch(target: Union[str, Target] = "cuda") -> TileDevice: +def get_arch(target: str | Target = "cuda") -> TileDevice: if isinstance(target, str): target = Target(target) diff --git a/tilelang/carver/arch/arch_base.py b/tilelang/carver/arch/arch_base.py index 06a614fb5..a10fa434d 100644 --- a/tilelang/carver/arch/arch_base.py +++ b/tilelang/carver/arch/arch_base.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations class TileDevice: @@ -14,12 +14,12 @@ def __init__(self) -> None: 0 # The size of a warp, a group of threads that execute instructions in lockstep ) self.sm_partition: int = 0 # The number of streaming multiprocessor partitions - self.transaction_size: List[int] = [ + self.transaction_size: list[int] = [ 0, 0, ] # The size of memory transactions, typically in bytes self.max_smem_usage: int = 0 # The maximum shared memory usage allowed - self.bandwidth: List[int] = [ + self.bandwidth: list[int] = [ 0, 0, ] # Bandwidth specifications, possibly including peak and sustained rates @@ -29,9 +29,9 @@ def __init__(self) -> None: ) self.l2_cache_size_bytes: int = 0 # the number of transaction size in bytes - self.transaction_size: List[int] = [0, 0] # in bytes + self.transaction_size: list[int] = [0, 0] # in bytes # bandwidth in MB/s, will be used for recommend basic tile size - self.bandwidth: List[int] = [0, 0] + self.bandwidth: list[int] = [0, 0] def get_avaliable_tensorintrin_shapes(self): raise NotImplementedError() diff --git a/tilelang/carver/arch/cdna.py b/tilelang/carver/arch/cdna.py index ed9848219..ec5aa905f 100644 --- a/tilelang/carver/arch/cdna.py +++ b/tilelang/carver/arch/cdna.py @@ -1,7 +1,7 @@ +from __future__ import annotations import tvm from tvm.target import Target from .arch_base import TileDevice -from typing import List, Union def is_cdna_arch(arch: TileDevice) -> bool: @@ -10,7 +10,7 @@ def is_cdna_arch(arch: TileDevice) -> bool: class CDNA(TileDevice): - def __init__(self, target: Union[Target, str]): + def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) self.target = target @@ -27,9 +27,9 @@ def __init__(self, target: Union[Target, str]): self.max_smem_usage: int = 2 * self.smem_cap self.sm_partition: int = 4 self.l2_cache_size_bytes: int = target.l2_cache_size_bytes - self.transaction_size: List[int] = [32, 128] # in bytes + self.transaction_size: list[int] = [32, 128] # in bytes - self.bandwidth: List[int] = [1300, 14000] + self.bandwidth: list[int] = [1300, 14000] __all__ = [ diff --git a/tilelang/carver/arch/cuda.py b/tilelang/carver/arch/cuda.py index ce5df4af4..4c7f98dff 100644 --- a/tilelang/carver/arch/cuda.py +++ b/tilelang/carver/arch/cuda.py @@ -1,7 +1,7 @@ +from __future__ import annotations import tvm from tvm.target import Target from .arch_base import TileDevice -from typing import List, Union from .driver import cuda_driver @@ -91,21 +91,21 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til raise ValueError(f"Unsupported architecture: {arch}") -class TensorInstruction(object): +class TensorInstruction: def __init__( self, name: str, - shape: List[int], + shape: list[int], ): self.name: str = name # only hold the shape of M and N - self.shape: List[int] = shape + self.shape: list[int] = shape class CUDA(TileDevice): - def __init__(self, target: Union[Target, str]): + def __init__(self, target: Target | str): if isinstance(target, str): target = tvm.target.Target(target) self.target = target @@ -126,15 +126,15 @@ def __init__(self, target: Union[Target, str]): self.sm_partition: int = 4 self.l2_cache_size_bytes: int = target.l2_cache_size_bytes # the number of transaction size in bytes - self.transaction_size: List[int] = [32, 128] # in bytes + self.transaction_size: list[int] = [32, 128] # in bytes # bandwidth in MB/s, will be used for recommend basic tile size # TODO(lei): find some way to get the real bandwidth # However, the ratio of bandwidth between different devices can # be similar. The bandwidth can work for another devices as well. - self.bandwidth: List[int] = [750, 12080] + self.bandwidth: list[int] = [750, 12080] # get the available tensor instructions during runtime to avoid # the dependency of the tensor intrinsics registration - self.available_tensor_instructions: List[TensorInstruction] = None + self.available_tensor_instructions: list[TensorInstruction] = None def get_avaliable_tensorintrin_shapes(self): self.available_tensor_instructions = ( diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py index 3e08e9afd..337987dd8 100644 --- a/tilelang/carver/arch/driver/cuda_driver.py +++ b/tilelang/carver/arch/driver/cuda_driver.py @@ -1,6 +1,6 @@ +from __future__ import annotations import ctypes import sys -from typing import Optional class cudaDeviceProp(ctypes.Structure): @@ -77,7 +77,7 @@ class cudaDeviceProp(ctypes.Structure): ] -def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]: +def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None: if sys.platform == "win32": libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll") @@ -95,7 +95,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]: raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}") -def get_device_name(device_id: int = 0) -> Optional[str]: +def get_device_name(device_id: int = 0) -> str | None: prop = get_cuda_device_properties(device_id) if prop: return prop.name.decode() @@ -103,7 +103,7 @@ def get_device_name(device_id: int = 0) -> Optional[str]: raise RuntimeError("Failed to get device properties.") -def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> Optional[int]: +def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None: assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb" prop = get_cuda_device_properties(device_id) if prop: @@ -143,7 +143,7 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int: return None -def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> Optional[int]: +def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> int | None: """ Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes. """ diff --git a/tilelang/carver/arch/metal.py b/tilelang/carver/arch/metal.py index 5650f7cc4..9cd1c4d1e 100644 --- a/tilelang/carver/arch/metal.py +++ b/tilelang/carver/arch/metal.py @@ -1,3 +1,4 @@ +from __future__ import annotations from tvm.target import Target from .arch_base import TileDevice diff --git a/tilelang/carver/common_schedules.py b/tilelang/carver/common_schedules.py index 609d02b51..2766a15e3 100644 --- a/tilelang/carver/common_schedules.py +++ b/tilelang/carver/common_schedules.py @@ -19,7 +19,8 @@ # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm common_schedules.py in dlight. """Common schedule strategies for TIR.""" -from typing import Callable, List +from __future__ import annotations +from typing import Callable from tvm import tir from .utils import retrieve_func_from_module @@ -28,7 +29,7 @@ def get_block( sch: tir.Schedule, - blocks: List[BlockInfo], + blocks: list[BlockInfo], name: str, ): """Get the target block from a schedule. @@ -56,7 +57,7 @@ def get_block( def get_output_blocks( sch: tir.Schedule, - blocks: List[BlockInfo], + blocks: list[BlockInfo], ): """Get the output blocks of a schedule. @@ -89,8 +90,8 @@ def get_output_blocks( def try_inline( sch: tir.Schedule, - blocks: List[BlockInfo], -) -> List[BlockInfo]: + blocks: list[BlockInfo], +) -> list[BlockInfo]: """Try to inline as many blocks as possible, and return the remaining blocks. Parameters @@ -127,8 +128,8 @@ def _trial(func: Callable): def try_inline_contiguous_spatial( sch: tir.Schedule, - block_infos: List[BlockInfo], -) -> List[BlockInfo]: + block_infos: list[BlockInfo], +) -> list[BlockInfo]: """Try to inline contiguous spatial blocks in a schedule Parameters diff --git a/tilelang/carver/matmul_analysis.py b/tilelang/carver/matmul_analysis.py index dfc1a53e9..02a86cc78 100644 --- a/tilelang/carver/matmul_analysis.py +++ b/tilelang/carver/matmul_analysis.py @@ -1,8 +1,8 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" +from __future__ import annotations from dataclasses import dataclass from enum import Enum -from typing import List, Optional, Set, Union, Tuple, Dict from tvm import tir from tvm.ir import Range from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap @@ -57,7 +57,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): def auto_inline_producers( sch: tir.Schedule, block: tir.schedule.BlockRV, - skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, + skip_blocks: list[tir.schedule.BlockRV] | None = None, ): skip_blocks = skip_blocks or [] while True: @@ -118,7 +118,7 @@ def auto_inline_consumer_chain( # used to match the similar region with dequantize op. -def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): +def find_first_similar_region(regions: list[BufferRegion], buffer: tir.Buffer): for region in regions: if len(region.buffer.shape) == len(buffer.shape): return region @@ -126,7 +126,7 @@ def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): # used to match the similar buffer with dequantize op. -def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): +def find_first_similar_buffer(regions: list[BufferRegion], buffer: tir.Buffer): for region in regions: if len(region.buffer.shape) == len(buffer.shape): return region.buffer @@ -134,7 +134,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): # find the block that required to be reindex and scope. -def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: +def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> BlockRV | None: # block that most near to the arguments block = main_block buffer = buffer @@ -209,11 +209,11 @@ class IterTrait: def make_iter_fusion_index_map( - traits: List[IterTrait], - kind_order: List[IterKind], + traits: list[IterTrait], + kind_order: list[IterKind], ) -> tir.IndexMap: - fused_iters: Dict[IterKind, PrimExpr] = {} - input_iters: List[tir.Var] = [] + fused_iters: dict[IterKind, PrimExpr] = {} + input_iters: list[tir.Var] = [] for i, trait in enumerate(traits): v_i = tir.Var(f"i{i}", trait.extent.dtype) input_iters.append(v_i) @@ -226,14 +226,14 @@ def make_iter_fusion_index_map( else: fused_iters[trait.kind] = v_i - final_indices: List[tir.PrimExpr] = [ + final_indices: list[tir.PrimExpr] = [ fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order ] return tir.IndexMap(input_iters, final_indices, None) -def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: +def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None: """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] Parameters @@ -252,8 +252,8 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: if len(block.reads) != 2 or len(block.writes) != 1: return None - def get_access_axes(region: List[Range]) -> Set[Var]: - axes: Set[Var] = set() + def get_access_axes(region: list[Range]) -> set[Var]: + axes: set[Var] = set() for r in region: if not _is_one(r.extent): raise ValueError("Expect elemwise block access") @@ -267,7 +267,7 @@ def get_access_axes(region: List[Range]) -> Set[Var]: except ValueError: return None - traits: Dict[Var, IterTrait] = {} + traits: dict[Var, IterTrait] = {} for iter_var in block.iter_vars: var = iter_var.var kind: IterKind @@ -308,7 +308,7 @@ def get_access_axes(region: List[Range]) -> Set[Var]: def get_index_map(block: tir.Block, - layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: + layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None: """Get index maps for the block Parameters @@ -334,8 +334,8 @@ def get_index_map(block: tir.Block, return None A_traits, B_traits, C_traits, block_traits = traits - def get_ordered_axes(region: List[Range]) -> Set[Var]: - axes: List[Var] = [] + def get_ordered_axes(region: list[Range]) -> set[Var]: + axes: list[Var] = [] for r in region: if not _is_one(r.extent): raise ValueError("Expect elemwise block access") @@ -352,11 +352,11 @@ def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) return any(is_common_reduce(v) for v in vars) - def check_last_trait(region: List[Range]): + def check_last_trait(region: list[Range]): axes = get_ordered_axes(region) return has_common_reduce(axes[-1]) - def infer_layout(layout: str, region: List[Range], kind: str = "A"): + def infer_layout(layout: str, region: list[Range], kind: str = "A"): """ Infer the layout based on the region and the kind of buffer kind: "A", "B", "C" @@ -409,7 +409,7 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"): ) -def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: +def get_in_out_dtypes(block: tir.Block) -> tuple[str]: """ Detect In/Out data types for the given block based on the analysis if read/write buffers. """ @@ -419,7 +419,7 @@ def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: return (in_dtype, out_dtype) -def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: +def get_dequantize_block(sch, blocks) -> BlockRV | None: # check at least two input and one output # at lease one input has uint dtype, and the output dtype is float def is_dequantize(block: BlockRV) -> bool: @@ -445,8 +445,8 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: if not isinstance(block_stmt.body.value, tir.BufferLoad): return False, False - def get_access_vars(region: List[Range]) -> List[Var]: - axes: List[Var] = [] + def get_access_vars(region: list[Range]) -> list[Var]: + axes: list[Var] = [] for r in region: if not _is_one(r.extent): return None @@ -475,7 +475,7 @@ def is_transpose_block(block_stmt: tir.Block) -> bool: return is_identity_or_transpose_block(block_stmt)[1] -def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): +def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]): result_blocks = [] for block in blocks: if not is_transpose_block(sch.get(block)): @@ -493,7 +493,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] def normalize_to_matmul(sch: tir.Schedule, main_block: BlockRV, - layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: + layout: list[str] | None = None) -> tir.Schedule | None: if layout is None: layout = ["n", "t", "n"] block_stmt = sch.get(main_block) @@ -521,10 +521,10 @@ def normalize_to_matmul(sch: tir.Schedule, def get_tensorized_func_and_tags( func: tir.PrimFunc, target: Target, - layout: Optional[List[str]] = None, + layout: list[str] | None = None, skip_normalize: bool = False, allow_gemv: bool = False, -) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: +) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]: """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -554,9 +554,8 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, - target: Target) -> Union[bool, Dict]: - tags: Dict[str, Union[List[int], int]] = {} + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool | dict: + tags: dict[str, list[int] | int] = {} block_stmt = sch.get(block) # Nvidia Only Support Tensor Core for @@ -584,8 +583,8 @@ def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, tags["use_async_copy"] = True # analysis intrin information - def get_ordered_axes(region: List[Range]) -> Set[Var]: - axes: List[Var] = [] + def get_ordered_axes(region: list[Range]) -> set[Var]: + axes: list[Var] = [] for r in region: if not _is_one(r.extent): raise ValueError("Expect elemwise block access") @@ -602,7 +601,7 @@ def has_common_reduce(var: Var) -> bool: vars = collect_vars_from_expr(var) return any(is_common_reduce(v) for v in vars) - def check_last_trait(region: List[Range]): + def check_last_trait(region: list[Range]): axes = get_ordered_axes(region) return has_common_reduce(axes[-1]) diff --git a/tilelang/carver/roller/bestfit.py b/tilelang/carver/roller/bestfit.py index e8107112e..b66ceaae7 100644 --- a/tilelang/carver/roller/bestfit.py +++ b/tilelang/carver/roller/bestfit.py @@ -17,7 +17,7 @@ def merge(self, other): self.end = max(self.end, other.end) def __repr__(self) -> str: - return "".format(self.start, self.size()) + return f"" class BestFit: diff --git a/tilelang/carver/roller/hint.py b/tilelang/carver/roller/hint.py index 3b51b85c5..20d62f68f 100644 --- a/tilelang/carver/roller/hint.py +++ b/tilelang/carver/roller/hint.py @@ -1,6 +1,6 @@ """Hint definition for schedule""" +from __future__ import annotations from tvm import DataType -from typing import Dict, List, Tuple from . import PrimFuncNode import numpy as np from .rasterization import * @@ -13,17 +13,17 @@ class TensorCoreExtraConfig: def __init__( self, - AS_shape: Tuple[int], - BS_shape: Tuple[int], - AF_shape: Tuple[int], - BF_shape: Tuple[int], - tc_axis: Tuple[int], + AS_shape: tuple[int], + BS_shape: tuple[int], + AF_shape: tuple[int], + BF_shape: tuple[int], + tc_axis: tuple[int], ) -> None: - self.AS_shape: Tuple[int] = AS_shape - self.BS_shape: Tuple[int] = BS_shape - self.AF_shape: Tuple[int] = AF_shape - self.BF_shape: Tuple[int] = BF_shape - self.tc_axis: Tuple[int] = tc_axis + self.AS_shape: tuple[int] = AS_shape + self.BS_shape: tuple[int] = BS_shape + self.AF_shape: tuple[int] = AF_shape + self.BF_shape: tuple[int] = BF_shape + self.tc_axis: tuple[int] = tc_axis class Stride: @@ -45,7 +45,7 @@ def ax(self) -> int: def stride(self) -> int: return self._stride - def compute_strides_from_shape(self, shape: List[int]) -> List[int]: + def compute_strides_from_shape(self, shape: list[int]) -> list[int]: ndim = len(shape) strides = [1 for _ in shape] for i in range(ndim - 2, -1, -1): @@ -55,7 +55,7 @@ def compute_strides_from_shape(self, shape: List[int]) -> List[int]: strides[i] = int(strides[i + 1] * shape[i + 1]) return strides - def compute_elements_from_shape(self, shape: List[int]) -> int: + def compute_elements_from_shape(self, shape: list[int]) -> int: original_shape = np.prod(shape) if not self.is_valid(): strided_elem = original_shape @@ -94,10 +94,10 @@ def __init__(self, output_tile) -> None: self.grid_size = -1 self.valid = True - def get_tile(self, func) -> List[int]: + def get_tile(self, func) -> list[int]: return self.tile_map[func] - def get_rstep(self, node) -> Dict[str, int]: + def get_rstep(self, node) -> dict[str, int]: return self.rstep_map[node] def __hash__(self) -> int: @@ -147,7 +147,7 @@ def inter_transform_b(self) -> bool: return self.weight_transform_kind >= 1 -class Hint(object): +class Hint: """ Central configuration class for managing various parameters of computational tasks. """ @@ -178,15 +178,15 @@ def __init__(self) -> None: # Experimental self._raxis_order = [] self._step = [] - self.vectorize: Dict[str, int] = {} + self.vectorize: dict[str, int] = {} self.pipeline_stage = 1 self.use_async = False - self.opt_shapes: Dict[str, int] = {} + self.opt_shapes: dict[str, int] = {} self.intrin_info = IntrinInfo("float16", "float16", True) self.shared_scope: str = "shared" - self.pass_context: Dict = {} + self.pass_context: dict = {} - def to_dict(self) -> Dict: + def to_dict(self) -> dict: dic = {} dic["block"] = self.block if self.use_tc: @@ -218,7 +218,7 @@ def to_dict(self) -> Dict: return dic @classmethod - def from_dict(cls, dic: Dict) -> "Hint": + def from_dict(cls, dic: dict) -> Hint: hint = cls() for k, v in dic.items(): setattr(hint, k, v) @@ -231,13 +231,13 @@ def tensorcore_legalization(self): return self @property - def raxis_order(self) -> List[int]: + def raxis_order(self) -> list[int]: if self._raxis_order != []: return self._raxis_order return list(range(len(self.rstep))) @property - def step(self) -> List[int]: + def step(self) -> list[int]: if self._step != []: return self._step return [1 for _ in self.block] diff --git a/tilelang/carver/roller/node.py b/tilelang/carver/roller/node.py index 120b8a4b7..f9e38b168 100644 --- a/tilelang/carver/roller/node.py +++ b/tilelang/carver/roller/node.py @@ -1,9 +1,10 @@ """PrimFunc Wrapper and Block information Analaysis""" +from __future__ import annotations import tvm from tvm import tir from tvm.tir import IterVar, PrimFunc -from typing import Any, Dict, List, Tuple, Optional +from typing import Any from tvm.tir.schedule.schedule import BlockRV import numpy as np import functools @@ -29,11 +30,11 @@ def _traverse(block): _traverse(block) -class BlockAnalyzer(object): +class BlockAnalyzer: def __init__(self, sch) -> None: self.sch: tir.Schedule = sch - self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) + self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch) def get_block_name(self, block: BlockRV) -> str: return self.sch.get(block).name_hint @@ -44,7 +45,7 @@ def get_block_info(self, block: BlockRV) -> BlockInfo: return block_info return None - def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: + def get_spatial_axis(self, block: BlockRV) -> list[IterVar]: block_info = self.get_block_info(block) axis = [] for iter in block_info.iters: @@ -52,7 +53,7 @@ def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: axis.append(iter) return axis - def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: + def get_reduce_axis(self, block: BlockRV) -> list[IterVar]: block_info = self.get_block_info(block) raxis = [] for iter in block_info.iters: @@ -60,39 +61,39 @@ def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: raxis.append(iter) return raxis - def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]: + def get_input_buffers(self, block: BlockRV) -> list[tir.Buffer]: buffers = [] for read in self.sch.get(block).reads: buffers.append(read.buffer) return buffers - def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]: + def get_output_buffers(self, block: BlockRV) -> list[tir.Buffer]: buffers = [] for write in self.sch.get(block).writes: buffers.append(write.buffer) return buffers - def get_buffers(self, block: BlockRV) -> List[tir.Buffer]: + def get_buffers(self, block: BlockRV) -> list[tir.Buffer]: return self.get_input_buffers(block) + self.get_output_buffers(block) - def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]: + def get_producer_blocks(self, block: BlockRV) -> list[BlockRV]: return self.sch.get_producers(block) - def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: + def get_consumer_blocks(self, block: BlockRV) -> list[BlockRV]: return self.sch.get_consumers(block) @dataclass class Edge: - src_node: 'Node' - dst_node: 'Node' + src_node: Node + dst_node: Node src_id: int dst_id: int -class Node(object): +class Node: - def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None: + def __init__(self, tags: dict | None = None, name: str = "Node") -> None: self.name = name if tags is None: tags = {} @@ -100,10 +101,10 @@ def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None: self._in_edges = [] self._shapes = [] self._dtypes = [] - self._tag: Dict = {} + self._tag: dict = {} self.update_tags(tags) - def update_tags(self, tags: Dict) -> None: + def update_tags(self, tags: dict) -> None: for tag in tags: self.add_tag(tag, tags[tag]) @@ -125,11 +126,11 @@ def is_output(self): return False @property - def inputs(self) -> List[Edge]: + def inputs(self) -> list[Edge]: return self._in_edges @property - def outputs(self) -> List[Edge]: + def outputs(self) -> list[Edge]: return self._out_edges def set_inputs(self, i: int, edge: Edge): @@ -153,10 +154,10 @@ def set_dtype(self, dtype: tvm.DataType, id=0) -> None: assert self._dtypes[id] == dtype, (self._dtypes, dtype) self._dtypes[id] = dtype - def get_shape(self, id: int = 0) -> List[int]: + def get_shape(self, id: int = 0) -> list[int]: return self._shapes[id] - def set_shape(self, shape: List[int], id=0, overwrite=False) -> None: + def set_shape(self, shape: list[int], id=0, overwrite=False) -> None: if len(self._shapes) <= id: self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)]) # elif self._shapes[id] is not None and not overwrite: @@ -191,15 +192,15 @@ class PrimFuncNode(Node): def __init__(self, prim_func: PrimFunc, - tags: Optional[Dict] = None, + tags: dict | None = None, name: str = "PrimFuncNode") -> None: super().__init__(tags, name=name) self.prim_func = self._specialize_func(prim_func) self.sch: tir.Schedule = tir.Schedule(self.prim_func) self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) - self.schedule_stages: List[BlockRV] = [] - self.blocks: List[BlockRV] = [] - self.output_blocks: List[BlockRV] = None + self.schedule_stages: list[BlockRV] = [] + self.blocks: list[BlockRV] = [] + self.output_blocks: list[BlockRV] = None self.reduction_block: BlockRV = None self.raxis = [] self.input_buffers = [] @@ -219,7 +220,7 @@ def __init__(self, self.set_dtype(tvm.DataType(buffer.dtype), output_id) def _assign_placeholder_node(self): - inputs: List[Node] = [] + inputs: list[Node] = [] for buffer in self.input_buffers: inputs.append(PlaceHolderNode(buffer.name)) @@ -301,8 +302,8 @@ def extent_wrapper(self, value) -> int: else: return value - @functools.lru_cache() - def get_space_dim(self) -> List[int]: + @functools.lru_cache + def get_space_dim(self) -> list[int]: dim_size = [] if self.reduction_block: block_info = self.block_analyzer.get_block_info(self.reduction_block) @@ -333,7 +334,7 @@ def set_dtype(self, dtype: tvm.DataType, id=0) -> None: def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: return tvm.DataType(buffer.dtype) - def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): + def propagate(self, tile, rstep: dict | None = None, targets=None): if rstep is None: rstep = {} shape = { @@ -343,7 +344,7 @@ def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): } return self.ana.infer(shape, rstep, targets) - def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + def propagate_inputs(self, tile, rstep: dict | None = None) -> list[list[int]]: if rstep is None: rstep = {} read_idx_offset = len(self.input_buffers) @@ -363,7 +364,7 @@ def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int] return results # Propagate inputs only on reduction block - def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + def propagate_inputs_on_reduction(self, tile, rstep: dict | None = None) -> list[list[int]]: if rstep is None: rstep = {} reduction_block = self.reduction_block @@ -386,7 +387,7 @@ def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> L results.append(trimmed_shape) return results - def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + def propagate_outputs(self, tile, rstep: dict | None = None) -> list[list[int]]: if rstep is None: rstep = {} read_idx_offset = len(self.input_buffers) @@ -399,9 +400,7 @@ def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int results.append(trimmed_shape) return results - def propagate_reduction_inputs(self, - shape, - rstep: Optional[Dict] = None) -> Dict[str, List[int]]: + def propagate_reduction_inputs(self, shape, rstep: dict | None = None) -> dict[str, list[int]]: if rstep is None: rstep = {} if self.reduction_block is None: @@ -418,8 +417,8 @@ def get_reduce_inputs_dtype(self): for b in self.block_analyzer.get_input_buffers(self.reduction_block) } - @functools.lru_cache() - def infer_tensorcore_axis(self) -> Tuple[int]: + @functools.lru_cache + def infer_tensorcore_axis(self) -> tuple[int]: # axis is fixed for one expression, so only inference and cached assert self.get_tag("tensorcore_config") @@ -461,7 +460,7 @@ def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) return tc_axis - def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int: + def footprint(self, shape, rstep, stride_map: dict | None = None) -> int: if stride_map is None: stride_map = {} result = 0 @@ -510,7 +509,7 @@ def is_after_reduce_stage(block): result += buffer_len return result, cached_tensor - def get_input_buffers(self) -> List[tir.Buffer]: + def get_input_buffers(self) -> list[tir.Buffer]: return self.block_analyzer.input_buffers @@ -537,7 +536,7 @@ def get_ir(self) -> str: return "output" -def topo_order(list_of_nodes) -> List[Node]: +def topo_order(list_of_nodes) -> list[Node]: input_ready_count = {node: len(node.inputs) for node in list_of_nodes} ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes)) output_list = [] @@ -557,7 +556,7 @@ def topo_order(list_of_nodes) -> List[Node]: return output_list -def find_topo_sort_priority(output_node_list) -> List[Node]: +def find_topo_sort_priority(output_node_list) -> list[Node]: import sys sys.setrecursionlimit(10000) @@ -591,7 +590,7 @@ def topo_sort_dfs(node, visited, topo_order): return topo_order -def find_topo_sort(output_node_list) -> List[Node]: +def find_topo_sort(output_node_list) -> list[Node]: def topo_sort_dfs(node, visited, topo_order): if node in visited: diff --git a/tilelang/carver/roller/policy/common.py b/tilelang/carver/roller/policy/common.py index 0dadfa8a2..747dddbb0 100644 --- a/tilelang/carver/roller/policy/common.py +++ b/tilelang/carver/roller/policy/common.py @@ -1,8 +1,8 @@ -from typing import List +from __future__ import annotations import numpy as np -def get_all_factors(n: int) -> List[int]: +def get_all_factors(n: int) -> list[int]: # Calculate the square root of n and round it up to the nearest integer n0 = int(np.ceil(np.sqrt(n))) @@ -16,7 +16,7 @@ def get_all_factors(n: int) -> List[int]: return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] -def factorize(n: int) -> List[int]: +def factorize(n: int) -> list[int]: i = 2 # Start with the smallest prime number result = [] @@ -30,7 +30,7 @@ def factorize(n: int) -> List[int]: return result -def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: +def coalesced_factor(subtensor: list[int], tensor: list[int]) -> int: # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension if subtensor[-1] != tensor[-1] or len(subtensor) == 1: return subtensor[-1] @@ -39,7 +39,7 @@ def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) -def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int: +def coalesced_tensor_shape(subtensor: list[int], tensor: list[int], transaction_size: int) -> int: # Calculate the total number of elements in the subtensor bytes = int(np.prod(subtensor)) diff --git a/tilelang/carver/roller/policy/default.py b/tilelang/carver/roller/policy/default.py index 7837395d9..36d8f1f2c 100644 --- a/tilelang/carver/roller/policy/default.py +++ b/tilelang/carver/roller/policy/default.py @@ -1,8 +1,9 @@ """Policy for cuda core schedule""" +from __future__ import annotations import functools import math from queue import PriorityQueue -from typing import Iterable, Dict, List, Optional +from typing import Iterable import numpy as np import tvm @@ -22,11 +23,11 @@ class DefaultPolicy: """ func: tvm.tir.PrimFunc - nodes: List[PrimFuncNode] = [] + nodes: list[PrimFuncNode] = [] arch: TileDevice - tags: Dict + tags: dict - def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None: + def __init__(self, arch: TileDevice, tags: dict | None = None) -> None: if tags is None: tags = {} @@ -38,20 +39,17 @@ def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None: def from_prim_func(cls, func: tvm.tir.PrimFunc, arch: TileDevice, - tags: Optional[Dict] = None, + tags: dict | None = None, name: str = "PrimFuncNode"): return cls(arch, tags)._init_with_prim_func(func, name) @classmethod - def from_output_nodes(cls, - nodes: List[OutputNode], - arch: TileDevice, - tags: Optional[Dict] = None): + def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None): return cls(arch, tags)._init_with_output_nodes(nodes) def _init_with_prim_func(self, func: tvm.tir.PrimFunc, - name: str = "PrimFuncNode") -> "DefaultPolicy": + name: str = "PrimFuncNode") -> DefaultPolicy: if func is not None and isinstance(func, tvm.tir.PrimFunc): self.func = func self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name) @@ -61,7 +59,7 @@ def _init_with_prim_func(self, self._init_with_output_nodes(output_nodes) return self - def _init_with_output_nodes(self, output_nodes: List[OutputNode]): + def _init_with_output_nodes(self, output_nodes: list[OutputNode]): self.ordered_nodes = list( filter(lambda n: not n.is_placeholder() and not n.is_output(), find_topo_sort(output_nodes))) @@ -78,7 +76,7 @@ def _init_with_output_nodes(self, output_nodes: List[OutputNode]): self.output_nodes.append(node) return self - def emit_config(self, topk: int) -> List[Hint]: + def emit_config(self, topk: int) -> list[Hint]: base_tile = self.get_base_tile() if base_tile is None: return [] @@ -557,7 +555,7 @@ def _compute_stride_map(self, td: TileDict): node, td) td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map - def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: + def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict: """ Computes and returns a TileDict object for a given output tile configuration and reduction step map. @@ -624,7 +622,7 @@ def check_tile_shape_isvalid(self, td: TileDict) -> bool: return True - def recommend_block_size(self, td: TileDict) -> List[int]: + def recommend_block_size(self, td: TileDict) -> list[int]: """ Recommends optimal block sizes based on the TileDict configuration. diff --git a/tilelang/carver/roller/policy/tensorcore.py b/tilelang/carver/roller/policy/tensorcore.py index 60edc930e..15bad4122 100644 --- a/tilelang/carver/roller/policy/tensorcore.py +++ b/tilelang/carver/roller/policy/tensorcore.py @@ -1,6 +1,6 @@ """Policy for tensorcore schedule""" +from __future__ import annotations import tvm -from typing import Dict, List, Tuple, Optional import numpy as np import logging from ..hint import Hint, Stride, TileDict, IntrinInfo @@ -19,9 +19,9 @@ class TensorCorePolicy(DefaultPolicy): wmma_k: int = 16 pipeline_stage: int = 1 use_async_copy: bool = False - block_reduction_depth: Optional[int] = None + block_reduction_depth: int | None = None - def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: Optional[str] = None): + def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str | None = None): super()._init_with_prim_func(func, name) self._legalize_info() return self @@ -52,9 +52,9 @@ def _legalize_info(self): def _compute_tc_strides( self, node: PrimFuncNode, - tile: List[int], - rstep: Optional[Dict[str, int]] = None, - ) -> Tuple[Stride, Stride, Stride]: + tile: list[int], + rstep: dict[str, int] | None = None, + ) -> tuple[Stride, Stride, Stride]: if rstep is None: rstep = {} # strides was used for shared memory padding. which is necessary for avoiding diff --git a/tilelang/carver/roller/rasterization.py b/tilelang/carver/roller/rasterization.py index 3ead2e12e..39c603b6b 100644 --- a/tilelang/carver/roller/rasterization.py +++ b/tilelang/carver/roller/rasterization.py @@ -1,6 +1,5 @@ """Rasteration Plan For L2 Cache Locality""" - -from typing import List +from __future__ import annotations class Rasterization: @@ -10,7 +9,7 @@ class Rasterization: def __init__(self) -> None: pass - def get_code(self) -> List[str]: + def get_code(self) -> list[str]: raise NotImplementedError() @property @@ -27,7 +26,7 @@ def __init__(self) -> None: def __repr__(self) -> str: return "" - def get_code(self) -> List[str]: + def get_code(self) -> list[str]: return [] @@ -47,7 +46,7 @@ def __init__(self, panel_width=4) -> None: def __repr__(self) -> str: return f"" - def get_code(self) -> List[str]: + def get_code(self) -> list[str]: raise NotImplementedError() @@ -84,10 +83,10 @@ def get_device_function(self) -> str: } """ - def get_code(self, panel_width: int = None) -> List[str]: + def get_code(self, panel_width: int = None) -> list[str]: if panel_width is None: panel_width = self.panel_width_ return [ self.get_device_function(), - "const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width), + f"const dim3 blockIdx = rasterization2DColumn({panel_width});\n", ] diff --git a/tilelang/carver/roller/shape_inference/common.py b/tilelang/carver/roller/shape_inference/common.py index a3a7a31d6..aaf59aed9 100644 --- a/tilelang/carver/roller/shape_inference/common.py +++ b/tilelang/carver/roller/shape_inference/common.py @@ -1,10 +1,10 @@ +from __future__ import annotations from collections import OrderedDict -from typing import Dict, List from tvm import arith -class Statement(): +class Statement: def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): @@ -18,12 +18,12 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) -class InputShapeInference(): +class InputShapeInference: - def __init__(self, deps: List[Statement]): + def __init__(self, deps: list[Statement]): self.deps = deps - def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]): + def _infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int]): shape = shape.copy() ana = arith.Analyzer() for dep in reversed(self.deps): @@ -44,7 +44,7 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i shape[name] = [c.max_value - c.min_value + 1 for c in bounds] return shape - def infer(self, shape, rstep: Dict[str, int] = None): + def infer(self, shape, rstep: dict[str, int] = None): if rstep is None: rstep = {} if isinstance(shape, (list, tuple)): diff --git a/tilelang/carver/roller/shape_inference/tir.py b/tilelang/carver/roller/shape_inference/tir.py index 8a744ec00..c1b97188a 100644 --- a/tilelang/carver/roller/shape_inference/tir.py +++ b/tilelang/carver/roller/shape_inference/tir.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Tuple, Set, Mapping +from __future__ import annotations +from typing import Mapping from tvm.tir.schedule.schedule import BlockRV from tvm.ir import structural_equal from tvm import arith, tir @@ -15,7 +16,7 @@ def __init__(self, block_analyzer, block: BlockRV): self.reverse_bound_inference = {} - def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]): + def make_reverse(self, input_name: str, input_iter: list[tir.PrimExpr]): if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: return None if len(self.dependent_region[input_name]) != 1: @@ -47,7 +48,7 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) -class TensorDepNode(object): +class TensorDepNode: """ For tensor dependency analysis. """ @@ -76,7 +77,7 @@ def __repr__(self): return self.name -class DependencyAnalysis(object): +class DependencyAnalysis: def __init__(self, deps): self.deps = deps @@ -89,7 +90,7 @@ def _construct_unique_name2dep(self, deps): This is a workaround for the issue that we have two same ops' fuse case. See https://github.com/apache/tvm/issues/16433 """ - _names: Set = set() + _names: set = set() name2dep: Mapping = {} for dep in deps: output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] @@ -168,7 +169,7 @@ def _find_path_recursive(self, current_node, target_name, visited, path): class InputShapeInference: - def __init__(self, deps: List[Statement]): + def __init__(self, deps: list[Statement]): self.deps = deps self.target_mapping = {} self.buffer_mapping = {} @@ -179,7 +180,7 @@ def __init__(self, deps: List[Statement]): self.dep_analysis = DependencyAnalysis(self.deps) self.dep_analysis.analyze() - def construct_dependency_target(self, targets: Tuple[str]): + def construct_dependency_target(self, targets: tuple[str]): if targets in self.target_mapping: return self.target_mapping[targets] # should be buffer name instead of block name @@ -242,8 +243,8 @@ def construct_dependency_target(self, targets: Tuple[str]): return input_vars, mapping def infer(self, - shape: Dict[str, List[arith.ConstIntBound]], - rstep: Dict[str, int] = None, + shape: dict[str, list[arith.ConstIntBound]], + rstep: dict[str, int] = None, targets=None): if rstep is None: rstep = {} @@ -351,10 +352,10 @@ def walk_indice(expr): elif isinstance(expr, tir.Call): return None else: - raise Exception("Unhandled node type in walk_indice(): %s" % expr) + raise Exception(f"Unhandled node type in walk_indice(): {expr}") -def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]: +def _extract_dependent_region(block_analyzer, block: BlockRV) -> dict[str, list[tir.PrimExpr]]: input_buffers = block_analyzer.get_input_buffers(block) dependent_region = {buffer.name: [] for buffer in input_buffers} diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py index 0de3c5996..5aa5074c2 100644 --- a/tilelang/carver/template/base.py +++ b/tilelang/carver/template/base.py @@ -1,11 +1,11 @@ # Import necessary modules and classes +from __future__ import annotations from abc import ABC, abstractmethod # For defining abstract base classes from dataclasses import dataclass, field # For defining data classes from ..arch import ( # Import architecture-related utilities and classes TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch) from ..roller.hint import Hint # Import the Hint class from ..roller.node import OutputNode # Import the OutputNode class -from typing import List # For type hinting from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions @@ -24,10 +24,10 @@ class BaseTemplate(ABC): _func: PrimFunc = field(default=None, init=False, repr=False) # The outputs nodes associated with this template, initially None - _output_nodes: List[OutputNode] = field(default=None, init=False, repr=False) + _output_nodes: list[OutputNode] = field(default=None, init=False, repr=False) @abstractmethod - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: """ Abstract method that must be implemented by subclasses. It should return a list of hardware-aware configurations (hints) @@ -42,7 +42,7 @@ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> """ pass - def with_arch(self, arch: TileDevice) -> "BaseTemplate": + def with_arch(self, arch: TileDevice) -> BaseTemplate: """ Sets the architecture for this template and returns itself. @@ -110,7 +110,7 @@ def initialize_function(self) -> None: """ raise NotImplementedError("initialize_function is not implemented") - def set_function(self, func: PrimFunc) -> "BaseTemplate": + def set_function(self, func: PrimFunc) -> BaseTemplate: """ Sets the function for this template and returns itself. @@ -123,7 +123,7 @@ def set_function(self, func: PrimFunc) -> "BaseTemplate": self._func = func return self - def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate": + def set_output_nodes(self, output_nodes: list[OutputNode]) -> BaseTemplate: """ Sets the output nodes for this template and returns itself. @@ -136,7 +136,7 @@ def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate": self._output_nodes = output_nodes return self - def recommend_hints(self, topk: int = 10) -> List[Hint]: + def recommend_hints(self, topk: int = 10) -> list[Hint]: """ Provides a list of recommended hardware-aware configurations. @@ -159,7 +159,7 @@ def arch(self) -> TileDevice: return self._arch @property - def output_nodes(self) -> List[OutputNode]: + def output_nodes(self) -> list[OutputNode]: """ Returns the output nodes associated with this template. diff --git a/tilelang/carver/template/conv.py b/tilelang/carver/template/conv.py index 5931b2656..f180084d5 100644 --- a/tilelang/carver/template/conv.py +++ b/tilelang/carver/template/conv.py @@ -1,8 +1,8 @@ +from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te, tir from ..roller import Hint -from typing import List from ..utils import get_roller_hints_from_func @@ -44,7 +44,7 @@ class ConvTemplate(BaseTemplate): accum_dtype: str = "float16" # Data type for accumulation with_bias: bool = False # Whether to add a bias term - def get_hardware_aware_configs(self, arch=None, topk=10) -> List[Hint]: + def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]: """ Retrieves optimized hardware-aware configurations. diff --git a/tilelang/carver/template/elementwise.py b/tilelang/carver/template/elementwise.py index 311b75ccf..26d531529 100644 --- a/tilelang/carver/template/elementwise.py +++ b/tilelang/carver/template/elementwise.py @@ -1,10 +1,10 @@ # Import necessary modules +from __future__ import annotations from dataclasses import dataclass # Used for defining data classes from .base import BaseTemplate # Importing the base class for templates from tvm import te # Importing TVM's tensor expression module from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations from ..roller import Hint # Importing Hint for optimization hints -from typing import List # Importing List type hint from ..utils import get_roller_hints_from_func # Function to obtain optimization hints @@ -19,10 +19,10 @@ class ElementwiseTemplate(BaseTemplate): """ # OP Related Config - shape: List[int] = None # Shape of the tensor + shape: list[int] = None # Shape of the tensor dtype: str = "float16" # Data type of the tensor - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: """ Retrieves hardware-aware optimization configurations. diff --git a/tilelang/carver/template/flashattention.py b/tilelang/carver/template/flashattention.py index f9dc85b76..760b19817 100644 --- a/tilelang/carver/template/flashattention.py +++ b/tilelang/carver/template/flashattention.py @@ -1,17 +1,17 @@ +from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te from ..arch import TileDevice from ..roller import Hint from ..roller import PrimFuncNode, OutputNode, Edge -from typing import List from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags @dataclass class FlashAttentionTemplate(BaseTemplate): - _output_nodes: List[OutputNode] = None + _output_nodes: list[OutputNode] = None # Operation-related configuration parameters batch_size: int = 1 @@ -26,7 +26,7 @@ class FlashAttentionTemplate(BaseTemplate): out_dtype: str = "float16" accum_dtype: str = "float16" - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: """ Retrieves optimized hardware-aware configurations. diff --git a/tilelang/carver/template/gemv.py b/tilelang/carver/template/gemv.py index a6e943a01..7195a0b87 100644 --- a/tilelang/carver/template/gemv.py +++ b/tilelang/carver/template/gemv.py @@ -1,9 +1,9 @@ +from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te from ..arch import TileDevice from ..roller import Hint -from typing import List from ..utils import get_roller_hints_from_func @@ -25,7 +25,7 @@ class GEMVTemplate(BaseTemplate): accum_dtype: str = "float16" # Accumulation data type with_bias: bool = False # Whether to add a bias term - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: """ Retrieves optimized hardware-aware configurations. diff --git a/tilelang/carver/template/general_reduce.py b/tilelang/carver/template/general_reduce.py index 9eba86c63..a8da5fd6c 100644 --- a/tilelang/carver/template/general_reduce.py +++ b/tilelang/carver/template/general_reduce.py @@ -1,9 +1,9 @@ +from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te from ..arch import TileDevice from ..roller import Hint -from typing import List, Union from ..utils import get_roller_hints_from_func @@ -11,11 +11,11 @@ class GeneralReductionTemplate(BaseTemplate): # OP Related Config - structure: Union[str, List[str]] = None - shape: List[int] = None + structure: str | list[str] = None + shape: list[int] = None dtype: str = "float16" - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: roller_hints = get_roller_hints_from_func( self._func, arch=arch, topk=topk, allow_gemv=False) return roller_hints diff --git a/tilelang/carver/template/matmul.py b/tilelang/carver/template/matmul.py index 24aa6ef91..4847cdb22 100644 --- a/tilelang/carver/template/matmul.py +++ b/tilelang/carver/template/matmul.py @@ -1,9 +1,9 @@ +from __future__ import annotations from dataclasses import dataclass from .base import BaseTemplate from tvm import te from ..arch import TileDevice from ..roller import Hint -from typing import List from ..utils import get_roller_hints_from_func @@ -38,7 +38,7 @@ class MatmulTemplate(BaseTemplate): accum_dtype: str = "float16" # Data type for accumulation with_bias: bool = False # Whether to add a bias term - def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]: """ Retrieves optimized hardware-aware configurations. diff --git a/tilelang/carver/utils.py b/tilelang/carver/utils.py index 649b4388c..cedb7547a 100644 --- a/tilelang/carver/utils.py +++ b/tilelang/carver/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from __future__ import annotations from tvm import tir, IRModule from tvm.tir import PrimFunc from .arch import TileDevice @@ -26,11 +26,11 @@ def get_rasterization_code(pannel_width: int = 8) -> str: """ -def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], +def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule, arch: TileDevice, topk: int = 10, tensorcore_only: bool = False, - allow_gemv: bool = False) -> Optional[List[Hint]]: + allow_gemv: bool = False) -> list[Hint] | None: func = None if isinstance(func_or_module, tir.PrimFunc): func = func_or_module @@ -69,11 +69,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], return roller_hints -def get_roller_hints_from_output_nodes( - output_nodes: List[OutputNode], - arch: TileDevice, - topk: int = 10, - extra_tags: Optional[List[str]] = None) -> Optional[List[Hint]]: +def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode], + arch: TileDevice, + topk: int = 10, + extra_tags: list[str] | None = None) -> list[Hint] | None: assert isinstance(output_nodes, list), "The input should be a list of functions." lints = [] diff --git a/tilelang/contrib/cc.py b/tilelang/contrib/cc.py index 26bb419db..d5cba6c4e 100644 --- a/tilelang/contrib/cc.py +++ b/tilelang/contrib/cc.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Util to invoke C/C++ compilers in the system.""" +from __future__ import annotations import functools import os import shutil @@ -23,7 +24,6 @@ # pylint: disable=invalid-name import sys -from typing import Dict from tvm.base import py_str from tvm.contrib import tar as _tar @@ -208,7 +208,7 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e raise ValueError("Unsupported platform") -def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]: +def get_global_symbol_section_map(path, *, nm=None) -> dict[str, str]: """Get global symbols from a library via nm -g Parameters diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py index afd381223..92fbcc8e3 100644 --- a/tilelang/contrib/hipcc.py +++ b/tilelang/contrib/hipcc.py @@ -54,7 +54,7 @@ def compile_hip(code, if target_format not in ["hsaco"]: raise ValueError("target_format must be hsaco") temp_code = temp.relpath("my_kernel.cc") - temp_target = temp.relpath("my_kernel.%s" % target_format) + temp_target = temp.relpath(f"my_kernel.{target_format}") with open(temp_code, "w") as out_file: out_file.write(code) diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py index 6b2e739a0..8e813d92b 100644 --- a/tilelang/contrib/nvcc.py +++ b/tilelang/contrib/nvcc.py @@ -2,11 +2,11 @@ # modified from apache tvm python/tvm/contrib/nvcc.py """Utility to invoke nvcc compiler in the system""" from __future__ import absolute_import as _abs +from __future__ import annotations import os import subprocess import warnings -from typing import Tuple from tilelang.env import CUDA_HOME import tvm.ffi @@ -299,7 +299,7 @@ def get_target_compute_version(target=None): "Try specifying it by adding '-arch=sm_xx' to your target.") -def parse_compute_version(compute_version) -> Tuple[int, int]: +def parse_compute_version(compute_version) -> tuple[int, int]: """Parse compute capability string to divide major and minor version Parameters diff --git a/tilelang/contrib/nvrtc.py b/tilelang/contrib/nvrtc.py index 0f07022c9..b69115549 100644 --- a/tilelang/contrib/nvrtc.py +++ b/tilelang/contrib/nvrtc.py @@ -1,10 +1,11 @@ +from __future__ import annotations import cuda.bindings.nvrtc as nvrtc -from typing import Literal, Union, List, Optional, Tuple +from typing import Literal from tvm.target import Target from .nvcc import get_target_compute_version, parse_compute_version -def get_nvrtc_version() -> Tuple[int, int]: +def get_nvrtc_version() -> tuple[int, int]: result, major, minor = nvrtc.nvrtcVersion() assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}" return (major, minor) @@ -12,8 +13,8 @@ def get_nvrtc_version() -> Tuple[int, int]: def compile_cuda(code: str, target_format: Literal["ptx", "cubin"] = "ptx", - arch: Optional[int] = None, - options: Optional[Union[str, List[str]]] = None, + arch: int | None = None, + options: str | list[str] | None = None, verbose: bool = False) -> bytearray: """Compile cuda code with NVRTC. diff --git a/tilelang/engine/callback.py b/tilelang/engine/callback.py index 8d43e41d5..ee1c80693 100644 --- a/tilelang/engine/callback.py +++ b/tilelang/engine/callback.py @@ -1,4 +1,5 @@ -from typing import Callable, Union +from __future__ import annotations +from typing import Callable from tvm import register_func from tvm.target import Target @@ -25,7 +26,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T register_func("tilelang_callback_hip_postproc", f=func, override=override) -def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): +def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True): """Decorator for registering CUDA post-processing callback function. Can be used with or without parentheses: @@ -58,7 +59,7 @@ def _register(fn: Callable[[str, Target], str]): raise TypeError("Invalid decorator usage") -def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True): +def register_hip_postproc_callback(func: Callable | bool = None, override: bool = True): """Decorator for registering HIP post-processing callback function. Can be used with or without parentheses: diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 717a8ebd2..8738f58a1 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -1,8 +1,9 @@ """The compiler for TL programs.""" +from __future__ import annotations import os import os.path as osp -from typing import Union, Optional, Callable, List +from typing import Callable import tilelang.transform from tilelang import tvm as tvm from tvm import tir @@ -114,7 +115,7 @@ def tilelang_callback_hip_compile(code, target): return hsaco -def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: +def extrac_params(func: tir.PrimFunc) -> list[KernelParam]: tensor_types = [] for var in func.params: if var in func.buffer_map: @@ -124,7 +125,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]: return tensor_types -def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]): +def canon_target_host(target: str | Target, target_host: str | Target | None): if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" @@ -190,9 +191,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) -> def lower( - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], - target: Union[str, Target] = "auto", - target_host: Optional[Union[str, Target]] = None, + func_or_mod: tir.PrimFunc | tvm.IRModule, + target: str | Target = "auto", + target_host: str | Target | None = None, runtime_only=False, enable_host_codegen=False, enable_device_compile=False, diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py index 2db2d8391..de3c979ea 100644 --- a/tilelang/engine/param.py +++ b/tilelang/engine/param.py @@ -1,7 +1,7 @@ """The profiler and convert to torch utils""" +from __future__ import annotations from dataclasses import dataclass -from typing import List, Union, Optional import torch from tilelang import tvm as tvm from tvm.tir import Buffer, IntImm, Var, PrimExpr @@ -15,7 +15,7 @@ class KernelParam: Used to describe tensor or scalar parameters in TVM/PyTorch interop. """ dtype: torch.dtype # PyTorch data type of the parameter - shape: List[Union[int, Var]] # List of dimensions, can be integers or TVM variables + shape: list[int | Var] # List of dimensions, can be integers or TVM variables @classmethod def from_buffer(cls, buffer: Buffer): @@ -111,7 +111,6 @@ class CompiledArtifact: """ host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code - params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel + params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel kernel_source: str # Raw source code of the generated kernel - rt_mod: Optional[ - tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized + rt_mod: tvm.runtime.Module | None = None # Runtime module for execution, may be lazily initialized diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 7126186cc..10fd87d10 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -1,13 +1,13 @@ +from __future__ import annotations from tvm import tir, IRModule from tvm.target import Target import tilelang from tilelang.transform import PassContext from tilelang.contrib.nvcc import have_tma, is_hopper -from typing import Optional -def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, - target: Optional[Target] = None) -> bool: +def allow_warp_specialized(pass_ctx: PassContext | None = None, + target: Target | None = None) -> bool: # avoid circular import from tilelang.jit.adapter.utils import is_cuda_target @@ -19,8 +19,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None, return not disable_warp_specialized -def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, - target: Optional[Target] = None) -> bool: +def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, + target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() if not have_tma(target): @@ -29,26 +29,26 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target) -def allow_fence_proxy(target: Optional[Target] = None) -> bool: +def allow_fence_proxy(target: Target | None = None) -> bool: return have_tma(target) -def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: +def allow_vectorize(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False) return not disable_vectorize -def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) -> bool: +def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False) return enable_global_thread_sync -def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, - target: Optional[Target] = None) -> bool: +def should_enable_aggressive_merge(pass_ctx: PassContext | None = None, + target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() enable_aggressive_merge = bool( @@ -61,7 +61,7 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None, return enable_aggressive_merge -def should_force_let_inline(pass_ctx: Optional[PassContext] = None) -> bool: +def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False)) diff --git a/tilelang/env.py b/tilelang/env.py index 08cf031ca..9d3f50a8e 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -1,3 +1,4 @@ +from __future__ import annotations import sys import os import pathlib @@ -5,7 +6,6 @@ import shutil import glob from dataclasses import dataclass -from typing import Optional logger = logging.getLogger(__name__) @@ -170,7 +170,7 @@ class Environment: key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION") default: str # Default value if the environment variable is not set - _forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging) + _forced_value: str | None = None # Temporary runtime override (mainly for tests/debugging) def get(self): if self._forced_value is not None: diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py index 12551b193..aa369980f 100644 --- a/tilelang/intrinsics/mfma_macro_generator.py +++ b/tilelang/intrinsics/mfma_macro_generator.py @@ -1,17 +1,16 @@ +from __future__ import annotations from tilelang import tvm as tvm import tilelang.language as T -from typing import Tuple from tvm import DataType from tvm.tir import PrimExpr from tvm.runtime import convert -from typing import Optional from .utils import ( mfma_store_index_map,) lift = convert -class MatrixCoreIntrinEmitter(object): +class MatrixCoreIntrinEmitter: """ To eliminate Python syntax within TIR Macro. """ @@ -51,9 +50,9 @@ def __init__( chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, - k_pack: Optional[int] = None, - is_m_first: Optional[bool] = False, - b_preshuffle: Optional[bool] = False, + k_pack: int | None = None, + is_m_first: bool | None = False, + b_preshuffle: bool | None = False, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -135,15 +134,15 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): self.micro_size_y = n_dim self.micro_size_k = k_dim - def _initialize_k_pack(self, k_pack: Optional[int] = None): + def _initialize_k_pack(self, k_pack: int | None = None): if k_pack is not None: self.k_pack = k_pack - def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): + def _initialize_is_m_first(self, is_m_first: bool | None = False): if is_m_first is not None: self.is_m_first = is_m_first - def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): + def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False): if b_preshuffle is not None: self.b_preshuffle = b_preshuffle @@ -203,7 +202,7 @@ def get_ldmatrix_index_map(self, is_b=False): def extract_thread_binding(self, thread_id, - is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: + is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: ''' is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -418,10 +417,10 @@ def __init__( chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, - k_pack: Optional[int] = None, - is_m_first: Optional[bool] = False, - a_preshuffle: Optional[bool] = False, - b_preshuffle: Optional[bool] = False, + k_pack: int | None = None, + is_m_first: bool | None = False, + a_preshuffle: bool | None = False, + b_preshuffle: bool | None = False, ): self.a_dtype = a_dtype diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py index 8ddd9f96d..1fec00584 100644 --- a/tilelang/intrinsics/mma_layout.py +++ b/tilelang/intrinsics/mma_layout.py @@ -1,4 +1,4 @@ -from typing import Union +from __future__ import annotations from tvm import arith, DataType import tilelang.language as T @@ -163,7 +163,7 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) -def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None): +def get_swizzle_layout(row_idx, col_idx, row_size, dtype: DataType | str, swizzle_bytes=None): ana = arith.Analyzer() if isinstance(dtype, str): dtype = DataType(dtype) diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py index 65d2ab0ca..537cc762c 100644 --- a/tilelang/intrinsics/mma_macro_generator.py +++ b/tilelang/intrinsics/mma_macro_generator.py @@ -1,5 +1,6 @@ +from __future__ import annotations import tilelang.language as T -from typing import Union, Tuple, Optional, Literal, Callable +from typing import Literal, Callable from tilelang.common import TransformKind from tvm import DataType from tvm.tir import PrimExpr, IndexMap, Buffer, Var @@ -25,7 +26,7 @@ lift = convert -class TensorCoreIntrinEmitter(object): +class TensorCoreIntrinEmitter: """ To eliminate Python syntax within TIR Macro. """ @@ -62,8 +63,8 @@ def __init__( chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, - is_m_first: Optional[bool] = False, - thread_var: Optional[Var] = None, + is_m_first: bool | None = False, + thread_var: Var | None = None, ): self.a_dtype = a_dtype self.b_dtype = b_dtype @@ -144,7 +145,7 @@ def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): self.micro_size_x = m_dim self.micro_size_k = k_dim - def _initialize_is_m_first(self, is_m_first: Optional[bool] = False): + def _initialize_is_m_first(self, is_m_first: bool | None = False): if is_m_first is not None: self.is_m_first = is_m_first @@ -167,7 +168,7 @@ def get_store_index_map(self, inverse: bool = False) -> IndexMap: def extract_thread_binding( self, thread_id: PrimExpr, - is_m_first: Optional[bool] = None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]: + is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]: """ is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m) which represents [warp_size, block_row_warps (split n), block_col_warps (split m)] @@ -200,7 +201,7 @@ def ldmatrix_a(self, A_local_buf: Buffer, A_shared_buf: Buffer, ki: PrimExpr, - rk: Optional[PrimExpr] = 0): + rk: PrimExpr | None = 0): warp_row_tiles = self.warp_row_tiles warp_rows = self.warp_rows chunk = self.chunk @@ -264,7 +265,7 @@ def ldmatrix_b(self, B_local_buf: Buffer, B_shared_buf: Buffer, ki: PrimExpr, - rk: Optional[PrimExpr] = 0): + rk: PrimExpr | None = 0): warp_col_tiles = self.warp_col_tiles warp_cols = self.warp_cols chunk = self.chunk @@ -336,7 +337,7 @@ def mma(self, A_local_buf: Buffer, B_local_buf: Buffer, C_local_buf: Buffer, - k_inner: Optional[PrimExpr] = 0): + k_inner: PrimExpr | None = 0): warp_rows = self.warp_rows warp_cols = self.warp_cols local_size_a = self.local_size_a @@ -518,8 +519,7 @@ def make_mma_load_layout(self, else: raise ValueError(f"Unsupported matrix {matrix}") - assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( - local_buf.scope()) + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" if matrix_is_a: micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k @@ -684,9 +684,9 @@ def __init__( chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, - is_m_first: Optional[bool] = False, - transform_kind_a: Union[int, TransformKind] = 0, - transform_kind_b: Union[int, TransformKind] = 0, + is_m_first: bool | None = False, + transform_kind_a: int | TransformKind = 0, + transform_kind_b: int | TransformKind = 0, ): super().__init__( a_dtype=a_dtype, diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index 9d64a15fe..d9d591f72 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -1,6 +1,7 @@ +from __future__ import annotations import tilelang.language as T from enum import IntEnum -from typing import Optional, Callable +from typing import Callable from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter from tvm import DataType from tvm.tir import PrimExpr, Buffer, Var, IndexMap @@ -86,8 +87,8 @@ def __init__( chunk: int = 16, reduce_k: int = 1, num_elems_per_byte: int = 1, - is_m_first: Optional[bool] = False, - thread_var: Optional[Var] = None, + is_m_first: bool | None = False, + thread_var: Var | None = None, ): super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps, block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k, @@ -409,8 +410,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragme transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a( j, i) - assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format( - local_buf.scope()) + assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}" micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py index 447e43b2a..78454a558 100644 --- a/tilelang/jit/__init__.py +++ b/tilelang/jit/__init__.py @@ -3,17 +3,13 @@ It includes functionality to JIT-compile TileLang programs into a runnable kernel adapter using TVM. """ +from __future__ import annotations from typing import ( Any, - List, - Union, Callable, - Tuple, overload, Literal, - Dict, # For type hinting dicts - Optional, ) from tilelang import tvm as tvm from tilelang.jit.adapter.utils import is_metal_target @@ -33,13 +29,13 @@ def compile( func: PrimFunc = None, - out_idx: Union[List[int], int, None] = None, + out_idx: list[int] | int | None = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", - target: Union[str, Target] = "auto", - target_host: Union[str, Target, None] = None, + target: str | Target = "auto", + target_host: str | Target | None = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[Union[List[str], str]] = None, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | str | None = None, ) -> JITKernel: """ Compile the given TileLang PrimFunc with TVM and build a JITKernel. @@ -92,24 +88,24 @@ def compile( class _JitImplementation: - out_idx: Optional[Union[List[int], int]] - target: Union[str, Target] - target_host: Union[str, Target] + out_idx: list[int] | int | None + target: str | Target + target_host: str | Target execution_backend: Literal["dlpack", "ctypes", "cython"] verbose: bool - pass_configs: Optional[Dict[str, Any]] - debug_root_path: Optional[str] - compile_flags: Optional[Union[List[str], str]] + pass_configs: dict[str, Any] | None + debug_root_path: str | None + compile_flags: list[str] | str | None def __init__(self, out_idx: Any = None, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, + target: str | Target = "auto", + target_host: str | Target = None, execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - debug_root_path: Optional[str] = None, - compile_flags: Optional[Union[List[str], str]] = None): + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None): """ Initializes the JIT compiler decorator. @@ -162,12 +158,12 @@ def __init__(self, except NameError: self.debug_root_path = path.abspath(self.debug_root_path) - self._kernel_cache: Dict[tuple, Kernel] = {} + self._kernel_cache: dict[tuple, Kernel] = {} # This tells the type checker what the *wrapper* function will return. # this is for linting, please do not remove it. @overload - def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]: + def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]: ... @overload @@ -242,16 +238,16 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any: def jit( # This is the new public interface - func: Union[Callable[_P, _RProg], PrimFunc, None] = None, + func: Callable[_P, _RProg] | PrimFunc | None = None, *, # Indicates subsequent arguments are keyword-only out_idx: Any = None, - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, + target: str | Target = "auto", + target_host: str | Target = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - debug_root_path: Optional[str] = None, - compile_flags: Optional[Union[List[str], str]] = None): + pass_configs: dict[str, Any] | None = None, + debug_root_path: str | None = None, + compile_flags: list[str] | str | None = None): """ Just-In-Time (JIT) compiler decorator for TileLang functions. diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py index 1b584d71c..9d998bc96 100644 --- a/tilelang/jit/adapter/base.py +++ b/tilelang/jit/adapter/base.py @@ -1,21 +1,22 @@ """The profiler and convert to torch utils""" +from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List, Callable, Optional +from typing import Any, Callable from tilelang.engine.param import KernelParam class BaseKernelAdapter(ABC): - func: Optional[Callable] = None + func: Callable | None = None - def __init__(self, mod, params: List[KernelParam], result_idx: List[int]) -> None: + def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None: self.mod = mod self.params = params self.result_idx = self._legalize_result_idx(result_idx) self._post_init() - def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]: + def _legalize_result_idx(self, result_idx: list[int] | None) -> list[int]: params = self.params # result_idx is a list of indices of the output tensors if result_idx is None: diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py index 7ec6cef0d..648c66c1c 100644 --- a/tilelang/jit/adapter/ctypes/adapter.py +++ b/tilelang/jit/adapter/ctypes/adapter.py @@ -1,9 +1,10 @@ """The profiler and convert to torch utils""" +from __future__ import annotations import torch from ..base import BaseKernelAdapter import ctypes -from typing import List, Optional, Union, Callable, Dict, Tuple, Any +from typing import Callable, Any from tilelang import tvm as tvm from tvm.target import Target from tvm.relax import TensorType @@ -25,32 +26,32 @@ class CtypesKernelAdapter(BaseKernelAdapter): # Class attributes to store compiled kernel information target = "cuda" - ir_module: Optional[tvm.IRModule] = None + ir_module: tvm.IRModule | None = None # The global source code of the kernel -> global means the source code of the kernel # that is not wrapped by the wrapper code - kernel_global_source: Optional[str] = None - lib: Optional[ctypes.CDLL] = None # Compiled library handle - wrapped_source: Optional[str] = None # Generated C++ wrapper code + kernel_global_source: str | None = None + lib: ctypes.CDLL | None = None # Compiled library handle + wrapped_source: str | None = None # Generated C++ wrapper code # Maps symbolic variables to their corresponding buffer and shape indices - dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None + dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None # Pass configs for the compiler - pass_configs: Optional[Dict[str, Any]] = None + pass_configs: dict[str, Any] | None = None # Add new cache attributes - param_dtypes: Optional[List[torch.dtype]] = None # Cache for parameter dtypes - param_shapes: Optional[List[List]] = None # Cache for parameter shapes + param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes + param_shapes: list[list] | None = None # Cache for parameter shapes def __init__(self, - params: List[TensorType], - result_idx: List[int], + params: list[TensorType], + result_idx: list[int], target: str, - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], - host_mod: Optional[tvm.IRModule] = None, - device_mod: Optional[tvm.IRModule] = None, - kernel_global_source: Optional[str] = None, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + kernel_global_source: str | None = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): """Initialize the adapter with the given TIR function or module. Args: @@ -107,15 +108,15 @@ def __init__(self, @classmethod def from_database(cls, - params: List[TensorType], - result_idx: List[int], + params: list[TensorType], + result_idx: list[int], target: str, - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + func_or_mod: tir.PrimFunc | tvm.IRModule, kernel_global_source: str, kernel_lib_path: str, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -155,7 +156,7 @@ def from_database(cls, adapter._post_init() return adapter - def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: """Extract information about dynamic shapes from the TIR function. Maps symbolic variables to their corresponding (id, buffer_index, dimension) @@ -182,7 +183,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map - def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel. Converts PyTorch tensor pointers to C void pointers for ctypes interface. @@ -193,9 +194,7 @@ def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args) - def _wrap_forward_from_prebuild_lib(self, - *ins: List[torch.Tensor], - stream: Optional[int] = None): + def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): """High-level wrapper for kernel execution. Handles: diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py index d210de46c..7857872cf 100644 --- a/tilelang/jit/adapter/cython/adapter.py +++ b/tilelang/jit/adapter/cython/adapter.py @@ -1,10 +1,11 @@ """The profiler and convert to torch utils""" +from __future__ import annotations import ctypes import logging import torch -from typing import List, Optional, Union, Callable, Dict, Tuple, Any +from typing import Callable, Any from tilelang import tvm as tvm from tvm.target import Target from tilelang.engine.param import KernelParam @@ -44,43 +45,43 @@ class CythonKernelAdapter(BaseKernelAdapter): """ # Class attributes to store compiled kernel information - target: Union[str, Target] = "cuda" - ir_module: Optional[tvm.IRModule] = None + target: str | Target = "cuda" + ir_module: tvm.IRModule | None = None # The global source code of the kernel -> global means the source code of the kernel # that is not wrapped by the wrapper code - kernel_global_source: Optional[str] = None - lib: Optional[ctypes.CDLL] = None # Compiled library handle - wrapped_source: Optional[str] = None # Generated C++ wrapper code + kernel_global_source: str | None = None + lib: ctypes.CDLL | None = None # Compiled library handle + wrapped_source: str | None = None # Generated C++ wrapper code # Maps symbolic variables to their corresponding buffer and shape indices - dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None + dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None # Maps pointer arguments to their corresponding (buffer_index, shape_dimension) - ptr_map: Optional[Dict[int, str]] = None + ptr_map: dict[int, str] | None = None # Maps buffer variables to their corresponding dtypes - buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None + buffer_dtype_map: dict[tir.Var, tuple[int, torch.dtype]] | None = None # Maps buffer variables to their corresponding static shapes and strides, # e.g., { # "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16) # } - static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None - static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None + static_shape_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None + static_strides_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None # Contains contiguous buffers - static_contiguous_list: Optional[List[tir.Var]] = None + static_contiguous_list: list[tir.Var] | None = None # Maps buffer variables to their corresponding devices - buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None + buffer_device_map: dict[tir.Var, tuple[int, torch.device]] | None = None # Pass configs for the compiler - pass_configs: Optional[Dict[str, Any]] = None + pass_configs: dict[str, Any] | None = None def __init__(self, - params: List[KernelParam], - result_idx: List[int], - target: Union[str, Target], - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], - host_mod: Optional[tvm.IRModule] = None, - device_mod: Optional[tvm.IRModule] = None, - kernel_global_source: Optional[str] = None, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + kernel_global_source: str | None = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): """Initialize the adapter with the given TIR function or module. Args: @@ -146,15 +147,15 @@ def __init__(self, @classmethod def from_database(cls, - params: List[TensorType], - result_idx: List[int], + params: list[TensorType], + result_idx: list[int], target: str, - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + func_or_mod: tir.PrimFunc | tvm.IRModule, kernel_global_source: str, kernel_lib_path: str, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -205,7 +206,7 @@ def from_database(cls, adapter._post_init() return adapter - def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]: """Extract information about dynamic shapes from the TIR function. Maps symbolic variables to their corresponding (id, buffer_index, dimension) @@ -232,7 +233,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]: dynamic_symbolic_map[stride] = (1, i, j) return dynamic_symbolic_map - def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]: + def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]: """Extract information about buffer dtypes from the TIR function. Maps buffer variables to their corresponding dtypes. @@ -248,7 +249,7 @@ def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]: buffer_dtype_map[name] = (i, map_torch_type(dtype)) return buffer_dtype_map - def _process_ptr_map(self) -> Dict[int, str]: + def _process_ptr_map(self) -> dict[int, str]: """Extract information about pointer arguments from the TIR function. Maps pointer arguments to their corresponding (buffer_index, shape_dimension) @@ -263,9 +264,9 @@ def _process_ptr_map(self) -> Dict[int, str]: return ptr_map def _process_static_buffer_infos(self) -> \ - Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]], - Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]], - List[Tuple[tir.Var]]]: + tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], + dict[tir.Var, tuple[int, list[tuple[int, int]]]], + list[tuple[tir.Var]]]: """Extract information about static shapes from the TIR function. Maps buffer variables to their corresponding static shapes. @@ -300,7 +301,7 @@ def _process_static_buffer_infos(self) -> \ static_contiguous_list.append((i, buffer.name)) return static_shape_map, static_strides_map, static_contiguous_list - def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]: + def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]: """Extract information about buffer devices from the TIR function. Maps buffer variables to their corresponding devices. @@ -326,7 +327,7 @@ def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]: buffer_device_map[name] = (i, device) return buffer_device_map - def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel. Converts PyTorch tensor pointers to C void pointers for ctypes interface. diff --git a/tilelang/jit/adapter/dlpack.py b/tilelang/jit/adapter/dlpack.py index b45742433..9fa767f04 100644 --- a/tilelang/jit/adapter/dlpack.py +++ b/tilelang/jit/adapter/dlpack.py @@ -1,7 +1,7 @@ """The profiler and convert to torch utils""" +from __future__ import annotations import torch -from typing import List from tilelang.contrib.dlpack import to_pytorch_func from .base import BaseKernelAdapter @@ -11,7 +11,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter): def _convert_torch_func(self) -> callable: torch_func = to_pytorch_func(self.mod) - def func(*ins: List[torch.Tensor]): + def func(*ins: list[torch.Tensor]): if len(ins) + len(self.result_idx) != len(self.params): raise ValueError( f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py index 5d1143a67..1e33ec040 100644 --- a/tilelang/jit/adapter/libgen.py +++ b/tilelang/jit/adapter/libgen.py @@ -1,3 +1,4 @@ +from __future__ import annotations import ctypes import importlib import logging @@ -5,7 +6,7 @@ import os.path as osp import subprocess import tempfile -from typing import Any, Dict, Optional, List +from typing import Any from tvm.target import Target @@ -29,21 +30,21 @@ is_nvrtc_available = False -class LibraryGenerator(object): - srcpath: Optional[str] = None - libpath: Optional[str] = None - lib_code: Optional[str] = None - pass_configs: Optional[Dict[str, Any]] = None - compile_flags: Optional[List[str]] = None +class LibraryGenerator: + srcpath: str | None = None + libpath: str | None = None + lib_code: str | None = None + pass_configs: dict[str, Any] | None = None + compile_flags: list[str] | None = None def __init__(self, target: Target, verbose: bool = False): self.target = target self.verbose = verbose - def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None): + def assign_pass_configs(self, pass_configs: dict[str, Any] | None = None): self.pass_configs = pass_configs - def assign_compile_flags(self, compile_flags: Optional[List[str]] = None): + def assign_compile_flags(self, compile_flags: list[str] | None = None): if compile_flags is None: compile_flags = [] self.compile_flags = compile_flags @@ -52,7 +53,7 @@ def update_lib_code(self, lib_code: str): self.lib_code = lib_code # Assume currently we only support CUDA compilation - def load_lib(self, lib_path: Optional[str] = None): + def load_lib(self, lib_path: str | None = None): if lib_path is None: lib_path = self.libpath else: @@ -185,7 +186,7 @@ def set_src_path(self, srcpath): class PyLibraryGenerator(LibraryGenerator): - host_func: Optional[str] = None + host_func: str | None = None culib = None pymodule = None @@ -206,7 +207,7 @@ def import_from_file(module_name, file_path): def update_host_func(self, host_func: str): self.host_func = host_func - def load_lib(self, lib_path: Optional[str] = None): + def load_lib(self, lib_path: str | None = None): if lib_path is None: lib_path = self.libpath diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py index d1fd9d421..d6723a031 100644 --- a/tilelang/jit/adapter/nvrtc/adapter.py +++ b/tilelang/jit/adapter/nvrtc/adapter.py @@ -1,5 +1,6 @@ +from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable import torch from tvm import tir @@ -26,16 +27,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter): kernels = {} def __init__(self, - params: List[KernelParam], - result_idx: List[int], - target: Union[str, Target], - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], - host_mod: Optional[tvm.IRModule] = None, - device_mod: Optional[tvm.IRModule] = None, - kernel_global_source: Optional[str] = None, + params: list[KernelParam], + result_idx: list[int], + target: str | Target, + func_or_mod: tir.PrimFunc | tvm.IRModule, + host_mod: tvm.IRModule | None = None, + device_mod: tvm.IRModule | None = None, + kernel_global_source: str | None = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): check_nvrtc_available() @@ -91,15 +92,15 @@ def __init__(self, @classmethod def from_database(cls, - params: List[KernelParam], - result_idx: List[int], + params: list[KernelParam], + result_idx: list[int], target: str, - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + func_or_mod: tir.PrimFunc | tvm.IRModule, kernel_global_source: str, kernel_lib_path: str, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None): + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None): adapter = cls.__new__(cls) adapter.params = params adapter.result_idx = adapter._legalize_result_idx(result_idx) @@ -143,7 +144,7 @@ def from_database(cls, adapter._post_init() return adapter - def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: + def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]: """Extract information about dynamic shapes from the TIR function. Maps symbolic variables to their corresponding (buffer_index, shape_dimension) @@ -165,7 +166,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]: dynamic_symbolic_map[shape] = (i, j) return dynamic_symbolic_map - def get_kernel_source(self) -> Optional[str]: + def get_kernel_source(self) -> str | None: """Get the CUDA kernel source code. Returns @@ -175,14 +176,12 @@ def get_kernel_source(self) -> Optional[str]: """ return self.kernel_global_source - def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): + def _forward_from_prebuild_lib(self, *args, stream: int | None = None): """Low-level function to call the compiled CUDA kernel. """ return self.pymodule.call(self.kernels, *args, stream=stream) - def _wrap_forward_from_prebuild_lib(self, - *ins: List[torch.Tensor], - stream: Optional[int] = None): + def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None): """High-level wrapper for kernel execution. Handles: @@ -242,7 +241,7 @@ def _wrap_forward_from_prebuild_lib(self, else: return [args[i] for i in self.result_idx] - def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]: + def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]: """Convert to a PyTorch-compatible function. Returns diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py index 9693fca06..30e84ad71 100644 --- a/tilelang/jit/adapter/torch/metal.py +++ b/tilelang/jit/adapter/torch/metal.py @@ -1,5 +1,6 @@ +from __future__ import annotations from functools import wraps -from typing import Callable, Optional, Union, List +from typing import Callable import torch from tvm import tir @@ -14,13 +15,13 @@ class MetalKernelAdapter(BaseKernelAdapter): def __init__( self, - params: List[KernelParam], - result_idx: List[int], + params: list[KernelParam], + result_idx: list[int], # target: Union[str, Target], - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], + func_or_mod: tir.PrimFunc | tvm.IRModule, # host_mod: Optional[tvm.IRModule] = None, - device_mod: Optional[tvm.IRModule] = None, - kernel_global_source: Optional[str] = None, + device_mod: tvm.IRModule | None = None, + kernel_global_source: str | None = None, verbose: bool = False, # pass_configs: Optional[Dict[str, Any]] = None, # compile_flags: Optional[List[str]] = None diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py index 6a09d6f6f..efc965e1b 100644 --- a/tilelang/jit/adapter/utils.py +++ b/tilelang/jit/adapter/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Union, Optional, Literal, Dict +from typing import Literal from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target @@ -65,11 +65,11 @@ def is_metal_target(target: Target) -> bool: def get_annotated_mod( - func_or_mod: Union[tir.PrimFunc, tvm.IRModule], - target: Union[str, Target] = "auto", - target_host: Optional[Union[str, Target]] = None, + func_or_mod: tir.PrimFunc | tvm.IRModule, + target: str | Target = "auto", + target_host: str | Target | None = None, model_type: Literal["device", "host", "all"] = "all", -) -> Union[IRModule, tuple[IRModule, IRModule]]: +) -> IRModule | tuple[IRModule, IRModule]: # Validate model_type early if model_type not in {"device", "host", "all"}: @@ -107,7 +107,7 @@ def get_annotated_mod( return dispatch[model_type](mod) -def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str: +def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str: """ Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence. diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py index 9c032826f..235e70135 100644 --- a/tilelang/jit/adapter/wrapper.py +++ b/tilelang/jit/adapter/wrapper.py @@ -1,6 +1,7 @@ +from __future__ import annotations from abc import ABC, abstractmethod from tilelang import tvm as tvm -from typing import Optional, List, Dict, Union, Any +from typing import Any from tvm import IRModule from tvm.target import Target from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target, @@ -176,7 +177,7 @@ def wrap(self, *args, **kwargs): logger = logging.getLogger(__name__) -class TLCUDASourceWrapper(object): +class TLCUDASourceWrapper: _TYPE_MAP = { "float32": "float", "float16": "half_t", @@ -196,33 +197,33 @@ class TLCUDASourceWrapper(object): } backend = "tl" - device_mod: Optional[IRModule] = None - host_mod: Optional[IRModule] = None - pass_configs: Optional[Dict[str, Any]] = None + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target, - device_mod: Optional[IRModule] = None, - host_mod: Optional[IRModule] = None, - pass_configs: Optional[Dict[str, Any]] = None): + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): self.mod = scheduled_ir_module self.target = target self.source = source self.pass_configs = pass_configs self.device_mod = device_mod self.host_mod = host_mod - self.function_names: Optional[str] = None - self.dynamic_smem_buf: Optional[int] = None - self.block_info: Union[List[int], Dict] = [1, 1, 1] - self.grid_info: Union[List[int], Dict] = [1, 1, 1] - self.tma_descriptor_args: Optional[Dict] = None - self.l2_persistent_map: Optional[Dict[str, Dict]] = {} + self.function_names: str | None = None + self.dynamic_smem_buf: int | None = None + self.block_info: list[int] | dict = [1, 1, 1] + self.grid_info: list[int] | dict = [1, 1, 1] + self.tma_descriptor_args: dict | None = None + self.l2_persistent_map: dict[str, dict] | None = {} self.parse_source_information() - self.srcpath: Optional[str] = None - self.libpath: Optional[str] = None - self.lib_code: Optional[str] = self.update_lib_code(source) + self.srcpath: str | None = None + self.libpath: str | None = None + self.lib_code: str | None = self.update_lib_code(source) def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str: return pythonic_expr(expr, self._TYPE_MAP) @@ -264,10 +265,10 @@ def create_dispatch_func(self, code, function_informations): def func_call_args(s, function_args, function_params, - desc_name_map: Optional[Dict[str, str]] = None, - desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None): + desc_name_map: dict[str, str] | None = None, + desc_name_var_map: dict[str, tvm.tir.Var] | None = None): # Extract the function call arguments matching the function definition - def maybe_desc(name: str, matches: List[str], i: int): + def maybe_desc(name: str, matches: list[str], i: int): match = matches[i] if not (match == name + "_desc" or match.startswith(name + "_desc_")): return False @@ -305,8 +306,8 @@ def maybe_desc(name: str, matches: List[str], i: int): kernel_launch_code = """""" if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE - desc_name_map: Dict[str, str] = {} - desc_name_var_map: Dict[str, tvm.tir.Var] = {} + desc_name_map: dict[str, str] = {} + desc_name_var_map: dict[str, tvm.tir.Var] = {} for function_name, function_info in function_informations.items(): block_info = function_info["block_info"] grid_info = function_info["grid_info"] @@ -322,14 +323,8 @@ def maybe_desc(name: str, matches: List[str], i: int): # Identify the start of the function body to insert arguments index = code.index("{", index) - block_str = "dim3({}, {}, {})".format( - self._pythonic_expr(block_info[0]), - self._pythonic_expr(block_info[1]), - self._pythonic_expr(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]), - self._pythonic_expr(grid_info[2])) + block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})" + grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})" smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf init_l2_persistent_map = self.generate_l2_persistent_map(function_name) kernel_launch_code += init_l2_persistent_map @@ -353,9 +348,8 @@ def maybe_desc(name: str, matches: List[str], i: int): args_list ), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments" call_args = ", ".join(args_list) - kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format( - function_name, grid_str, block_str, smem_str, call_args) - kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name) + kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n" + kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n" if has_l2_persistent_map: kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE @@ -386,8 +380,8 @@ def generate_l2_persistent_map(self, function_name: str) -> str: return init_l2_persistent_map - def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str], - desc_name_var_map: Dict[str, tvm.tir.Var]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str], + desc_name_var_map: dict[str, tvm.tir.Var]) -> str: tma_descripter_init = "" if self.tma_descriptor_args is None: return tma_descripter_init @@ -512,7 +506,7 @@ def parse_source_information(self): def get_dynamic_symbolic_set(self, prim_func): # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set: List[str] = [] + dynamic_symbolic_set: list[str] = [] def unique_push_back(name: str): if name not in dynamic_symbolic_set: @@ -565,7 +559,7 @@ def update_lib_code(self, code: str): assert function_name in self.device_mod, f"Function {function_name} not found in device module" device_func = self.device_mod[function_name] kernel_params_cnt = len(device_func.params) - function_params: List[str] = None + function_params: list[str] = None def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): nonlocal function_params @@ -599,7 +593,7 @@ def visitor(node, fn=function_name, param_cnt=kernel_params_cnt): lib_code = self.source + init_func + host_func return lib_code - def get_stream_type(self) -> Dict[str, str]: + def get_stream_type(self) -> dict[str, str]: return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"} @property @@ -669,9 +663,9 @@ def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target, - device_mod: Optional[IRModule] = None, - host_mod: Optional[IRModule] = None, - pass_configs: Optional[Dict[str, Any]] = None): + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) def create_dispatch_func(self, code, function_informations): @@ -701,9 +695,9 @@ def create_dispatch_func(self, code, function_informations): # Format the function arguments for declaration def_args = ", ".join([f"{arg['name']}" for arg in function_args]) - def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None): + def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None): # Extract the function call arguments matching the function definition - def maybe_desc(name: str, matches: List[str], i: int): + def maybe_desc(name: str, matches: list[str], i: int): match = matches[i] if not (match == name + "_desc" or match.startswith(name + "_desc_")): return False @@ -729,7 +723,7 @@ def maybe_desc(name: str, matches: List[str], i: int): call_args.append((match, "None")) return call_args - desc_name_map: Dict[str, str] = {} + desc_name_map: dict[str, str] = {} device_index = 0 kernel_launch_code = """""" for function_name, function_info in function_informations.items(): @@ -766,7 +760,7 @@ def maybe_desc(name: str, matches: List[str], i: int): repr(list(function_informations.keys())), def_args, kernel_launch_code) return host_func - def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: + def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str: tma_descripter_init = "" if self.tma_descriptor_args is None: return tma_descripter_init @@ -844,7 +838,7 @@ def update_lib_code(self, code: str): self.host_func = self.create_dispatch_func(code, function_informations) return self.lib_code - def get_stream_type(self) -> Dict[str, str]: + def get_stream_type(self) -> dict[str, str]: return {"name": "stream=0", "type": "int"} @@ -877,9 +871,9 @@ def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target, - device_mod: Optional[IRModule] = None, - host_mod: Optional[IRModule] = None, - pass_configs: Optional[Dict[str, Any]] = None): + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs) def get_init_func(self): @@ -895,11 +889,11 @@ def get_init_func(self): init_funcs = PREDEF_INIT_FUNC.format(call_str) return init_funcs - def get_stream_type(self) -> Dict[str, str]: + def get_stream_type(self) -> dict[str, str]: return {"name": "stream=hipStreamDefault", "type": "hipStream_t"} -class TLCPUSourceWrapper(object): +class TLCPUSourceWrapper: _TYPE_MAP = { "float32": "float", "float16": "half", @@ -925,29 +919,29 @@ class TLCPUSourceWrapper(object): """) backend = "tl" - device_mod: Optional[IRModule] = None - host_mod: Optional[IRModule] = None - pass_configs: Optional[Dict[str, Any]] = None + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target, - device_mod: Optional[IRModule] = None, - host_mod: Optional[IRModule] = None, - pass_configs: Optional[Dict[str, Any]] = None): + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): self.mod = scheduled_ir_module self.target = target self.source = source self.device_mod = device_mod self.host_mod = host_mod self.pass_configs = pass_configs - self.function_names: Optional[str] = None - self.dynamic_smem_buf: Optional[int] = None + self.function_names: str | None = None + self.dynamic_smem_buf: int | None = None self.parse_source_information() - self.srcpath: Optional[str] = None - self.libpath: Optional[str] = None - self.lib_code: Optional[str] = self.update_lib_code(source) + self.srcpath: str | None = None + self.libpath: str | None = None + self.lib_code: str | None = self.update_lib_code(source) def create_call_func(self, code, function_informations): # Extract the set of dynamic symbolic names used in the primary function @@ -997,7 +991,7 @@ def func_call_args(s, function_args): index = code.index("{", index) call_args = ", ".join(func_call_args(declaration, function_args)) - _call_str += "{}({})".format(function_name, call_args) + _call_str += f"{function_name}({call_args})" # Wrap the kernel dispatch logic in an external C function host_func = self.CALL_PREFIX.format(def_args, _call_str) @@ -1018,7 +1012,7 @@ def parse_source_information(self): def get_dynamic_symbolic_set(self, prim_func): # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set: List[str] = [] + dynamic_symbolic_set: list[str] = [] for param in prim_func.params: if param in prim_func.buffer_map: buffer = prim_func.buffer_map[param] @@ -1066,15 +1060,15 @@ def prim_func(self): raise ValueError("Cannot find primary function in the module.") -class TLMetalSourceWrapper(object): +class TLMetalSourceWrapper: def __init__(self, scheduled_ir_module: IRModule, source: str, target: Target, - device_mod: Optional[IRModule] = None, - host_mod: Optional[IRModule] = None, - pass_configs: Optional[Dict[str, Any]] = None): + device_mod: IRModule | None = None, + host_mod: IRModule | None = None, + pass_configs: dict[str, Any] | None = None): self.mod = scheduled_ir_module self.target = target self.source = source @@ -1092,11 +1086,11 @@ class TLWrapper(BaseWrapper): """ A wrapper class for the TileLang backend. """ - device_mod: Optional[IRModule] = None - host_mod: Optional[IRModule] = None - pass_configs: Optional[Dict[str, Any]] = None - target: Optional[Target] = None - lib: Optional[object] = None + device_mod: IRModule | None = None + host_mod: IRModule | None = None + pass_configs: dict[str, Any] | None = None + target: Target | None = None + lib: object | None = None def __init__(self, target: Target): super().__init__() @@ -1108,7 +1102,7 @@ def __init__(self, target: Target): def assign_optimized_module(self, scheduled_ir_module: IRModule): self.scheduled_ir_module = scheduled_ir_module - def assign_pass_configs(self, pass_configs: Dict[str, Any]): + def assign_pass_configs(self, pass_configs: dict[str, Any]): self.pass_configs = pass_configs def assign_host_module(self, host_mod: IRModule): diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py index 64fc7bdf1..71dafffb2 100644 --- a/tilelang/jit/kernel.py +++ b/tilelang/jit/kernel.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from __future__ import annotations +from typing import Any, Callable, Literal from tilelang.jit.adapter.utils import is_metal_target from tvm.target import Target @@ -17,7 +18,7 @@ logger = logging.getLogger(__name__) -class JITKernel(object): +class JITKernel: """ A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions. @@ -37,20 +38,20 @@ class JITKernel(object): # tuner result latency: float = None - config: Dict[str, Any] = None + config: dict[str, Any] = None ref_latency: float = None def __init__( self, func: PrimFunc = None, - out_idx: Union[List[int], int] = None, + out_idx: list[int] | int = None, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", - target: Union[str, Target] = "auto", - target_host: Union[str, Target] = None, + target: str | Target = "auto", + target_host: str | Target = None, verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None, + pass_configs: dict[str, Any] | None = None, from_database: bool = False, - compile_flags: Optional[List[str]] = None, + compile_flags: list[str] | None = None, ): """ Initializes a TorchFunction instance. @@ -138,13 +139,13 @@ def from_database( func: PrimFunc, kernel_global_source: str, kernel_lib_path: str, - params: List[KernelParam], - target: Union[str, Target], - target_host: Union[str, Target], - out_idx: Union[List[int], int], + params: list[KernelParam], + target: str | Target, + target_host: str | Target, + out_idx: list[int] | int, execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"], - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None, ): """ Alternative constructor to create a TorchFunction directly from a database. @@ -192,7 +193,7 @@ def __call__(self, *args: Any, **kwds: Any) -> Any: return self.torch_function(*args, **kwds) def _compile_and_create_adapter(self, tilelang_func: PrimFunc, - out_idx: List[int]) -> BaseKernelAdapter: + out_idx: list[int]) -> BaseKernelAdapter: """ Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter. @@ -295,16 +296,15 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc, return adapter - def _create_adapter_from_database( - self, - params: List[KernelParam], - result_idx: Union[List[int], int], - target: Union[str, Target], - func_or_mod: Union[PrimFunc, tvm.runtime.Module], - kernel_global_source: str, - kernel_lib_path: str, - pass_configs: Optional[Dict[str, Any]] = None, - compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter: + def _create_adapter_from_database(self, + params: list[KernelParam], + result_idx: list[int] | int, + target: str | Target, + func_or_mod: PrimFunc | tvm.runtime.Module, + kernel_global_source: str, + kernel_lib_path: str, + pass_configs: dict[str, Any] | None = None, + compile_flags: list[str] | None = None) -> BaseKernelAdapter: target = self.target execution_backend = self.execution_backend @@ -405,11 +405,11 @@ def get_host_source(self) -> str: """ return str(self.artifact.host_mod) - def run_once(self, func: Optional[Callable] = None) -> None: + def run_once(self, func: Callable | None = None) -> None: return self.get_profiler().run_once(func) - def update_tuner_result(self, latency: float, config: Dict[str, Any], - ref_latency: float) -> "JITKernel": + def update_tuner_result(self, latency: float, config: dict[str, Any], + ref_latency: float) -> JITKernel: """ Updates the tuning results for this kernel. @@ -432,7 +432,7 @@ def update_tuner_result(self, latency: float, config: Dict[str, Any], return self - def get_tuner_result(self) -> Dict[str, Any]: + def get_tuner_result(self) -> dict[str, Any]: """ Gets the tuning results for this kernel. @@ -454,11 +454,11 @@ def get_tuner_result(self) -> Dict[str, Any]: } @property - def out_idx(self) -> List[int]: + def out_idx(self) -> list[int]: return self.adapter.result_idx @property - def params(self) -> List[KernelParam]: + def params(self) -> list[KernelParam]: return self.artifact.params if self.artifact else self.adapter.params @property diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 994f338f2..1a26b53d0 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -1,6 +1,6 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import Optional # from .parser import * # now is fully compatible with the upstream # tir script @@ -90,6 +90,6 @@ ) -def import_source(source: Optional[str] = None): +def import_source(source: str | None = None): # source is the source code to be imported return block_attr({"pragma_import_c": source}) if source is not None else None diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index cee46ca2f..12d3af4d3 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -1,6 +1,7 @@ """Annotation helpers exposed on the TileLang language surface.""" +from __future__ import annotations -from typing import Callable, Dict +from typing import Callable from tilelang.layout import Layout from tvm.script.parser.tir import attr, block_attr @@ -21,7 +22,7 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True): return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>") -def annotate_layout(layout_map: Dict): +def annotate_layout(layout_map: dict): """Annotate the layout of the buffer.""" _layout_map = {} for buffer, layout in layout_map.items(): @@ -35,7 +36,7 @@ def annotate_layout(layout_map: Dict): return block_attr({"layout_map": _layout_map}) -def annotate_safe_value(safe_value_map: Dict): +def annotate_safe_value(safe_value_map: dict): """Annotate the safe value of the buffer.""" _safe_value_map = {} for buffer, safe_value in safe_value_map.items(): @@ -43,7 +44,7 @@ def annotate_safe_value(safe_value_map: Dict): return block_attr({"safe_value_map": _safe_value_map}) -def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict): +def annotate_l2_hit_ratio(l2_hit_ratio_map: dict): """Annotate the L2 hit ratio of the buffer.""" _l2_hit_ratio_map = {} for buffer, hit_ratio in l2_hit_ratio_map.items(): diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index eb2d18526..f1b37d236 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -1,11 +1,11 @@ # Copyright (c) Tile-AI Corporation. # Licensed under the MIT License. """Atomic operations for tilelang.""" +from __future__ import annotations import tilelang.language as T from tvm import ir, tir from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op -from typing import Optional from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region from tilelang.utils.language import get_buffer_region_from_load @@ -21,7 +21,7 @@ def atomic_max(dst: Buffer, value: PrimExpr, - memory_order: Optional[str] = None, + memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Perform an atomic maximum on the value stored at dst with an optional memory-order. @@ -67,7 +67,7 @@ def atomic_max(dst: Buffer, def atomic_min(dst: Buffer, value: PrimExpr, - memory_order: Optional[str] = None, + memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Atomically update the value at dst to the minimum of its current value and value. @@ -115,7 +115,7 @@ def atomic_min(dst: Buffer, def atomic_add(dst: Buffer, value: PrimExpr, - memory_order: Optional[str] = None, + memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr: """ diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index f9867f235..f0b223f46 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -1,17 +1,18 @@ """The language interface for tl programs.""" +from __future__ import annotations from tilelang import tvm as tvm from tilelang.language import ptx_arrive_barrier, evaluate from tilelang.language.kernel import get_thread_bindings, get_block_extents from tilelang.utils.target import check_hip_availability from tvm import tir -from typing import Union, Any, Optional +from typing import Any from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad _IS_HIP_AVAILABLE = check_hip_availability() -def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]: +def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None: """ Normalize warp sizing arguments so both Python ints and PrimExpr values are accepted uniformly. @@ -183,7 +184,7 @@ def disable_warp_group_reg_alloc(): return no_set_max_nreg() -def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]): +def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var): """Wait for memory barrier parity condition. Args: @@ -233,7 +234,7 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity) -def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]): +def mbarrier_arrive(mbarrier: int | PrimExpr | tir.Call): """Arrive at memory barrier. Args: @@ -294,7 +295,7 @@ def warpgroup_wait(num_mma: int): return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma) -def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: +def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: """Return the logical lane index of the calling thread within a warp. Parameters @@ -319,7 +320,7 @@ def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr) -def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: +def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr: """Return the canonical warp index, assuming the warp's threads are converged. Parameters @@ -343,7 +344,7 @@ def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> Prim return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr) -def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: +def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr: """Return the canonical warp index without synchronizing the warp. Parameters @@ -368,8 +369,8 @@ def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr: def get_warp_group_idx( - warp_size: Optional[Union[int, PrimExpr]] = None, - warps_per_group: Optional[Union[int, PrimExpr]] = None, + warp_size: int | PrimExpr | None = None, + warps_per_group: int | PrimExpr | None = None, ) -> PrimExpr: """Return the canonical warp group index for the calling thread. @@ -441,7 +442,7 @@ def wait_wgmma(id: int): return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id) -def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None): +def barrier_wait(barrier_id: int | PrimExpr | tir.Call, parity: int | Var | None = None): """Wait for a memory barrier to complete. Args: @@ -456,7 +457,7 @@ def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, return mbarrier_wait_parity(barrier_id, parity) -def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]): +def barrier_arrive(barrier_id: int | PrimExpr | tir.Call): """Arrive at a memory barrier. Args: @@ -466,7 +467,7 @@ def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]): return mbarrier_arrive(barrier_id) -def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): +def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): """Perform a shuffle operation with XOR offset. Args: @@ -483,7 +484,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset) -def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): +def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): """Perform a shuffle operation with down offset. Args: @@ -496,7 +497,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset) -def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]): +def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call): """Perform a shuffle operation with up offset. Args: @@ -601,7 +602,7 @@ def loop_break(): return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break")) -def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): +def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 0be3e21ac..84444b8c6 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -1,17 +1,18 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import Union, Optional, Literal +from typing import Literal from tilelang import language as T from tilelang.utils.language import get_buffer_region_from_load from tvm import ir, tir from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region -def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion], - dst: Union[tir.Buffer, tir.BufferLoad], - coalesced_width: Optional[int] = None, +def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, + dst: tir.Buffer | tir.BufferLoad, + coalesced_width: int | None = None, disable_tma: bool = False, - eviction_policy: Optional[Literal["evict_normal", "evict_first", "evict_last"]] = None): + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): """Copy data between memory regions. Args: @@ -94,8 +95,7 @@ def c2d_im2col(img: tir.Buffer, stride: int, dilation: int, pad: int, - eviction_policy: Optional[Literal["evict_normal", "evict_first", - "evict_last"]] = None): + eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None): """Perform im2col transformation for 2D convolution. Args: diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py index e31cce4a6..0830c22dc 100644 --- a/tilelang/language/customize.py +++ b/tilelang/language/customize.py @@ -1,8 +1,8 @@ """The language interface for tl programs.""" +from __future__ import annotations import tilelang.language as T from tvm.tir import PrimExpr, Buffer, op -from typing import List, Union from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401 @@ -36,7 +36,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr: return dst -def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: +def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer: """Reshapes the input buffer to the specified shape. Args: @@ -49,9 +49,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer: return T.Tensor(shape, src.dtype, src.data) -def view(src: Buffer, - shape: Union[List[PrimExpr], None] = None, - dtype: Union[str, None] = None) -> Buffer: +def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer: """ Return a Tensor view of the input buffer with an optional new shape and dtype. diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index 5cb6eb837..fc511c007 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -1,16 +1,16 @@ """The language interface for tl programs.""" +from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir -from typing import Union def gemm_sp( - A_sparse: Union[tir.Buffer, tir.Var], - E: Union[tir.Buffer, tir.Var], - B: Union[tir.Buffer, tir.Var], - C: Union[tir.Buffer, tir.Var], + A_sparse: tir.Buffer | tir.Var, + E: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, @@ -42,7 +42,7 @@ def gemm_sp( AssertionError: If the K dimensions of matrices A and B don't match """ - def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): + def legalize_arguments(arg: tir.Buffer | tir.Var): """Convert let-bound variables to their corresponding buffers. Args: diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index de6b3cff3..95ef26746 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -1,12 +1,12 @@ """The language interface for tl programs.""" +from __future__ import annotations from tvm import tir -from typing import Union from tilelang.language import has_let_value, get_let_value from tilelang.utils.language import get_buffer_region_from_load -def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): +def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr): """Fill a buffer or buffer region with a specified value. Args: @@ -21,7 +21,7 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr): return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value) -def clear(buffer: Union[tir.Buffer, tir.Var]): +def clear(buffer: tir.Buffer | tir.Var): """Clear a buffer by filling it with zeros. Args: diff --git a/tilelang/language/frame.py b/tilelang/language/frame.py index b82cfe5ef..8e6d59268 100644 --- a/tilelang/language/frame.py +++ b/tilelang/language/frame.py @@ -1,4 +1,5 @@ """Override the LetFrame to print a message when entering the frame.""" +from __future__ import annotations from tvm.ffi import register_object as _register_object from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion @@ -6,7 +7,6 @@ from tvm import DataType from tvm.script.ir_builder.tir.frame import TIRFrame from collections import deque -from typing import Optional import threading @@ -150,7 +150,7 @@ def __exit__(self, ptype, value, trace): super().__exit__(ptype, value, trace) @classmethod - def Current(cls) -> "LetFrame": + def Current(cls) -> LetFrame: """Get the current (topmost) let frame. Returns: @@ -198,7 +198,7 @@ def has_let_value(var: Var) -> bool: return _get_let_stack().has_value(var) -def get_let_value(var: Var) -> Optional[PrimExpr]: +def get_let_value(var: Var) -> PrimExpr | None: """Get the value bound to a variable in the current let frame stack. Args: diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 3c4aa5452..bb8dc6ce8 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -1,23 +1,23 @@ """The language interface for tl programs.""" +from __future__ import annotations from tilelang.primitives.gemm.base import GemmWarpPolicy import tilelang.language as T from tvm import tir -from typing import Union, List, Optional from tilelang.utils.language import get_buffer_region_from_load def gemm( - A: Union[tir.Buffer, tir.Var], - B: Union[tir.Buffer, tir.Var], - C: Union[tir.Buffer, tir.Var], + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0, - mbar: Optional[tir.Buffer] = None, + mbar: tir.Buffer | None = None, ): """Perform a General Matrix Multiplication (GEMM) operation. @@ -45,7 +45,7 @@ def gemm( AssertionError: If the K dimensions of matrices A and B don't match """ - def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): + def legalize_arguments(arg: tir.Buffer | tir.Var): """Convert let-bound variables to their corresponding buffers. Args: @@ -63,7 +63,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): C = legalize_arguments(C) mbar = legalize_arguments(mbar) if mbar is not None else None - def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]: if isinstance(object, tir.Buffer): return object.shape elif isinstance(object, tir.BufferRegion): @@ -82,7 +82,7 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: raise ValueError( f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}") - def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]: if isinstance(object, tir.Buffer): strides = [] stride = 1 @@ -137,8 +137,7 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: stride_a = A_stride[-2] stride_b = B_stride[-2] - def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], - access_type: str = "r") -> tir.PrimExpr: + def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr: if isinstance(object, tir.Buffer): return object.access_ptr(access_type) elif isinstance(object, tir.BufferRegion): @@ -175,7 +174,7 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], raise ValueError( f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}") - def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: + def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr: """Retrieve the offset of the buffer or buffer region.""" if isinstance(object, tir.Buffer): return [0] * len(object.shape) @@ -214,9 +213,9 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr # experimental currently, for fast compilation def gemm_v2( - A: Union[tir.Buffer, tir.Var], - B: Union[tir.Buffer, tir.Var], - C: Union[tir.Buffer, tir.Var], + A: tir.Buffer | tir.Var, + B: tir.Buffer | tir.Var, + C: tir.Buffer | tir.Var, transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, @@ -247,7 +246,7 @@ def gemm_v2( AssertionError: If the K dimensions of matrices A and B don't match """ - def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): + def legalize_arguments(arg: tir.Buffer | tir.Var): """Convert let-bound variables to their corresponding buffers. Args: @@ -264,7 +263,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): B = legalize_arguments(B) C = legalize_arguments(C) - def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]: if isinstance(object, tir.Buffer): return object.shape elif isinstance(object, tir.BufferRegion): @@ -283,7 +282,7 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: raise ValueError( f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}") - def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: + def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]: if isinstance(object, tir.Buffer): strides = [] stride = 1 @@ -338,8 +337,7 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]: stride_a = A_stride[-2] stride_b = B_stride[-2] - def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], - access_type: str = "r") -> tir.PrimExpr: + def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr: if isinstance(object, tir.Buffer): return object.access_ptr(access_type) elif isinstance(object, tir.BufferRegion): @@ -376,7 +374,7 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion], raise ValueError( f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}") - def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr: + def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr: """Retrieve the offset of the buffer or buffer region.""" if isinstance(object, tir.Buffer): return [0] * len(object.shape) diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py index 303e88a94..54b78d3d9 100644 --- a/tilelang/language/kernel.py +++ b/tilelang/language/kernel.py @@ -1,6 +1,6 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import Union, List, Tuple, Optional from collections import deque from tvm import tir from tvm.tir import Var @@ -80,7 +80,7 @@ def _get_current_stack() -> FrameStack: return _local.kernel_launch_frame_stack -def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]: +def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]: """ Return a bare Var when we only have a single binding so that users may write either `with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`. @@ -98,7 +98,7 @@ class KernelLaunchFrame(TIRFrame): and handles the entry and exit of the kernel launch scope. """ - def __enter__(self) -> Union[Var, List[Var]]: + def __enter__(self) -> Var | list[Var]: """ Enters the KernelLaunchFrame scope and pushes this frame onto the stack. Returns one Var if we detect exactly 5 frames (meaning there is a single @@ -132,7 +132,7 @@ def __exit__(self, ptype, value, trace): super().__exit__(ptype, value, trace) @classmethod - def Current(cls) -> Optional["KernelLaunchFrame"]: + def Current(cls) -> KernelLaunchFrame | None: """ Returns the topmost (current) KernelLaunchFrame from the stack if it exists, or None if the stack is empty. @@ -148,7 +148,7 @@ def get_block_extent(self, dim: int) -> int: iter_var = self.frames[dim].iter_var return int(iter_var.dom.extent) - def get_block_extents(self) -> List[int]: + def get_block_extents(self) -> list[int]: """ Returns the block extents for all three dimensions. """ @@ -162,7 +162,7 @@ def get_thread_extent(self, dim: int) -> int: iter_var = self.frames[-4 + dim].iter_var return int(iter_var.dom.extent) - def get_thread_extents(self) -> List[int]: + def get_thread_extents(self) -> list[int]: """ Returns the thread extents for all three dimensions. """ @@ -175,7 +175,7 @@ def get_thread_binding(self, dim: int = 0) -> Var: """ return self.frames[-4 + dim].iter_var.var - def get_thread_bindings(self) -> List[Var]: + def get_thread_bindings(self) -> list[Var]: """ Returns the thread binding for the given dimension. dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. @@ -198,21 +198,21 @@ def get_block_binding(self, dim: int = 0) -> Var: """ return self.frames[dim].iter_var.var - def get_block_bindings(self) -> List[Var]: + def get_block_bindings(self) -> list[Var]: """ Returns all three block bindings. """ return [frame.iter_var.var for frame in self.frames[0:-4]] @property - def blocks(self) -> List[Var]: + def blocks(self) -> list[Var]: """ Returns the block indices from the topmost frame. """ return [frame.iter_var.var for frame in self.frames[0:-4]] @property - def threads(self) -> List[Var]: + def threads(self) -> list[Var]: """ Returns the thread indices from the topmost frame. """ @@ -227,10 +227,10 @@ def num_threads(self) -> int: def Kernel( - *blocks: List[tir.PrimExpr], - threads: Optional[Union[int, List[int], Tuple]] = None, + *blocks: list[tir.PrimExpr], + threads: int | list[int] | tuple | None = None, is_cpu: bool = False, - prelude: Optional[str] = None, + prelude: str | None = None, ): """Tools to quickly construct a GPU kernel launch frame. @@ -310,7 +310,7 @@ def get_thread_binding(dim: int = 0) -> Var: return KernelLaunchFrame.Current().get_thread_binding(dim) -def get_thread_bindings() -> List[Var]: +def get_thread_bindings() -> list[Var]: """Returns all three thread bindings. """ assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" @@ -324,7 +324,7 @@ def get_block_binding(dim: int = 0) -> Var: return KernelLaunchFrame.Current().get_block_binding(dim) -def get_block_bindings() -> List[Var]: +def get_block_bindings() -> list[Var]: """Returns all three block bindings. """ assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" @@ -338,7 +338,7 @@ def get_thread_extent(dim: int = 0) -> int: return KernelLaunchFrame.Current().get_thread_extent(dim) -def get_thread_extents() -> List[int]: +def get_thread_extents() -> list[int]: """Returns all three thread extents. """ assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" @@ -352,7 +352,7 @@ def get_block_extent(dim: int = 0) -> int: return KernelLaunchFrame.Current().get_block_extent(dim) -def get_block_extents() -> List[int]: +def get_block_extents() -> list[int]: """Returns all three block extents. """ assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized" diff --git a/tilelang/language/logical.py b/tilelang/language/logical.py index a08627203..a09088e68 100644 --- a/tilelang/language/logical.py +++ b/tilelang/language/logical.py @@ -1,13 +1,13 @@ """The language interface for tl programs.""" +from __future__ import annotations from tilelang import language as T from tvm.tir import Buffer, BufferRegion, BufferLoad from tvm import tir -from typing import Union from tilelang.utils.language import get_buffer_elems -def any_of(buffer: Union[T.Tensor, BufferRegion]): +def any_of(buffer: T.Tensor | BufferRegion): """Check if any element in the buffer is true. Args: @@ -42,7 +42,7 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]): raise ValueError(f"Invalid buffer type: {type(buffer)}") -def all_of(buffer: Union[T.Tensor, BufferRegion]): +def all_of(buffer: T.Tensor | BufferRegion): """Check if all elements in the buffer are true. Args: diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py index 5a9343650..01d59b607 100644 --- a/tilelang/language/overrides/parser.py +++ b/tilelang/language/overrides/parser.py @@ -1,7 +1,7 @@ """TVMScript parser overrides tailored for TileLang.""" +from __future__ import annotations from functools import partial -from typing import Tuple from tvm.script.ir_builder import tir as T from tvm.script.parser._core import dispatch, doc @@ -10,7 +10,7 @@ from tvm.script.parser.tir import parser as tvm_tir_parser -def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]: +def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]: """Return the span (lineno, col, end_lineno, end_col) for a doc node.""" return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset) diff --git a/tilelang/language/parallel.py b/tilelang/language/parallel.py index a70846a62..8173675a8 100644 --- a/tilelang/language/parallel.py +++ b/tilelang/language/parallel.py @@ -1,11 +1,12 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import Optional, Dict, Any +from typing import Any from tvm import tir from tilelang import _ffi_api -def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None): +def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): """Tools to construct nested parallel for loop. This can be used to create element-wise tensor expression. @@ -22,7 +23,7 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None): res : frame.ForFrame The ForFrame. """ - annotations: Dict[str, Any] = {} + annotations: dict[str, Any] = {} if coalesced_width is not None: annotations.update({"coalesced_width": coalesced_width}) return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/tilelang/language/parser/operation.py b/tilelang/language/parser/operation.py index e16fa261b..43774947e 100644 --- a/tilelang/language/parser/operation.py +++ b/tilelang/language/parser/operation.py @@ -17,8 +17,7 @@ # This file is modified from the original version, # which is part of the TVM project (https://tvm.apache.org/). """The tir expression operation registration""" - -from typing import Type +from __future__ import annotations from tvm import tir from tvm.ffi.runtime_ctypes import DataType, DataTypeCode @@ -28,7 +27,7 @@ from tvm.script.parser._core import OpMethod, doc, register_op -def _register_expr_op(ty: Type): # pylint: disable=invalid-name +def _register_expr_op(ty: type): # pylint: disable=invalid-name ty._dispatch_type = ty # pylint: disable=protected-access def _and(a, b): @@ -115,7 +114,7 @@ def _gt(a, b): def _ge(a, b): return _auto_broadcast(a, b, tir.GE) - def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name + def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name register_op(ty, op, i)(m) for i in [0, 1]: diff --git a/tilelang/language/persistent.py b/tilelang/language/persistent.py index 1761cfa53..0ee7f112a 100644 --- a/tilelang/language/persistent.py +++ b/tilelang/language/persistent.py @@ -1,15 +1,15 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import List, Optional from tvm import tir from tilelang import _ffi_api def Persistent( - domain: List[tir.PrimExpr], + domain: list[tir.PrimExpr], wave_size: tir.PrimExpr, index: tir.PrimExpr, - group_size: Optional[tir.PrimExpr] = 8, + group_size: tir.PrimExpr | None = 8, ): """Tools to construct persistent for loop. diff --git a/tilelang/language/pipeline.py b/tilelang/language/pipeline.py index 85fd90cc0..895ed914a 100644 --- a/tilelang/language/pipeline.py +++ b/tilelang/language/pipeline.py @@ -1,6 +1,6 @@ """The language interface for tl programs.""" +from __future__ import annotations -from typing import List, Optional from tvm import tir from tvm.tir import IntImm from tilelang import _ffi_api @@ -10,10 +10,10 @@ def Pipelined( start: tir.PrimExpr, stop: tir.PrimExpr = None, num_stages: int = 0, - order: Optional[List[int]] = None, - stage: Optional[List[int]] = None, - sync: Optional[List[List[int]]] = None, - group: Optional[List[List[int]]] = None, + order: list[int] | None = None, + stage: list[int] | None = None, + sync: list[list[int]] | None = None, + group: list[list[int]] | None = None, ): """Tools to construct pipelined for loop. diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py index 83513f7a1..539c1d94c 100644 --- a/tilelang/language/proxy.py +++ b/tilelang/language/proxy.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" from __future__ import annotations -from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union +from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING from typing_extensions import Self from tvm import tir @@ -143,7 +143,7 @@ class TensorProxy(BaseTensorProxy): """ @staticmethod - def _construct_strides(shape: Tuple[Any]): + def _construct_strides(shape: tuple[Any]): s, strides = 1, [1] for dim in shape[:0:-1]: s *= dim @@ -151,7 +151,7 @@ def _construct_strides(shape: Tuple[Any]): return tuple(reversed(strides)) def __call__(self, - shape: Union[Tuple[Any], PrimExpr, int], + shape: tuple[Any] | PrimExpr | int, dtype: str = "float32", data=None, scope=None) -> tir.Buffer: @@ -172,8 +172,8 @@ class StridedTensorProxy(BaseTensorProxy): """ def __call__(self, - shape: Tuple[Any], - strides: Tuple[Any], + shape: tuple[Any], + strides: tuple[Any], dtype: str = "float32", scope=None) -> tir.Buffer: if len(shape) != len(strides): @@ -270,7 +270,7 @@ class LocalBuffer(BaseTensor): LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name -def ptr(dtype: Optional[str] = None, +def ptr(dtype: str | None = None, storage_scope: str = "global", *, is_size_var: bool = False) -> Var: diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 5cfca850b..55ac2bb0d 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -1,7 +1,7 @@ """The language interface for tl programs.""" +from __future__ import annotations from tvm import tir -from typing import Optional from tilelang.language import copy, macro, alloc_shared @@ -199,7 +199,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - copy(cumsum_smem, dst) -def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False): +def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse: bool = False): """ Compute the cumulative sum of `src` along `dim`, writing results to `dst`. diff --git a/tilelang/language/tir/entry.py b/tilelang/language/tir/entry.py index ade36b81c..22702ae43 100644 --- a/tilelang/language/tir/entry.py +++ b/tilelang/language/tir/entry.py @@ -1,14 +1,15 @@ +from __future__ import annotations import inspect -from typing import Callable, Optional, Union +from typing import Callable import tvm.script.parser.tir.entry as _tir_entry from tvm.tir.function import PrimFunc from tvm.script.parser._core import parse, scan_macro, utils -def prim_func(func: Optional[Callable] = None, +def prim_func(func: Callable | None = None, private: bool = False, - check_well_formed: bool = False) -> Union[PrimFunc, Callable]: + check_well_formed: bool = False) -> PrimFunc | Callable: """The parsing method for tir prim func, by using `@prim_func` as decorator. Parameters diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py index 1143f2a9e..0c0d167e0 100644 --- a/tilelang/language/tir/ir.py +++ b/tilelang/language/tir/ir.py @@ -1,7 +1,8 @@ +from __future__ import annotations import tvm.script.ir_builder.tir.ir as _ir from tvm.script.ir_builder.tir import frame from tvm.tir import PrimExpr -from typing import Any, Dict +from typing import Any import tilelang.language.tir.op as _tir_op import functools @@ -9,7 +10,7 @@ def serial(start: PrimExpr, stop: PrimExpr = None, *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] = None) -> frame.ForFrame: """The serial For statement. Parameters @@ -34,7 +35,7 @@ def serial(start: PrimExpr, def parallel(start: PrimExpr, stop: PrimExpr = None, *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] = None) -> frame.ForFrame: """The parallel For statement. Parameters @@ -59,7 +60,7 @@ def parallel(start: PrimExpr, def vectorized(start: PrimExpr, stop: PrimExpr = None, *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] = None) -> frame.ForFrame: """The vectorized For statement. Parameters @@ -84,7 +85,7 @@ def vectorized(start: PrimExpr, def unroll(start: PrimExpr, stop: PrimExpr = None, *, - annotations: Dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] = None) -> frame.ForFrame: """The unrolled For statement. Parameters @@ -111,7 +112,7 @@ def thread_binding( stop: PrimExpr = None, thread: str = None, *, - annotations: Dict[str, Any] = None, + annotations: dict[str, Any] = None, ) -> frame.ForFrame: """The thread-binding For statement. diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index 10ca7ca93..925665609 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -1,4 +1,5 @@ -from typing import Any, Optional +from __future__ import annotations +from typing import Any import tvm from tvm.ir import PrimExpr from tvm.ir.base import Span @@ -1857,7 +1858,7 @@ def min_value(dtype, span=None): return _tvm_op.min_value(dtype, span) -def max_value(dtype: str, span: Optional[Span] = None) -> Any: +def max_value(dtype: str, span: Span | None = None) -> Any: """maximum value of dtype Parameters @@ -1876,7 +1877,7 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any: return _tvm_op.max_value(dtype, span) -def infinity(dtype: str, span: Optional[Span] = None) -> Any: +def infinity(dtype: str, span: Span | None = None) -> Any: """infinity value of dtype Parameters @@ -1895,7 +1896,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any: return _tvm_op.infinity(dtype, span) -def reinterpret(dtype, value, span: Optional[Span] = None) -> Any: +def reinterpret(dtype, value, span: Span | None = None) -> Any: """infinity value of dtype Parameters diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 9b21596bb..caed14aa4 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -1,5 +1,5 @@ +from __future__ import annotations from tilelang import tvm as tvm -from typing import List from tvm import tir from tvm.tir import PrimExpr, Buffer, BufferLoad, op from tilelang import language as T @@ -42,7 +42,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str): return region(T.BufferLoad(buffer, mins), access_type, *extents) -def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]): +def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]): """Convert a buffer load operation to a tile region descriptor. Args: @@ -69,7 +69,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str, - extents: List[tir.PrimExpr]): + extents: list[tir.PrimExpr]): """Convert a buffer region to a tile region descriptor. Args: @@ -88,7 +88,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents) -def index_to_coordinates(index, shape) -> List[PrimExpr]: +def index_to_coordinates(index, shape) -> list[PrimExpr]: """ Convert a flat (linear) index into multi-dimensional coordinates for a given shape. diff --git a/tilelang/language/warpgroup.py b/tilelang/language/warpgroup.py index 2e64d66fa..872d30010 100644 --- a/tilelang/language/warpgroup.py +++ b/tilelang/language/warpgroup.py @@ -1,10 +1,10 @@ """The language interface for tl programs.""" +from __future__ import annotations from tvm.script.ir_builder.tir.frame import TIRFrame from tvm.ffi import register_object from tilelang import _ffi_api from .kernel import get_thread_bindings, get_thread_extents -from typing import List @register_object("tl.WarpSpecializeFrame") @@ -45,7 +45,7 @@ def WarpSpecialize(*warp_group_idx): # only available for nvidia gpus. warp_group_size = 128 - warp_group_ids: List[int] = [] + warp_group_ids: list[int] = [] for warp_group_id in warp_group_idx: warp_group_ids.append(warp_group_id) diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py index b26affaa2..b9c2b10ec 100644 --- a/tilelang/layout/fragment.py +++ b/tilelang/layout/fragment.py @@ -1,12 +1,12 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation +from __future__ import annotations import tvm from tvm.ir import Range from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tilelang import _ffi_api from tilelang.layout import Layout -from typing import List @tvm.ffi.register_object("tl.Fragment") @@ -123,7 +123,7 @@ def get_thread_size(self): def repeat(self, repeats, repeat_on_thread: bool = False, - lower_dim_first: bool = True) -> "Fragment": + lower_dim_first: bool = True) -> Fragment: """ Returns a new Fragment that repeats the iteration space a given number of times. @@ -143,7 +143,7 @@ def repeat(self, """ return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first) - def replicate(self, replicate: int) -> "Fragment": + def replicate(self, replicate: int) -> Fragment: """ Replicate the Fragment across a new thread dimension. @@ -159,7 +159,7 @@ def replicate(self, replicate: int) -> "Fragment": """ return _ffi_api.Fragment_replicate(self, replicate) - def condense_rep_var(self) -> "Fragment": + def condense_rep_var(self) -> Fragment: """ Condense or fold the replicate variable into the existing iteration space. This operation may be used to reduce dimensionality if the replicate variable @@ -172,7 +172,7 @@ def condense_rep_var(self) -> "Fragment": """ return _ffi_api.Fragment_condense_rep_var(self) - def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr: + def map_forward_thread(self, indices: list[PrimExpr]) -> PrimExpr: """ Get the thread mapping expression for a given set of argument indices. @@ -206,7 +206,7 @@ def __repr__(self): """ return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>" - def is_equal(self, other: "Fragment") -> bool: + def is_equal(self, other: Fragment) -> bool: """ Check if the current fragment is equal to another fragment. """ diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py index 1417d1b73..2fd58cd2e 100644 --- a/tilelang/layout/gemm_sp.py +++ b/tilelang/layout/gemm_sp.py @@ -1,17 +1,16 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation +from __future__ import annotations -from typing import Optional import tvm import tilelang.language as T import warnings from tilelang.contrib import nvcc -from typing import List from math import prod -def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]: +def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]: res = [] for x in basis: res.append(index_1d % x) @@ -136,7 +135,7 @@ def ColumnMajorInterleaved(i: int, j: int) -> int: def make_metadata_layout(buffer: tvm.tir.Buffer, mma_dtype: str = "float16", backend: str = 'cutlass', - arch: Optional[str] = None, + arch: str | None = None, **extra_args): if arch is None: arch = nvcc.get_target_compute_version() diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py index fd8e31225..dd0f11709 100644 --- a/tilelang/layout/layout.py +++ b/tilelang/layout/layout.py @@ -1,11 +1,11 @@ """Wrapping Layouts.""" # pylint: disable=invalid-name, unsupported-binary-operation +from __future__ import annotations import tvm from tvm.ir import Node, Range from tvm.tir import IterVar, Var, PrimExpr, IndexMap from tilelang import _ffi_api -from typing import List # Register the Layout class as a TVM object under the name "tl.Layout" @@ -92,7 +92,7 @@ def get_forward_vars(self): def get_forward_index(self): return self.index - def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr: + def map_forward_index(self, indices: list[PrimExpr]) -> PrimExpr: """ Compute the forward index mapping for a given set of input indices. @@ -122,7 +122,7 @@ def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr: # Map the provided indices using the constructed index mapping return index_map.map_indices(indices) - def inverse(self) -> "Layout": + def inverse(self) -> Layout: """ Compute the inverse of the current layout transformation. @@ -133,7 +133,7 @@ def inverse(self) -> "Layout": """ return _ffi_api.Layout_inverse(self) - def is_equal(self, other: "Layout") -> bool: + def is_equal(self, other: Layout) -> bool: """ Check if the current layout is equal to another layout. diff --git a/tilelang/primitives/gemm/__init__.py b/tilelang/primitives/gemm/__init__.py index 64f108957..ee9436d15 100644 --- a/tilelang/primitives/gemm/__init__.py +++ b/tilelang/primitives/gemm/__init__.py @@ -1,4 +1,5 @@ -from typing import Optional +from __future__ import annotations + from tvm import tir from tilelang.utils import is_local, is_fragment, is_shared from tilelang.primitives.gemm.base import GemmWarpPolicy @@ -12,11 +13,11 @@ def gemm( C: tir.Buffer, transpose_A: bool = False, transpose_B: bool = False, - block_row_warps: Optional[int] = None, - block_col_warps: Optional[int] = None, - warp_row_tiles: Optional[int] = None, - warp_col_tiles: Optional[int] = None, - chunk: Optional[int] = None, + block_row_warps: int | None = None, + block_col_warps: int | None = None, + warp_row_tiles: int | None = None, + warp_col_tiles: int | None = None, + chunk: int | None = None, policy: GemmWarpPolicy = GemmWarpPolicy.Square, k_pack: int = 1, ): diff --git a/tilelang/primitives/gemm/base.py b/tilelang/primitives/gemm/base.py index d79961635..827ff78f9 100644 --- a/tilelang/primitives/gemm/base.py +++ b/tilelang/primitives/gemm/base.py @@ -1,7 +1,7 @@ +from __future__ import annotations from enum import IntEnum from dataclasses import dataclass -from typing import Optional from tvm import tir @@ -161,7 +161,7 @@ def compute_warp_partition(self, M, N, num_warps): return m_warp, n_warp @classmethod - def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy': + def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy: """ Determine the warp policy based on the given warp partitioning. @@ -197,11 +197,11 @@ class GemmBaseParams: transpose_A: bool = False transpose_B: bool = False - block_row_warps: Optional[int] = None - block_col_warps: Optional[int] = None - warp_row_tiles: Optional[int] = None - warp_col_tiles: Optional[int] = None - chunk: Optional[int] = None + block_row_warps: int | None = None + block_col_warps: int | None = None + warp_row_tiles: int | None = None + warp_col_tiles: int | None = None + chunk: int | None = None policy: GemmWarpPolicy = GemmWarpPolicy.Square, k_pack: int = 1 @@ -226,7 +226,7 @@ def params_as_dict(self): "k_pack": self.k_pack, } - def infer_block_partition(self, threads: Optional[int]) -> None: + def infer_block_partition(self, threads: int | None) -> None: """ Infer and set block partition parameters (e.g., block_row_warps, block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py index 4f4f710d0..c681ee976 100644 --- a/tilelang/profiler/__init__.py +++ b/tilelang/profiler/__init__.py @@ -1,6 +1,7 @@ """The profiler and convert to torch utils""" +from __future__ import annotations -from typing import List, Optional, Callable, Any, Literal +from typing import Callable, Any, Literal from functools import partial import torch from contextlib import suppress @@ -28,17 +29,17 @@ class Profiler: adapter: Optional kernel adapter for interfacing with different backends """ - params: List[KernelParam] - result_idx: List[int] + params: list[KernelParam] + result_idx: list[int] supply_type: TensorSupplyType - adapter: Optional[BaseKernelAdapter] = None + adapter: BaseKernelAdapter | None = None def __post_init__(self): """Initialize tensor supply after dataclass initialization""" self.result_idx = self._legalize_result_idx(self.result_idx) self.supply = get_tensor_supply(self.supply_type) - def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]: + def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]: params = self.params # result_idx is a list of indices of the output tensors if result_idx is None: @@ -55,7 +56,7 @@ def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[i return result_idx - def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler": + def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler: self.adapter = adapter return self @@ -76,7 +77,7 @@ def _get_params(self, with_output=False): def assert_allclose( self, reference_program: Callable, - input_tensors: Optional[List[torch.Tensor]] = None, + input_tensors: list[torch.Tensor] | None = None, atol: float = 1e-2, rtol: float = 1e-2, max_mismatched_ratio=0.01, @@ -147,7 +148,7 @@ def is_float8(tensor: torch.Tensor) -> bool: def manual_assert_close( self, reference_program: Callable, - input_tensors: Optional[List[torch.Tensor]] = None, + input_tensors: list[torch.Tensor] | None = None, manual_check_prog: Callable = None, ): """Validates kernel output against a reference implementation. @@ -194,13 +195,13 @@ def assert_consistent(self, repeat=10): rhs, ] - def run_once(self, func: Optional[Callable] = None): + def run_once(self, func: Callable | None = None): ins = self._get_inputs() if not func: func = self.__call__ return func(*ins) - def determine_profiler(self, func: Optional[Callable] = None): + def determine_profiler(self, func: Callable | None = None): """Determines which profiler backend to use based on function type. Args: @@ -217,14 +218,14 @@ def determine_profiler(self, func: Optional[Callable] = None): def do_bench( self, - func: Optional[Callable] = None, + func: Callable | None = None, warmup: int = 25, rep: int = 100, n_warmup: int = 1, n_repeat: int = 1, - input_tensors: List[torch.Tensor] = None, + input_tensors: list[torch.Tensor] = None, backend: Literal["event", "cupti"] = "event", - quantiles: Optional[List[float]] = None, + quantiles: list[float] | None = None, return_mode: Literal["min", "max", "mean", "median"] = "mean", ) -> float: """Benchmarks the execution time of a given function. diff --git a/tilelang/profiler/bench.py b/tilelang/profiler/bench.py index d6f8c0820..a851ceb3d 100644 --- a/tilelang/profiler/bench.py +++ b/tilelang/profiler/bench.py @@ -1,8 +1,9 @@ """Profiler and benchmarking utilities for PyTorch functions.""" +from __future__ import annotations import os import sys -from typing import Callable, List, Literal, Optional, Union +from typing import Callable, Literal import torch @@ -65,11 +66,11 @@ def do_bench( rep: float = 100, _n_warmup: int = 0, _n_repeat: int = 0, - quantiles: Optional[List[float]] = None, + quantiles: list[float] | None = None, fast_flush: bool = True, backend: Literal["event", "cupti"] = "event", return_mode: Literal["min", "max", "mean", "median"] = "mean", -) -> Union[float, List[float]]: +) -> float | list[float]: """Benchmark the runtime of a PyTorch function with L2 cache management. This function provides accurate GPU kernel timing by: @@ -138,9 +139,9 @@ def _bench_with_cuda_events( fn: Callable, cache: torch.Tensor, n_repeat: int, - quantiles: Optional[List[float]], + quantiles: list[float] | None, return_mode: str, -) -> Union[float, List[float]]: +) -> float | list[float]: """Benchmark using CUDA events for timing.""" # Create timing events start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py index f1bc6910f..47d91f056 100644 --- a/tilelang/quantize/lop3.py +++ b/tilelang/quantize/lop3.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Dict, Literal +from __future__ import annotations +from typing import Literal decode_i4_to_f16 = """ template @@ -1096,7 +1097,7 @@ def get_lop3_intrin_group( with_zeros: bool = False, zeros_mode: Literal["original", "rescale", "quantized"] = "original", storage_scope: str = "local", -) -> Dict[str, str]: +) -> dict[str, str]: """ This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of @@ -1186,9 +1187,9 @@ def get_lop3_intrin_group( elif out_dtype == "int4": d4f = "i4s" else: - raise ValueError("Unsupported target dtype: {}".format(target_dtype)) + raise ValueError(f"Unsupported target dtype: {target_dtype}") source_symbol = "u" if source_format == "uint" else "s" - func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) + func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}" if with_scaling: func_name += "_scale" if with_zeros: diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py index 552f3db3c..0425c549d 100644 --- a/tilelang/quantize/mxfp.py +++ b/tilelang/quantize/mxfp.py @@ -1,4 +1,5 @@ -from typing import Literal, Dict +from __future__ import annotations +from typing import Literal # Implementation asm for fp4 to bf16, using twiddling # Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18 @@ -54,7 +55,7 @@ def get_mxfp_intrin_group( source_bit: int = 4, storage_dtype: Literal["int32", "int8", "uint8"] = "uint8", use_twiddling: bool = False, -) -> Dict[str, str]: +) -> dict[str, str]: """ Return metadata for an MXFP decoding intrinsic: function name and C source string. diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py index bc0ea47bf..db9d2349d 100644 --- a/tilelang/quantize/quantization.py +++ b/tilelang/quantize/quantization.py @@ -223,7 +223,7 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): e4 = val & tir.const(0x40, "uint16") prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), tir.const(0x4000, "uint16")) - e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix + e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix return tir.reinterpret("float16", s_f16 | e_f16) @@ -232,7 +232,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert dtype == "float16" s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") e4 = val & tir.const(0x40, "uint16") - e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) + e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) e_f16 = e_f16 ^ tir.const(0x2000, "uint16") return tir.reinterpret("float16", s_f16 | e_f16) diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py index 849b6d33a..4968b09f4 100644 --- a/tilelang/tileop/gemm/gemm_base.py +++ b/tilelang/tileop/gemm/gemm_base.py @@ -9,7 +9,7 @@ @dataclass -class GemmBase(object): +class GemmBase: gemm_node: Node def infer_layout(self, target: Target, thread_nums: int): diff --git a/tilelang/tools/Analyzer.py b/tilelang/tools/Analyzer.py index 379dfc119..205c647e3 100644 --- a/tilelang/tools/Analyzer.py +++ b/tilelang/tools/Analyzer.py @@ -1,9 +1,9 @@ +from __future__ import annotations import numpy as np from dataclasses import dataclass from tilelang import tvm from tvm.tir.stmt_functor import ir_transform import logging -from typing import Optional # Configuration for different hardware architectures. # Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count) ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)} @@ -168,7 +168,7 @@ def calculate(self) -> AnalysisResult: AnalysisResult: The calculated performance metrics. """ - def get_peak_tflops(device) -> Optional[float]: + def get_peak_tflops(device) -> float | None: """ Get the peak TFLOPS for the target device. Args: diff --git a/tilelang/transform/add_bufstore_wrapper.py b/tilelang/transform/add_bufstore_wrapper.py index 1b3b4cd4c..7ccab4707 100644 --- a/tilelang/transform/add_bufstore_wrapper.py +++ b/tilelang/transform/add_bufstore_wrapper.py @@ -1,7 +1,7 @@ +from __future__ import annotations from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm) from tvm.tir.stmt_functor import ir_transform, post_order_visit from tvm.tir.transform import prim_func_pass -from typing import Tuple, List, Dict def AddWrapperForSingleBufStore(): @@ -42,7 +42,7 @@ def visit_variable(node): post_order_visit(operation, visit_variable) return used_variables - def collect_buffer_accesses(statement) -> Tuple[List[Buffer], List[Buffer]]: + def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]: """ Categorizes buffers accessed in the statement by their scope. @@ -69,7 +69,7 @@ def visit_buffer_access(node): local_buffers.append(buffer) return local_buffers, fragment_buffers - def collect_buffer_indices(statement) -> Dict[Buffer, List[int]]: + def collect_buffer_indices(statement) -> dict[Buffer, list[int]]: """ Maps each buffer to its access indices. diff --git a/tilelang/transform/simplify.py b/tilelang/transform/simplify.py index 6b8fedfc3..7e0c5062b 100644 --- a/tilelang/transform/simplify.py +++ b/tilelang/transform/simplify.py @@ -1,7 +1,8 @@ +from __future__ import annotations from tilelang import tvm as tvm from tvm import IRModule from tvm.tir import PrimFunc -from typing import Union, Callable +from typing import Callable from . import _ffi_api @@ -27,8 +28,7 @@ def Simplify(simplify_arguments: bool = False): return _ffi_api.Simplify(simplify_arguments) # type: ignore -def _Simplify(stmt: Union[PrimFunc, IRModule], - inline_let: bool = False) -> Union[PrimFunc, IRModule]: +def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule: if isinstance(stmt, PrimFunc): if inline_let: mod = LetInline()(IRModule.from_expr(stmt)) @@ -53,13 +53,12 @@ def _Simplify(stmt: Union[PrimFunc, IRModule], def simplify_prim_func(func: Callable) -> Callable: def wrapper(*args, **kwargs): - stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs) + stmt: PrimFunc | IRModule = (func)(*args, **kwargs) return _Simplify(stmt) return wrapper -def apply_simplify(stmt: Union[PrimFunc, IRModule], - inline_let: bool = False) -> Union[PrimFunc, IRModule]: +def apply_simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule: """Apply Simplify pass to a PrimFunc or IRModule.""" return _Simplify(stmt, inline_let) diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py index 2c0b4efad..0972175a8 100644 --- a/tilelang/utils/language.py +++ b/tilelang/utils/language.py @@ -1,5 +1,5 @@ +from __future__ import annotations from tvm.tir import Buffer -from typing import List, Optional from functools import reduce from tvm import IRModule from tvm.tir import PrimFunc @@ -85,7 +85,7 @@ def get_buffer_elems(buffer: Buffer) -> int: return reduce(lambda x, y: x * y, buffer.shape) -def array_reduce(array: List[int]) -> int: +def array_reduce(array: list[int]) -> int: """ Reduce an array of integers to a single integer. @@ -121,7 +121,7 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc: return func -def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.BufferRegion]: +def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None: """ Get the buffer region from a buffer load. diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py index 22cd95f21..cd364b8bb 100644 --- a/tilelang/utils/sparse.py +++ b/tilelang/utils/sparse.py @@ -1,7 +1,7 @@ +from __future__ import annotations import os import torch import warnings -from typing import Optional, Tuple from tilelang.contrib import nvcc from torch.utils.cpp_extension import load, _import_module_from_library from tilelang import env @@ -44,7 +44,7 @@ def _get_cached_lib(): def compress_sm90(A: torch.Tensor, block_k: int, - transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: + transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: if block_k > 128: block_k = 128 # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 @@ -56,7 +56,7 @@ def compress_sm90(A: torch.Tensor, block_k: int, return compress_lib.compress_sm90(A, block_k, transposed) -def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: try: from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor except ImportError as err: @@ -75,8 +75,8 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torc def compress(A: torch.Tensor, transposed: bool, - arch: Optional[str] = None, - **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: + arch: str | None = None, + **kwargs) -> tuple[torch.Tensor, torch.Tensor]: """ Compress a tensor using the appropriate method based on the CUDA architecture. """ diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py index 948308b81..094c099fe 100644 --- a/tilelang/utils/target.py +++ b/tilelang/utils/target.py @@ -1,12 +1,13 @@ +from __future__ import annotations from platform import mac_ver -from typing import Dict, Literal, Union +from typing import Literal from tilelang import tvm as tvm from tilelang import _ffi_api from tvm.target import Target from tvm.contrib import rocm from tilelang.contrib import nvcc -SUPPORTED_TARGETS: Dict[str, str] = { +SUPPORTED_TARGETS: dict[str, str] = { "auto": "Auto-detect CUDA/HIP/Metal based on availability.", "cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).", "hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).", @@ -17,7 +18,7 @@ } -def describe_supported_targets() -> Dict[str, str]: +def describe_supported_targets() -> dict[str, str]: """ Return a mapping of supported target names to usage descriptions. """ @@ -58,8 +59,8 @@ def check_metal_availability() -> bool: return arch == 'arm64' -def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", - return_object: bool = False) -> Union[str, Target]: +def determine_target(target: str | Target | Literal["auto"] = "auto", + return_object: bool = False) -> str | Target: """ Determine the appropriate target for compilation (CUDA, HIP, or manual selection). @@ -76,7 +77,7 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto", AssertionError: If the target is invalid. """ - return_var: Union[str, Target] = target + return_var: str | Target = target if target == "auto": target = tvm.target.Target.current(allow_none=True) diff --git a/version_provider.py b/version_provider.py index c5aa42210..31a7e8ad5 100644 --- a/version_provider.py +++ b/version_provider.py @@ -3,7 +3,6 @@ import os import platform import subprocess -from typing import Optional from pathlib import Path ROOT = Path(__file__).parent @@ -17,13 +16,12 @@ def _read_cmake_bool(i: str | None, default=False): return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') -def get_git_commit_id() -> Optional[str]: +def get_git_commit_id() -> str | None: """Get the current git commit hash by running git in the current file's directory.""" r = subprocess.run(['git', 'rev-parse', 'HEAD'], cwd=ROOT, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, + capture_output=True, encoding='utf-8') if r.returncode == 0: return r.stdout.strip()