diff --git a/docs/conf.py b/docs/conf.py
index fde38c490..1b1289038 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,12 +1,10 @@
-# -*- coding: utf-8 -*-
-
# General information about the project.
project = "Tile Language
"
author = "Tile Lang Contributors"
-copyright = "2025-2025, %s" % author
+copyright = f"2025-2025, {author}"
# Version information.
-with open("../VERSION", "r") as f:
+with open("../VERSION") as f:
version = f.read().strip()
release = version
diff --git a/pyproject.toml b/pyproject.toml
index daa30406b..e76a267c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -87,6 +87,17 @@ target-version = "py38"
line-length = 100
output-format = "full"
+exclude = [
+ "3rdparty",
+ "examples/deepseek_v32/inference",
+]
+
+[tool.ruff.lint.per-file-ignores]
+# Do not upgrade type hint in testing and examples.
+# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
+"testing/**.py" = ["UP", "FA"]
+"examples/**.py" = ["UP", "FA"]
+
[tool.ruff.lint]
select = [
# pycodestyle
@@ -94,7 +105,7 @@ select = [
# Pyflakes
"F",
# pyupgrade
- # "UP",
+ "UP", "FA",
# flake8-bugbear
"B",
# flake8-simplify
@@ -115,6 +126,8 @@ ignore = [
"SIM108",
# key in dict.keys()
"SIM118",
+ # open file w.o. ctx manager
+ "SIM115",
# memory leaks
"B019",
# zip without explicit strict
@@ -122,9 +135,6 @@ ignore = [
# No such file or directory
"E902",
]
-[tool.ruff.lint.per-file-ignores]
-"3rdparty/**/*" = ["ALL"]
-"examples/deepseek_v32/inference/**/*" = ["ALL"]
[tool.pytest.ini_options]
verbosity_assertions = 3
diff --git a/tilelang/autotuner/capture.py b/tilelang/autotuner/capture.py
index 78f937de8..27c24f14e 100644
--- a/tilelang/autotuner/capture.py
+++ b/tilelang/autotuner/capture.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
import threading
-from typing import List, Any, Optional
+from typing import Any
# Use thread local to store the stack
# This is to avoid the cross-thread interference
@@ -87,7 +88,7 @@ class AutotuneInputsCapture:
__slots__ = ("tensors")
- def __init__(self, tensors: List[Any]):
+ def __init__(self, tensors: list[Any]):
self.tensors = tensors
def __enter__(self) -> None:
@@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture:
return AutotuneInputsCapture(tensors)
-def get_autotune_inputs() -> Optional[List[Any]]:
+def get_autotune_inputs() -> list[Any] | None:
"""
Get the current autotune inputs from the stack.
"""
diff --git a/tilelang/autotuner/param.py b/tilelang/autotuner/param.py
index aa8f6b9de..a486e9018 100644
--- a/tilelang/autotuner/param.py
+++ b/tilelang/autotuner/param.py
@@ -1,11 +1,12 @@
"""The auto-tune parameters.
"""
+from __future__ import annotations
import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
-from typing import Callable, List, Literal, Any, Optional, Union, Dict
+from typing import Callable, Literal, Any
from dataclasses import dataclass
from pathlib import Path
@@ -47,12 +48,12 @@ class CompileArgs:
"tl.disable_safe_memory_legalize": bool, default: False
"""
- out_idx: Optional[Union[List[int], int]] = None
+ out_idx: list[int] | int | None = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
- target_host: Union[str, Target] = None
+ target_host: str | Target = None
verbose: bool = False
- pass_configs: Optional[Dict[str, Any]] = None
+ pass_configs: dict[str, Any] | None = None
def compile_program(self, program: PrimFunc):
return tilelang.compile(
@@ -142,12 +143,12 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
- latency: Optional[float] = None
- config: Optional[dict] = None
- ref_latency: Optional[float] = None
- libcode: Optional[str] = None
- func: Optional[Callable] = None
- kernel: Optional[Callable] = None
+ latency: float | None = None
+ config: dict | None = None
+ ref_latency: float | None = None
+ libcode: str | None = None
+ func: Callable | None = None
+ kernel: Callable | None = None
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
"""
@@ -211,9 +212,9 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
def _load_kernel_from_disk(
self,
cache_path: Path,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
- out_idx: Optional[Union[List[int], int]] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
+ out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
@@ -239,14 +240,14 @@ def _load_kernel_from_disk(
if not os.path.exists(cache_path):
return None
- kernel_global_source: Optional[str] = None
- kernel_params: Optional[List[KernelParam]] = None
+ kernel_global_source: str | None = None
+ kernel_params: list[KernelParam] | None = None
try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
- with open(wrapped_kernel_path, "r") as f:
+ with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
@@ -307,7 +308,7 @@ def save_to_disk(self, path: Path, verbose: bool = False):
self._save_kernel_to_disk(path, self.kernel)
@classmethod
- def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
+ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult:
if not os.path.exists(path):
return None
@@ -315,7 +316,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul
# load best config
if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
- with open(path / BEST_CONFIG_PATH, "r") as f:
+ with open(path / BEST_CONFIG_PATH) as f:
config = json.load(f)
# load function
@@ -327,7 +328,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul
# load latency
if verbose:
logger.debug(f"Loading latency from file: {path / LATENCY_PATH}")
- with open(path / LATENCY_PATH, "r") as f:
+ with open(path / LATENCY_PATH) as f:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]
diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py
index 2173a1392..e94ac7466 100644
--- a/tilelang/autotuner/tuner.py
+++ b/tilelang/autotuner/tuner.py
@@ -3,6 +3,7 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
+from __future__ import annotations
import tilelang
from tilelang import tvm as tvm
@@ -10,7 +11,7 @@
from tvm.target import Target
import inspect
from functools import partial
-from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
+from typing import (Callable, Literal, Any, overload)
from tqdm import tqdm
import logging
import functools
@@ -103,8 +104,8 @@ class AutoTuner:
compile_args = CompileArgs()
profile_args = ProfileArgs()
- _kernel_parameters: Optional[Tuple[str, ...]] = None
- _function_parameters: Optional[Dict[str, Any]] = None
+ _kernel_parameters: tuple[str, ...] | None = None
+ _function_parameters: dict[str, Any] | None = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
@@ -131,12 +132,12 @@ def from_kernel(cls, kernel: Callable, configs):
return cls(kernel, configs)
def set_compile_args(self,
- out_idx: Union[List[int], int, None] = None,
+ out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
- target_host: Union[str, Target] = None,
+ target_host: str | Target = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None):
+ pass_configs: dict[str, Any] | None = None):
"""Set compilation arguments for the auto-tuner.
Args:
@@ -223,12 +224,12 @@ def set_profile_args(self,
return self
- def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
+ def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
# for cache key generation
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
- def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
+ def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process.
"""
@@ -307,8 +308,8 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
return result
best_latency: float = 1e8
- best_config: Optional[Dict[str, Any]] = None
- best_kernel: Optional[tilelang.JITKernel] = None
+ best_config: dict[str, Any] | None = None
+ best_kernel: tilelang.JITKernel | None = None
def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args
@@ -591,7 +592,7 @@ class _AutoTunerImplementation:
warmup: int = 25
rep: int = 100
timeout: int = 100
- configs: Union[Dict, Callable] = None
+ configs: dict | Callable = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
@@ -603,7 +604,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False
def __init__(self,
- configs: Union[Dict, Callable],
+ configs: dict | Callable,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
@@ -653,12 +654,12 @@ def __init__(self,
self.cache_input_tensors = cache_input_tensors # Reuse inputs
# Cache for storing tuned kernel implementations
- self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
+ self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
- def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
+ def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...
@overload
@@ -720,9 +721,9 @@ def jit_compile(**config_arg):
def autotune( # This is the new public interface
- func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
+ func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
- configs: Union[Dict, Callable],
+ configs: dict | Callable,
# profile arguments
warmup: int = 25,
rep: int = 100,
diff --git a/tilelang/cache/__init__.py b/tilelang/cache/__init__.py
index 72d003318..c338ce61d 100644
--- a/tilelang/cache/__init__.py
+++ b/tilelang/cache/__init__.py
@@ -1,6 +1,7 @@
"""The cache utils with class and database persistence - Init file"""
+from __future__ import annotations
-from typing import List, Union, Literal, Optional
+from typing import Literal
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
@@ -13,14 +14,14 @@
def cached(
func: PrimFunc = None,
- out_idx: List[int] = None,
+ out_idx: list[int] = None,
*args,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
- execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
- verbose: Optional[bool] = False,
- pass_configs: Optional[dict] = None,
- compile_flags: Optional[Union[List[str], str]] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
+ execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython",
+ verbose: bool | None = False,
+ pass_configs: dict | None = None,
+ compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels (using KernelCache class).
diff --git a/tilelang/cache/kernel_cache.py b/tilelang/cache/kernel_cache.py
index b6d2e77b7..d0a801fb4 100644
--- a/tilelang/cache/kernel_cache.py
+++ b/tilelang/cache/kernel_cache.py
@@ -1,4 +1,5 @@
"""The cache utils with class and database persistence - KernelCache Class"""
+from __future__ import annotations
import json
import logging
@@ -7,7 +8,7 @@
import threading
import uuid
from hashlib import sha256
-from typing import Callable, List, Literal, Optional, Union
+from typing import Callable, Literal
import cloudpickle
from tvm.target import Target
@@ -67,13 +68,13 @@ def _create_dirs():
def _generate_key(
self,
func: Callable,
- out_idx: List[int],
+ out_idx: list[int],
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
args=None,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
pass_configs: dict = None,
- compile_flags: Optional[Union[List[str], str]] = None,
+ compile_flags: list[str] | str | None = None,
) -> str:
"""
Generates a unique hash key for caching compiled kernels.
@@ -112,14 +113,14 @@ def _generate_key(
def cached(
self,
func: PrimFunc = None,
- out_idx: List[int] = None,
+ out_idx: list[int] = None,
*args,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
- compile_flags: Optional[Union[List[str], str]] = None,
+ compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels to avoid redundant compilation.
@@ -322,15 +323,15 @@ def _save_kernel_to_disk(self,
def _load_kernel_from_disk(
self,
key: str,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
- out_idx: List[int] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
+ out_idx: list[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None,
- compile_flags: Optional[Union[List[str], str]] = None,
+ compile_flags: list[str] | str | None = None,
func: Callable = None,
verbose: bool = False,
- ) -> Optional[JITKernel]:
+ ) -> JITKernel | None:
"""
Loads a previously compiled kernel from disk cache.
@@ -355,15 +356,15 @@ def _load_kernel_from_disk(
if not all([os.path.exists(file) for file in (kernel_lib_path, params_path)]):
return None
- kernel_global_source: Optional[str] = None
- kernel_params: Optional[List[KernelParam]] = None
+ kernel_global_source: str | None = None
+ kernel_params: list[KernelParam] | None = None
# Load the kernel source file (optional)
try:
if verbose:
self.logger.debug(
f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
- with open(wrapped_kernel_path, "r") as f:
+ with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")
diff --git a/tilelang/carver/analysis.py b/tilelang/carver/analysis.py
index 653392df7..96606e790 100644
--- a/tilelang/carver/analysis.py
+++ b/tilelang/carver/analysis.py
@@ -1,5 +1,5 @@
"""Analysis on TIR blocks, loops and functions."""
-from typing import List, Optional, Set, Union
+from __future__ import annotations
from typing_extensions import Literal
from tvm import ir, tir, DataType
@@ -31,7 +31,7 @@ def __init__(
self.loop_rv = loop_rv
@property
- def dom(self) -> Union[int, tir.PrimExpr]:
+ def dom(self) -> int | tir.PrimExpr:
"""The iteration domain of the loop."""
return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom
@@ -46,14 +46,14 @@ class BlockInfo:
"""Information about a TIR block."""
name: str
- iters: List[IterInfo]
+ iters: list[IterInfo]
block_rv: tir.schedule.BlockRV
_reduction_block: bool
def __init__(
self,
name: str,
- iters: List[IterInfo],
+ iters: list[IterInfo],
block_rv: tir.schedule.BlockRV,
reduction_block: bool = False,
):
@@ -63,7 +63,7 @@ def __init__(
self.iters = iters
self._reduction_block = reduction_block
- def dom(self) -> List[Union[int, tir.PrimExpr]]:
+ def dom(self) -> list[int | tir.PrimExpr]:
"""The iteration domain of the block."""
return [i.dom for i in self.iters]
@@ -118,7 +118,7 @@ def __repr__(self) -> str:
_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc")
-def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]:
+def normalize_prim_func(sch: tir.Schedule) -> list[BlockInfo] | None:
"""Normalize the primfunc to normal form"""
try:
result = _normalize_prim_func(sch)
@@ -133,7 +133,7 @@ def _iter_kind(i: tir.IterVar) -> str:
tir.IterVar.CommReduce: "R",
}.get(i.iter_type, "O")
- blocks: List[BlockInfo] = []
+ blocks: list[BlockInfo] = []
for block, loops, iters, is_reduction in zip(*result):
blocks.append(
BlockInfo(
@@ -203,7 +203,7 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
def collect_block_iter_vars_used_in_access_region(block: tir.Block,
- region: List[ir.Range]) -> Set[tir.Var]:
+ region: list[ir.Range]) -> set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
@@ -214,7 +214,7 @@ def collect_block_iter_vars_used_in_access_region(block: tir.Block,
return tir_vars
-def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]:
+def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]:
"""Collect the variables used in the PrimExpr."""
tir_vars = set()
@@ -259,7 +259,7 @@ def is_broadcast_epilogue(
def get_reduction_blocks(sch: tir.Schedule,
- blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]:
+ blocks: list[tir.schedule.BlockRV]) -> list[tir.schedule.BlockRV]:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
block_stmt = sch.get(block)
@@ -286,7 +286,7 @@ def is_spatial(block: BlockRV) -> bool:
def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int:
# gpu memory prefer 128 bits coalesced access (e.g. four banks)
# 128 bits
- buffers: List[tir.Buffer] = []
+ buffers: list[tir.Buffer] = []
for read in block_stmt.reads:
buffers.append(read.buffer)
for write in block_stmt.writes:
diff --git a/tilelang/carver/arch/__init__.py b/tilelang/carver/arch/__init__.py
index 3793d3a13..c2bc9c75d 100644
--- a/tilelang/carver/arch/__init__.py
+++ b/tilelang/carver/arch/__init__.py
@@ -1,14 +1,15 @@
+from __future__ import annotations
+
from .arch_base import TileDevice
from .cuda import *
from .cpu import *
from .cdna import *
from .metal import *
-from typing import Union
from tvm.target import Target
import torch
-def get_arch(target: Union[str, Target] = "cuda") -> TileDevice:
+def get_arch(target: str | Target = "cuda") -> TileDevice:
if isinstance(target, str):
target = Target(target)
diff --git a/tilelang/carver/arch/arch_base.py b/tilelang/carver/arch/arch_base.py
index 06a614fb5..a10fa434d 100644
--- a/tilelang/carver/arch/arch_base.py
+++ b/tilelang/carver/arch/arch_base.py
@@ -1,4 +1,4 @@
-from typing import List
+from __future__ import annotations
class TileDevice:
@@ -14,12 +14,12 @@ def __init__(self) -> None:
0 # The size of a warp, a group of threads that execute instructions in lockstep
)
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions
- self.transaction_size: List[int] = [
+ self.transaction_size: list[int] = [
0,
0,
] # The size of memory transactions, typically in bytes
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed
- self.bandwidth: List[int] = [
+ self.bandwidth: list[int] = [
0,
0,
] # Bandwidth specifications, possibly including peak and sustained rates
@@ -29,9 +29,9 @@ def __init__(self) -> None:
)
self.l2_cache_size_bytes: int = 0
# the number of transaction size in bytes
- self.transaction_size: List[int] = [0, 0] # in bytes
+ self.transaction_size: list[int] = [0, 0] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
- self.bandwidth: List[int] = [0, 0]
+ self.bandwidth: list[int] = [0, 0]
def get_avaliable_tensorintrin_shapes(self):
raise NotImplementedError()
diff --git a/tilelang/carver/arch/cdna.py b/tilelang/carver/arch/cdna.py
index ed9848219..ec5aa905f 100644
--- a/tilelang/carver/arch/cdna.py
+++ b/tilelang/carver/arch/cdna.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
-from typing import List, Union
def is_cdna_arch(arch: TileDevice) -> bool:
@@ -10,7 +10,7 @@ def is_cdna_arch(arch: TileDevice) -> bool:
class CDNA(TileDevice):
- def __init__(self, target: Union[Target, str]):
+ def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
@@ -27,9 +27,9 @@ def __init__(self, target: Union[Target, str]):
self.max_smem_usage: int = 2 * self.smem_cap
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
- self.transaction_size: List[int] = [32, 128] # in bytes
+ self.transaction_size: list[int] = [32, 128] # in bytes
- self.bandwidth: List[int] = [1300, 14000]
+ self.bandwidth: list[int] = [1300, 14000]
__all__ = [
diff --git a/tilelang/carver/arch/cuda.py b/tilelang/carver/arch/cuda.py
index ce5df4af4..4c7f98dff 100644
--- a/tilelang/carver/arch/cuda.py
+++ b/tilelang/carver/arch/cuda.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
import tvm
from tvm.target import Target
from .arch_base import TileDevice
-from typing import List, Union
from .driver import cuda_driver
@@ -91,21 +91,21 @@ def is_tensorcore_supported_precision(in_dtype: str, accum_dtype: str, arch: Til
raise ValueError(f"Unsupported architecture: {arch}")
-class TensorInstruction(object):
+class TensorInstruction:
def __init__(
self,
name: str,
- shape: List[int],
+ shape: list[int],
):
self.name: str = name
# only hold the shape of M and N
- self.shape: List[int] = shape
+ self.shape: list[int] = shape
class CUDA(TileDevice):
- def __init__(self, target: Union[Target, str]):
+ def __init__(self, target: Target | str):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
@@ -126,15 +126,15 @@ def __init__(self, target: Union[Target, str]):
self.sm_partition: int = 4
self.l2_cache_size_bytes: int = target.l2_cache_size_bytes
# the number of transaction size in bytes
- self.transaction_size: List[int] = [32, 128] # in bytes
+ self.transaction_size: list[int] = [32, 128] # in bytes
# bandwidth in MB/s, will be used for recommend basic tile size
# TODO(lei): find some way to get the real bandwidth
# However, the ratio of bandwidth between different devices can
# be similar. The bandwidth can work for another devices as well.
- self.bandwidth: List[int] = [750, 12080]
+ self.bandwidth: list[int] = [750, 12080]
# get the available tensor instructions during runtime to avoid
# the dependency of the tensor intrinsics registration
- self.available_tensor_instructions: List[TensorInstruction] = None
+ self.available_tensor_instructions: list[TensorInstruction] = None
def get_avaliable_tensorintrin_shapes(self):
self.available_tensor_instructions = (
diff --git a/tilelang/carver/arch/driver/cuda_driver.py b/tilelang/carver/arch/driver/cuda_driver.py
index 3e08e9afd..337987dd8 100644
--- a/tilelang/carver/arch/driver/cuda_driver.py
+++ b/tilelang/carver/arch/driver/cuda_driver.py
@@ -1,6 +1,6 @@
+from __future__ import annotations
import ctypes
import sys
-from typing import Optional
class cudaDeviceProp(ctypes.Structure):
@@ -77,7 +77,7 @@ class cudaDeviceProp(ctypes.Structure):
]
-def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
+def get_cuda_device_properties(device_id: int = 0) -> cudaDeviceProp | None:
if sys.platform == "win32":
libcudart = ctypes.windll.LoadLibrary("cudart64_110.dll")
@@ -95,7 +95,7 @@ def get_cuda_device_properties(device_id: int = 0) -> Optional[cudaDeviceProp]:
raise RuntimeError(f"cudaGetDeviceProperties failed with error {ret}")
-def get_device_name(device_id: int = 0) -> Optional[str]:
+def get_device_name(device_id: int = 0) -> str | None:
prop = get_cuda_device_properties(device_id)
if prop:
return prop.name.decode()
@@ -103,7 +103,7 @@ def get_device_name(device_id: int = 0) -> Optional[str]:
raise RuntimeError("Failed to get device properties.")
-def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> Optional[int]:
+def get_shared_memory_per_block(device_id: int = 0, format: str = "bytes") -> int | None:
assert format in ["bytes", "kb", "mb"], "Invalid format. Must be one of: bytes, kb, mb"
prop = get_cuda_device_properties(device_id)
if prop:
@@ -143,7 +143,7 @@ def get_device_attribute(attr: int, device_id: int = 0) -> int:
return None
-def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> Optional[int]:
+def get_max_dynamic_shared_size_bytes(device_id: int = 0, format: str = "bytes") -> int | None:
"""
Get the maximum dynamic shared memory size in bytes, kilobytes, or megabytes.
"""
diff --git a/tilelang/carver/arch/metal.py b/tilelang/carver/arch/metal.py
index 5650f7cc4..9cd1c4d1e 100644
--- a/tilelang/carver/arch/metal.py
+++ b/tilelang/carver/arch/metal.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
from tvm.target import Target
from .arch_base import TileDevice
diff --git a/tilelang/carver/common_schedules.py b/tilelang/carver/common_schedules.py
index 609d02b51..2766a15e3 100644
--- a/tilelang/carver/common_schedules.py
+++ b/tilelang/carver/common_schedules.py
@@ -19,7 +19,8 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm common_schedules.py in dlight.
"""Common schedule strategies for TIR."""
-from typing import Callable, List
+from __future__ import annotations
+from typing import Callable
from tvm import tir
from .utils import retrieve_func_from_module
@@ -28,7 +29,7 @@
def get_block(
sch: tir.Schedule,
- blocks: List[BlockInfo],
+ blocks: list[BlockInfo],
name: str,
):
"""Get the target block from a schedule.
@@ -56,7 +57,7 @@ def get_block(
def get_output_blocks(
sch: tir.Schedule,
- blocks: List[BlockInfo],
+ blocks: list[BlockInfo],
):
"""Get the output blocks of a schedule.
@@ -89,8 +90,8 @@ def get_output_blocks(
def try_inline(
sch: tir.Schedule,
- blocks: List[BlockInfo],
-) -> List[BlockInfo]:
+ blocks: list[BlockInfo],
+) -> list[BlockInfo]:
"""Try to inline as many blocks as possible, and return the remaining blocks.
Parameters
@@ -127,8 +128,8 @@ def _trial(func: Callable):
def try_inline_contiguous_spatial(
sch: tir.Schedule,
- block_infos: List[BlockInfo],
-) -> List[BlockInfo]:
+ block_infos: list[BlockInfo],
+) -> list[BlockInfo]:
"""Try to inline contiguous spatial blocks in a schedule
Parameters
diff --git a/tilelang/carver/matmul_analysis.py b/tilelang/carver/matmul_analysis.py
index dfc1a53e9..02a86cc78 100644
--- a/tilelang/carver/matmul_analysis.py
+++ b/tilelang/carver/matmul_analysis.py
@@ -1,8 +1,8 @@
# pylint: disable=missing-docstring, invalid-name
"""A GEMM schedule rule for GPU operators."""
+from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
-from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
@@ -57,7 +57,7 @@ def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV):
def auto_inline_producers(
sch: tir.Schedule,
block: tir.schedule.BlockRV,
- skip_blocks: Optional[List[tir.schedule.BlockRV]] = None,
+ skip_blocks: list[tir.schedule.BlockRV] | None = None,
):
skip_blocks = skip_blocks or []
while True:
@@ -118,7 +118,7 @@ def auto_inline_consumer_chain(
# used to match the similar region with dequantize op.
-def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
+def find_first_similar_region(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region
@@ -126,7 +126,7 @@ def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer):
# used to match the similar buffer with dequantize op.
-def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
+def find_first_similar_buffer(regions: list[BufferRegion], buffer: tir.Buffer):
for region in regions:
if len(region.buffer.shape) == len(buffer.shape):
return region.buffer
@@ -134,7 +134,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer):
# find the block that required to be reindex and scope.
-def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]:
+def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> BlockRV | None:
# block that most near to the arguments
block = main_block
buffer = buffer
@@ -209,11 +209,11 @@ class IterTrait:
def make_iter_fusion_index_map(
- traits: List[IterTrait],
- kind_order: List[IterKind],
+ traits: list[IterTrait],
+ kind_order: list[IterKind],
) -> tir.IndexMap:
- fused_iters: Dict[IterKind, PrimExpr] = {}
- input_iters: List[tir.Var] = []
+ fused_iters: dict[IterKind, PrimExpr] = {}
+ input_iters: list[tir.Var] = []
for i, trait in enumerate(traits):
v_i = tir.Var(f"i{i}", trait.extent.dtype)
input_iters.append(v_i)
@@ -226,14 +226,14 @@ def make_iter_fusion_index_map(
else:
fused_iters[trait.kind] = v_i
- final_indices: List[tir.PrimExpr] = [
+ final_indices: list[tir.PrimExpr] = [
fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order
]
return tir.IndexMap(input_iters, final_indices, None)
-def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
+def detect_iter_traits(block: tir.Block) -> tuple[list[IterTrait]] | None:
"""Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K]
Parameters
@@ -252,8 +252,8 @@ def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]:
if len(block.reads) != 2 or len(block.writes) != 1:
return None
- def get_access_axes(region: List[Range]) -> Set[Var]:
- axes: Set[Var] = set()
+ def get_access_axes(region: list[Range]) -> set[Var]:
+ axes: set[Var] = set()
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
@@ -267,7 +267,7 @@ def get_access_axes(region: List[Range]) -> Set[Var]:
except ValueError:
return None
- traits: Dict[Var, IterTrait] = {}
+ traits: dict[Var, IterTrait] = {}
for iter_var in block.iter_vars:
var = iter_var.var
kind: IterKind
@@ -308,7 +308,7 @@ def get_access_axes(region: List[Range]) -> Set[Var]:
def get_index_map(block: tir.Block,
- layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]:
+ layout: list[str] | None = None) -> tuple[tir.IndexMap, ...] | None:
"""Get index maps for the block
Parameters
@@ -334,8 +334,8 @@ def get_index_map(block: tir.Block,
return None
A_traits, B_traits, C_traits, block_traits = traits
- def get_ordered_axes(region: List[Range]) -> Set[Var]:
- axes: List[Var] = []
+ def get_ordered_axes(region: list[Range]) -> set[Var]:
+ axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
@@ -352,11 +352,11 @@ def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
- def check_last_trait(region: List[Range]):
+ def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
- def infer_layout(layout: str, region: List[Range], kind: str = "A"):
+ def infer_layout(layout: str, region: list[Range], kind: str = "A"):
"""
Infer the layout based on the region and the kind of buffer
kind: "A", "B", "C"
@@ -409,7 +409,7 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"):
)
-def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
+def get_in_out_dtypes(block: tir.Block) -> tuple[str]:
"""
Detect In/Out data types for the given block based on the analysis if read/write buffers.
"""
@@ -419,7 +419,7 @@ def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
return (in_dtype, out_dtype)
-def get_dequantize_block(sch, blocks) -> Optional[BlockRV]:
+def get_dequantize_block(sch, blocks) -> BlockRV | None:
# check at least two input and one output
# at lease one input has uint dtype, and the output dtype is float
def is_dequantize(block: BlockRV) -> bool:
@@ -445,8 +445,8 @@ def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool:
if not isinstance(block_stmt.body.value, tir.BufferLoad):
return False, False
- def get_access_vars(region: List[Range]) -> List[Var]:
- axes: List[Var] = []
+ def get_access_vars(region: list[Range]) -> list[Var]:
+ axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
return None
@@ -475,7 +475,7 @@ def is_transpose_block(block_stmt: tir.Block) -> bool:
return is_identity_or_transpose_block(block_stmt)[1]
-def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]):
+def inline_transpose_block(sch: tir.Schedule, blocks: list[tir.schedule.BlockRV]):
result_blocks = []
for block in blocks:
if not is_transpose_block(sch.get(block)):
@@ -493,7 +493,7 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]
def normalize_to_matmul(sch: tir.Schedule,
main_block: BlockRV,
- layout: Optional[List[str]] = None) -> Optional[tir.Schedule]:
+ layout: list[str] | None = None) -> tir.Schedule | None:
if layout is None:
layout = ["n", "t", "n"]
block_stmt = sch.get(main_block)
@@ -521,10 +521,10 @@ def normalize_to_matmul(sch: tir.Schedule,
def get_tensorized_func_and_tags(
func: tir.PrimFunc,
target: Target,
- layout: Optional[List[str]] = None,
+ layout: list[str] | None = None,
skip_normalize: bool = False,
allow_gemv: bool = False,
-) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]:
+) -> tuple[tir.PrimFunc, dict[str, list[int] | int]]:
"""
transform function to matmul if necessary (e.g. transform conv2d with im2col)
"""
@@ -554,9 +554,8 @@ def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
- def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
- target: Target) -> Union[bool, Dict]:
- tags: Dict[str, Union[List[int], int]] = {}
+ def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool | dict:
+ tags: dict[str, list[int] | int] = {}
block_stmt = sch.get(block)
# Nvidia Only Support Tensor Core for
@@ -584,8 +583,8 @@ def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
tags["use_async_copy"] = True
# analysis intrin information
- def get_ordered_axes(region: List[Range]) -> Set[Var]:
- axes: List[Var] = []
+ def get_ordered_axes(region: list[Range]) -> set[Var]:
+ axes: list[Var] = []
for r in region:
if not _is_one(r.extent):
raise ValueError("Expect elemwise block access")
@@ -602,7 +601,7 @@ def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)
- def check_last_trait(region: List[Range]):
+ def check_last_trait(region: list[Range]):
axes = get_ordered_axes(region)
return has_common_reduce(axes[-1])
diff --git a/tilelang/carver/roller/bestfit.py b/tilelang/carver/roller/bestfit.py
index e8107112e..b66ceaae7 100644
--- a/tilelang/carver/roller/bestfit.py
+++ b/tilelang/carver/roller/bestfit.py
@@ -17,7 +17,7 @@ def merge(self, other):
self.end = max(self.end, other.end)
def __repr__(self) -> str:
- return "".format(self.start, self.size())
+ return f""
class BestFit:
diff --git a/tilelang/carver/roller/hint.py b/tilelang/carver/roller/hint.py
index 3b51b85c5..20d62f68f 100644
--- a/tilelang/carver/roller/hint.py
+++ b/tilelang/carver/roller/hint.py
@@ -1,6 +1,6 @@
"""Hint definition for schedule"""
+from __future__ import annotations
from tvm import DataType
-from typing import Dict, List, Tuple
from . import PrimFuncNode
import numpy as np
from .rasterization import *
@@ -13,17 +13,17 @@ class TensorCoreExtraConfig:
def __init__(
self,
- AS_shape: Tuple[int],
- BS_shape: Tuple[int],
- AF_shape: Tuple[int],
- BF_shape: Tuple[int],
- tc_axis: Tuple[int],
+ AS_shape: tuple[int],
+ BS_shape: tuple[int],
+ AF_shape: tuple[int],
+ BF_shape: tuple[int],
+ tc_axis: tuple[int],
) -> None:
- self.AS_shape: Tuple[int] = AS_shape
- self.BS_shape: Tuple[int] = BS_shape
- self.AF_shape: Tuple[int] = AF_shape
- self.BF_shape: Tuple[int] = BF_shape
- self.tc_axis: Tuple[int] = tc_axis
+ self.AS_shape: tuple[int] = AS_shape
+ self.BS_shape: tuple[int] = BS_shape
+ self.AF_shape: tuple[int] = AF_shape
+ self.BF_shape: tuple[int] = BF_shape
+ self.tc_axis: tuple[int] = tc_axis
class Stride:
@@ -45,7 +45,7 @@ def ax(self) -> int:
def stride(self) -> int:
return self._stride
- def compute_strides_from_shape(self, shape: List[int]) -> List[int]:
+ def compute_strides_from_shape(self, shape: list[int]) -> list[int]:
ndim = len(shape)
strides = [1 for _ in shape]
for i in range(ndim - 2, -1, -1):
@@ -55,7 +55,7 @@ def compute_strides_from_shape(self, shape: List[int]) -> List[int]:
strides[i] = int(strides[i + 1] * shape[i + 1])
return strides
- def compute_elements_from_shape(self, shape: List[int]) -> int:
+ def compute_elements_from_shape(self, shape: list[int]) -> int:
original_shape = np.prod(shape)
if not self.is_valid():
strided_elem = original_shape
@@ -94,10 +94,10 @@ def __init__(self, output_tile) -> None:
self.grid_size = -1
self.valid = True
- def get_tile(self, func) -> List[int]:
+ def get_tile(self, func) -> list[int]:
return self.tile_map[func]
- def get_rstep(self, node) -> Dict[str, int]:
+ def get_rstep(self, node) -> dict[str, int]:
return self.rstep_map[node]
def __hash__(self) -> int:
@@ -147,7 +147,7 @@ def inter_transform_b(self) -> bool:
return self.weight_transform_kind >= 1
-class Hint(object):
+class Hint:
"""
Central configuration class for managing various parameters of computational tasks.
"""
@@ -178,15 +178,15 @@ def __init__(self) -> None:
# Experimental
self._raxis_order = []
self._step = []
- self.vectorize: Dict[str, int] = {}
+ self.vectorize: dict[str, int] = {}
self.pipeline_stage = 1
self.use_async = False
- self.opt_shapes: Dict[str, int] = {}
+ self.opt_shapes: dict[str, int] = {}
self.intrin_info = IntrinInfo("float16", "float16", True)
self.shared_scope: str = "shared"
- self.pass_context: Dict = {}
+ self.pass_context: dict = {}
- def to_dict(self) -> Dict:
+ def to_dict(self) -> dict:
dic = {}
dic["block"] = self.block
if self.use_tc:
@@ -218,7 +218,7 @@ def to_dict(self) -> Dict:
return dic
@classmethod
- def from_dict(cls, dic: Dict) -> "Hint":
+ def from_dict(cls, dic: dict) -> Hint:
hint = cls()
for k, v in dic.items():
setattr(hint, k, v)
@@ -231,13 +231,13 @@ def tensorcore_legalization(self):
return self
@property
- def raxis_order(self) -> List[int]:
+ def raxis_order(self) -> list[int]:
if self._raxis_order != []:
return self._raxis_order
return list(range(len(self.rstep)))
@property
- def step(self) -> List[int]:
+ def step(self) -> list[int]:
if self._step != []:
return self._step
return [1 for _ in self.block]
diff --git a/tilelang/carver/roller/node.py b/tilelang/carver/roller/node.py
index 120b8a4b7..f9e38b168 100644
--- a/tilelang/carver/roller/node.py
+++ b/tilelang/carver/roller/node.py
@@ -1,9 +1,10 @@
"""PrimFunc Wrapper and Block information Analaysis"""
+from __future__ import annotations
import tvm
from tvm import tir
from tvm.tir import IterVar, PrimFunc
-from typing import Any, Dict, List, Tuple, Optional
+from typing import Any
from tvm.tir.schedule.schedule import BlockRV
import numpy as np
import functools
@@ -29,11 +30,11 @@ def _traverse(block):
_traverse(block)
-class BlockAnalyzer(object):
+class BlockAnalyzer:
def __init__(self, sch) -> None:
self.sch: tir.Schedule = sch
- self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch)
+ self.block_infos: list[BlockInfo] = normalize_prim_func(self.sch)
def get_block_name(self, block: BlockRV) -> str:
return self.sch.get(block).name_hint
@@ -44,7 +45,7 @@ def get_block_info(self, block: BlockRV) -> BlockInfo:
return block_info
return None
- def get_spatial_axis(self, block: BlockRV) -> List[IterVar]:
+ def get_spatial_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block)
axis = []
for iter in block_info.iters:
@@ -52,7 +53,7 @@ def get_spatial_axis(self, block: BlockRV) -> List[IterVar]:
axis.append(iter)
return axis
- def get_reduce_axis(self, block: BlockRV) -> List[IterVar]:
+ def get_reduce_axis(self, block: BlockRV) -> list[IterVar]:
block_info = self.get_block_info(block)
raxis = []
for iter in block_info.iters:
@@ -60,39 +61,39 @@ def get_reduce_axis(self, block: BlockRV) -> List[IterVar]:
raxis.append(iter)
return raxis
- def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]:
+ def get_input_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = []
for read in self.sch.get(block).reads:
buffers.append(read.buffer)
return buffers
- def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]:
+ def get_output_buffers(self, block: BlockRV) -> list[tir.Buffer]:
buffers = []
for write in self.sch.get(block).writes:
buffers.append(write.buffer)
return buffers
- def get_buffers(self, block: BlockRV) -> List[tir.Buffer]:
+ def get_buffers(self, block: BlockRV) -> list[tir.Buffer]:
return self.get_input_buffers(block) + self.get_output_buffers(block)
- def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]:
+ def get_producer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_producers(block)
- def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]:
+ def get_consumer_blocks(self, block: BlockRV) -> list[BlockRV]:
return self.sch.get_consumers(block)
@dataclass
class Edge:
- src_node: 'Node'
- dst_node: 'Node'
+ src_node: Node
+ dst_node: Node
src_id: int
dst_id: int
-class Node(object):
+class Node:
- def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None:
+ def __init__(self, tags: dict | None = None, name: str = "Node") -> None:
self.name = name
if tags is None:
tags = {}
@@ -100,10 +101,10 @@ def __init__(self, tags: Optional[Dict] = None, name: str = "Node") -> None:
self._in_edges = []
self._shapes = []
self._dtypes = []
- self._tag: Dict = {}
+ self._tag: dict = {}
self.update_tags(tags)
- def update_tags(self, tags: Dict) -> None:
+ def update_tags(self, tags: dict) -> None:
for tag in tags:
self.add_tag(tag, tags[tag])
@@ -125,11 +126,11 @@ def is_output(self):
return False
@property
- def inputs(self) -> List[Edge]:
+ def inputs(self) -> list[Edge]:
return self._in_edges
@property
- def outputs(self) -> List[Edge]:
+ def outputs(self) -> list[Edge]:
return self._out_edges
def set_inputs(self, i: int, edge: Edge):
@@ -153,10 +154,10 @@ def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
assert self._dtypes[id] == dtype, (self._dtypes, dtype)
self._dtypes[id] = dtype
- def get_shape(self, id: int = 0) -> List[int]:
+ def get_shape(self, id: int = 0) -> list[int]:
return self._shapes[id]
- def set_shape(self, shape: List[int], id=0, overwrite=False) -> None:
+ def set_shape(self, shape: list[int], id=0, overwrite=False) -> None:
if len(self._shapes) <= id:
self._shapes.extend([None for _ in range(id - len(self._shapes) + 1)])
# elif self._shapes[id] is not None and not overwrite:
@@ -191,15 +192,15 @@ class PrimFuncNode(Node):
def __init__(self,
prim_func: PrimFunc,
- tags: Optional[Dict] = None,
+ tags: dict | None = None,
name: str = "PrimFuncNode") -> None:
super().__init__(tags, name=name)
self.prim_func = self._specialize_func(prim_func)
self.sch: tir.Schedule = tir.Schedule(self.prim_func)
self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch)
- self.schedule_stages: List[BlockRV] = []
- self.blocks: List[BlockRV] = []
- self.output_blocks: List[BlockRV] = None
+ self.schedule_stages: list[BlockRV] = []
+ self.blocks: list[BlockRV] = []
+ self.output_blocks: list[BlockRV] = None
self.reduction_block: BlockRV = None
self.raxis = []
self.input_buffers = []
@@ -219,7 +220,7 @@ def __init__(self,
self.set_dtype(tvm.DataType(buffer.dtype), output_id)
def _assign_placeholder_node(self):
- inputs: List[Node] = []
+ inputs: list[Node] = []
for buffer in self.input_buffers:
inputs.append(PlaceHolderNode(buffer.name))
@@ -301,8 +302,8 @@ def extent_wrapper(self, value) -> int:
else:
return value
- @functools.lru_cache()
- def get_space_dim(self) -> List[int]:
+ @functools.lru_cache
+ def get_space_dim(self) -> list[int]:
dim_size = []
if self.reduction_block:
block_info = self.block_analyzer.get_block_info(self.reduction_block)
@@ -333,7 +334,7 @@ def set_dtype(self, dtype: tvm.DataType, id=0) -> None:
def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType:
return tvm.DataType(buffer.dtype)
- def propagate(self, tile, rstep: Optional[Dict] = None, targets=None):
+ def propagate(self, tile, rstep: dict | None = None, targets=None):
if rstep is None:
rstep = {}
shape = {
@@ -343,7 +344,7 @@ def propagate(self, tile, rstep: Optional[Dict] = None, targets=None):
}
return self.ana.infer(shape, rstep, targets)
- def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
+ def propagate_inputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
@@ -363,7 +364,7 @@ def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]
return results
# Propagate inputs only on reduction block
- def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
+ def propagate_inputs_on_reduction(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
reduction_block = self.reduction_block
@@ -386,7 +387,7 @@ def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> L
results.append(trimmed_shape)
return results
- def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]:
+ def propagate_outputs(self, tile, rstep: dict | None = None) -> list[list[int]]:
if rstep is None:
rstep = {}
read_idx_offset = len(self.input_buffers)
@@ -399,9 +400,7 @@ def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int
results.append(trimmed_shape)
return results
- def propagate_reduction_inputs(self,
- shape,
- rstep: Optional[Dict] = None) -> Dict[str, List[int]]:
+ def propagate_reduction_inputs(self, shape, rstep: dict | None = None) -> dict[str, list[int]]:
if rstep is None:
rstep = {}
if self.reduction_block is None:
@@ -418,8 +417,8 @@ def get_reduce_inputs_dtype(self):
for b in self.block_analyzer.get_input_buffers(self.reduction_block)
}
- @functools.lru_cache()
- def infer_tensorcore_axis(self) -> Tuple[int]:
+ @functools.lru_cache
+ def infer_tensorcore_axis(self) -> tuple[int]:
# axis is fixed for one expression, so only inference and cached
assert self.get_tag("tensorcore_config")
@@ -461,7 +460,7 @@ def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions):
tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n)
return tc_axis
- def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int:
+ def footprint(self, shape, rstep, stride_map: dict | None = None) -> int:
if stride_map is None:
stride_map = {}
result = 0
@@ -510,7 +509,7 @@ def is_after_reduce_stage(block):
result += buffer_len
return result, cached_tensor
- def get_input_buffers(self) -> List[tir.Buffer]:
+ def get_input_buffers(self) -> list[tir.Buffer]:
return self.block_analyzer.input_buffers
@@ -537,7 +536,7 @@ def get_ir(self) -> str:
return "output"
-def topo_order(list_of_nodes) -> List[Node]:
+def topo_order(list_of_nodes) -> list[Node]:
input_ready_count = {node: len(node.inputs) for node in list_of_nodes}
ready = list(filter(lambda node: input_ready_count[node] == 0, list_of_nodes))
output_list = []
@@ -557,7 +556,7 @@ def topo_order(list_of_nodes) -> List[Node]:
return output_list
-def find_topo_sort_priority(output_node_list) -> List[Node]:
+def find_topo_sort_priority(output_node_list) -> list[Node]:
import sys
sys.setrecursionlimit(10000)
@@ -591,7 +590,7 @@ def topo_sort_dfs(node, visited, topo_order):
return topo_order
-def find_topo_sort(output_node_list) -> List[Node]:
+def find_topo_sort(output_node_list) -> list[Node]:
def topo_sort_dfs(node, visited, topo_order):
if node in visited:
diff --git a/tilelang/carver/roller/policy/common.py b/tilelang/carver/roller/policy/common.py
index 0dadfa8a2..747dddbb0 100644
--- a/tilelang/carver/roller/policy/common.py
+++ b/tilelang/carver/roller/policy/common.py
@@ -1,8 +1,8 @@
-from typing import List
+from __future__ import annotations
import numpy as np
-def get_all_factors(n: int) -> List[int]:
+def get_all_factors(n: int) -> list[int]:
# Calculate the square root of n and round it up to the nearest integer
n0 = int(np.ceil(np.sqrt(n)))
@@ -16,7 +16,7 @@ def get_all_factors(n: int) -> List[int]:
return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])]
-def factorize(n: int) -> List[int]:
+def factorize(n: int) -> list[int]:
i = 2 # Start with the smallest prime number
result = []
@@ -30,7 +30,7 @@ def factorize(n: int) -> List[int]:
return result
-def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
+def coalesced_factor(subtensor: list[int], tensor: list[int]) -> int:
# If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension
if subtensor[-1] != tensor[-1] or len(subtensor) == 1:
return subtensor[-1]
@@ -39,7 +39,7 @@ def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int:
return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1])
-def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int:
+def coalesced_tensor_shape(subtensor: list[int], tensor: list[int], transaction_size: int) -> int:
# Calculate the total number of elements in the subtensor
bytes = int(np.prod(subtensor))
diff --git a/tilelang/carver/roller/policy/default.py b/tilelang/carver/roller/policy/default.py
index 7837395d9..36d8f1f2c 100644
--- a/tilelang/carver/roller/policy/default.py
+++ b/tilelang/carver/roller/policy/default.py
@@ -1,8 +1,9 @@
"""Policy for cuda core schedule"""
+from __future__ import annotations
import functools
import math
from queue import PriorityQueue
-from typing import Iterable, Dict, List, Optional
+from typing import Iterable
import numpy as np
import tvm
@@ -22,11 +23,11 @@ class DefaultPolicy:
"""
func: tvm.tir.PrimFunc
- nodes: List[PrimFuncNode] = []
+ nodes: list[PrimFuncNode] = []
arch: TileDevice
- tags: Dict
+ tags: dict
- def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None:
+ def __init__(self, arch: TileDevice, tags: dict | None = None) -> None:
if tags is None:
tags = {}
@@ -38,20 +39,17 @@ def __init__(self, arch: TileDevice, tags: Optional[Dict] = None) -> None:
def from_prim_func(cls,
func: tvm.tir.PrimFunc,
arch: TileDevice,
- tags: Optional[Dict] = None,
+ tags: dict | None = None,
name: str = "PrimFuncNode"):
return cls(arch, tags)._init_with_prim_func(func, name)
@classmethod
- def from_output_nodes(cls,
- nodes: List[OutputNode],
- arch: TileDevice,
- tags: Optional[Dict] = None):
+ def from_output_nodes(cls, nodes: list[OutputNode], arch: TileDevice, tags: dict | None = None):
return cls(arch, tags)._init_with_output_nodes(nodes)
def _init_with_prim_func(self,
func: tvm.tir.PrimFunc,
- name: str = "PrimFuncNode") -> "DefaultPolicy":
+ name: str = "PrimFuncNode") -> DefaultPolicy:
if func is not None and isinstance(func, tvm.tir.PrimFunc):
self.func = func
self.prim_func_node = PrimFuncNode(self.func, tags=self.tags, name=name)
@@ -61,7 +59,7 @@ def _init_with_prim_func(self,
self._init_with_output_nodes(output_nodes)
return self
- def _init_with_output_nodes(self, output_nodes: List[OutputNode]):
+ def _init_with_output_nodes(self, output_nodes: list[OutputNode]):
self.ordered_nodes = list(
filter(lambda n: not n.is_placeholder() and not n.is_output(),
find_topo_sort(output_nodes)))
@@ -78,7 +76,7 @@ def _init_with_output_nodes(self, output_nodes: List[OutputNode]):
self.output_nodes.append(node)
return self
- def emit_config(self, topk: int) -> List[Hint]:
+ def emit_config(self, topk: int) -> list[Hint]:
base_tile = self.get_base_tile()
if base_tile is None:
return []
@@ -557,7 +555,7 @@ def _compute_stride_map(self, td: TileDict):
node, td)
td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map
- def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict:
+ def compute_tile_dict(self, output_tile: list[int], rstep_map) -> TileDict:
"""
Computes and returns a TileDict object for a given output tile configuration and reduction step map.
@@ -624,7 +622,7 @@ def check_tile_shape_isvalid(self, td: TileDict) -> bool:
return True
- def recommend_block_size(self, td: TileDict) -> List[int]:
+ def recommend_block_size(self, td: TileDict) -> list[int]:
"""
Recommends optimal block sizes based on the TileDict configuration.
diff --git a/tilelang/carver/roller/policy/tensorcore.py b/tilelang/carver/roller/policy/tensorcore.py
index 60edc930e..15bad4122 100644
--- a/tilelang/carver/roller/policy/tensorcore.py
+++ b/tilelang/carver/roller/policy/tensorcore.py
@@ -1,6 +1,6 @@
"""Policy for tensorcore schedule"""
+from __future__ import annotations
import tvm
-from typing import Dict, List, Tuple, Optional
import numpy as np
import logging
from ..hint import Hint, Stride, TileDict, IntrinInfo
@@ -19,9 +19,9 @@ class TensorCorePolicy(DefaultPolicy):
wmma_k: int = 16
pipeline_stage: int = 1
use_async_copy: bool = False
- block_reduction_depth: Optional[int] = None
+ block_reduction_depth: int | None = None
- def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: Optional[str] = None):
+ def _init_with_prim_func(self, func: tvm.tir.PrimFunc, name: str | None = None):
super()._init_with_prim_func(func, name)
self._legalize_info()
return self
@@ -52,9 +52,9 @@ def _legalize_info(self):
def _compute_tc_strides(
self,
node: PrimFuncNode,
- tile: List[int],
- rstep: Optional[Dict[str, int]] = None,
- ) -> Tuple[Stride, Stride, Stride]:
+ tile: list[int],
+ rstep: dict[str, int] | None = None,
+ ) -> tuple[Stride, Stride, Stride]:
if rstep is None:
rstep = {}
# strides was used for shared memory padding. which is necessary for avoiding
diff --git a/tilelang/carver/roller/rasterization.py b/tilelang/carver/roller/rasterization.py
index 3ead2e12e..39c603b6b 100644
--- a/tilelang/carver/roller/rasterization.py
+++ b/tilelang/carver/roller/rasterization.py
@@ -1,6 +1,5 @@
"""Rasteration Plan For L2 Cache Locality"""
-
-from typing import List
+from __future__ import annotations
class Rasterization:
@@ -10,7 +9,7 @@ class Rasterization:
def __init__(self) -> None:
pass
- def get_code(self) -> List[str]:
+ def get_code(self) -> list[str]:
raise NotImplementedError()
@property
@@ -27,7 +26,7 @@ def __init__(self) -> None:
def __repr__(self) -> str:
return ""
- def get_code(self) -> List[str]:
+ def get_code(self) -> list[str]:
return []
@@ -47,7 +46,7 @@ def __init__(self, panel_width=4) -> None:
def __repr__(self) -> str:
return f""
- def get_code(self) -> List[str]:
+ def get_code(self) -> list[str]:
raise NotImplementedError()
@@ -84,10 +83,10 @@ def get_device_function(self) -> str:
}
"""
- def get_code(self, panel_width: int = None) -> List[str]:
+ def get_code(self, panel_width: int = None) -> list[str]:
if panel_width is None:
panel_width = self.panel_width_
return [
self.get_device_function(),
- "const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width),
+ f"const dim3 blockIdx = rasterization2DColumn({panel_width});\n",
]
diff --git a/tilelang/carver/roller/shape_inference/common.py b/tilelang/carver/roller/shape_inference/common.py
index a3a7a31d6..aaf59aed9 100644
--- a/tilelang/carver/roller/shape_inference/common.py
+++ b/tilelang/carver/roller/shape_inference/common.py
@@ -1,10 +1,10 @@
+from __future__ import annotations
from collections import OrderedDict
-from typing import Dict, List
from tvm import arith
-class Statement():
+class Statement:
def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict,
range_map: OrderedDict):
@@ -18,12 +18,12 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
-class InputShapeInference():
+class InputShapeInference:
- def __init__(self, deps: List[Statement]):
+ def __init__(self, deps: list[Statement]):
self.deps = deps
- def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]):
+ def _infer(self, shape: dict[str, list[arith.ConstIntBound]], rstep: dict[str, int]):
shape = shape.copy()
ana = arith.Analyzer()
for dep in reversed(self.deps):
@@ -44,7 +44,7 @@ def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, i
shape[name] = [c.max_value - c.min_value + 1 for c in bounds]
return shape
- def infer(self, shape, rstep: Dict[str, int] = None):
+ def infer(self, shape, rstep: dict[str, int] = None):
if rstep is None:
rstep = {}
if isinstance(shape, (list, tuple)):
diff --git a/tilelang/carver/roller/shape_inference/tir.py b/tilelang/carver/roller/shape_inference/tir.py
index 8a744ec00..c1b97188a 100644
--- a/tilelang/carver/roller/shape_inference/tir.py
+++ b/tilelang/carver/roller/shape_inference/tir.py
@@ -1,4 +1,5 @@
-from typing import Dict, List, Tuple, Set, Mapping
+from __future__ import annotations
+from typing import Mapping
from tvm.tir.schedule.schedule import BlockRV
from tvm.ir import structural_equal
from tvm import arith, tir
@@ -15,7 +16,7 @@ def __init__(self, block_analyzer, block: BlockRV):
self.reverse_bound_inference = {}
- def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]):
+ def make_reverse(self, input_name: str, input_iter: list[tir.PrimExpr]):
if len(self.block_analyzer.get_reduce_axis(self.block)) > 0:
return None
if len(self.dependent_region[input_name]) != 1:
@@ -47,7 +48,7 @@ def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound):
return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value))
-class TensorDepNode(object):
+class TensorDepNode:
"""
For tensor dependency analysis.
"""
@@ -76,7 +77,7 @@ def __repr__(self):
return self.name
-class DependencyAnalysis(object):
+class DependencyAnalysis:
def __init__(self, deps):
self.deps = deps
@@ -89,7 +90,7 @@ def _construct_unique_name2dep(self, deps):
This is a workaround for the issue that we have two same ops' fuse case.
See https://github.com/apache/tvm/issues/16433
"""
- _names: Set = set()
+ _names: set = set()
name2dep: Mapping = {}
for dep in deps:
output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0]
@@ -168,7 +169,7 @@ def _find_path_recursive(self, current_node, target_name, visited, path):
class InputShapeInference:
- def __init__(self, deps: List[Statement]):
+ def __init__(self, deps: list[Statement]):
self.deps = deps
self.target_mapping = {}
self.buffer_mapping = {}
@@ -179,7 +180,7 @@ def __init__(self, deps: List[Statement]):
self.dep_analysis = DependencyAnalysis(self.deps)
self.dep_analysis.analyze()
- def construct_dependency_target(self, targets: Tuple[str]):
+ def construct_dependency_target(self, targets: tuple[str]):
if targets in self.target_mapping:
return self.target_mapping[targets]
# should be buffer name instead of block name
@@ -242,8 +243,8 @@ def construct_dependency_target(self, targets: Tuple[str]):
return input_vars, mapping
def infer(self,
- shape: Dict[str, List[arith.ConstIntBound]],
- rstep: Dict[str, int] = None,
+ shape: dict[str, list[arith.ConstIntBound]],
+ rstep: dict[str, int] = None,
targets=None):
if rstep is None:
rstep = {}
@@ -351,10 +352,10 @@ def walk_indice(expr):
elif isinstance(expr, tir.Call):
return None
else:
- raise Exception("Unhandled node type in walk_indice(): %s" % expr)
+ raise Exception(f"Unhandled node type in walk_indice(): {expr}")
-def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]:
+def _extract_dependent_region(block_analyzer, block: BlockRV) -> dict[str, list[tir.PrimExpr]]:
input_buffers = block_analyzer.get_input_buffers(block)
dependent_region = {buffer.name: [] for buffer in input_buffers}
diff --git a/tilelang/carver/template/base.py b/tilelang/carver/template/base.py
index 0de3c5996..5aa5074c2 100644
--- a/tilelang/carver/template/base.py
+++ b/tilelang/carver/template/base.py
@@ -1,11 +1,11 @@
# Import necessary modules and classes
+from __future__ import annotations
from abc import ABC, abstractmethod # For defining abstract base classes
from dataclasses import dataclass, field # For defining data classes
from ..arch import ( # Import architecture-related utilities and classes
TileDevice, is_volta_arch, is_ampere_arch, is_cdna_arch, auto_infer_current_arch)
from ..roller.hint import Hint # Import the Hint class
from ..roller.node import OutputNode # Import the OutputNode class
-from typing import List # For type hinting
from tvm.tir import PrimFunc # Import PrimFunc for handling tensor IR functions
@@ -24,10 +24,10 @@ class BaseTemplate(ABC):
_func: PrimFunc = field(default=None, init=False, repr=False)
# The outputs nodes associated with this template, initially None
- _output_nodes: List[OutputNode] = field(default=None, init=False, repr=False)
+ _output_nodes: list[OutputNode] = field(default=None, init=False, repr=False)
@abstractmethod
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Abstract method that must be implemented by subclasses.
It should return a list of hardware-aware configurations (hints)
@@ -42,7 +42,7 @@ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) ->
"""
pass
- def with_arch(self, arch: TileDevice) -> "BaseTemplate":
+ def with_arch(self, arch: TileDevice) -> BaseTemplate:
"""
Sets the architecture for this template and returns itself.
@@ -110,7 +110,7 @@ def initialize_function(self) -> None:
"""
raise NotImplementedError("initialize_function is not implemented")
- def set_function(self, func: PrimFunc) -> "BaseTemplate":
+ def set_function(self, func: PrimFunc) -> BaseTemplate:
"""
Sets the function for this template and returns itself.
@@ -123,7 +123,7 @@ def set_function(self, func: PrimFunc) -> "BaseTemplate":
self._func = func
return self
- def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate":
+ def set_output_nodes(self, output_nodes: list[OutputNode]) -> BaseTemplate:
"""
Sets the output nodes for this template and returns itself.
@@ -136,7 +136,7 @@ def set_output_nodes(self, output_nodes: List[OutputNode]) -> "BaseTemplate":
self._output_nodes = output_nodes
return self
- def recommend_hints(self, topk: int = 10) -> List[Hint]:
+ def recommend_hints(self, topk: int = 10) -> list[Hint]:
"""
Provides a list of recommended hardware-aware configurations.
@@ -159,7 +159,7 @@ def arch(self) -> TileDevice:
return self._arch
@property
- def output_nodes(self) -> List[OutputNode]:
+ def output_nodes(self) -> list[OutputNode]:
"""
Returns the output nodes associated with this template.
diff --git a/tilelang/carver/template/conv.py b/tilelang/carver/template/conv.py
index 5931b2656..f180084d5 100644
--- a/tilelang/carver/template/conv.py
+++ b/tilelang/carver/template/conv.py
@@ -1,8 +1,8 @@
+from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te, tir
from ..roller import Hint
-from typing import List
from ..utils import get_roller_hints_from_func
@@ -44,7 +44,7 @@ class ConvTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term
- def get_hardware_aware_configs(self, arch=None, topk=10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
diff --git a/tilelang/carver/template/elementwise.py b/tilelang/carver/template/elementwise.py
index 311b75ccf..26d531529 100644
--- a/tilelang/carver/template/elementwise.py
+++ b/tilelang/carver/template/elementwise.py
@@ -1,10 +1,10 @@
# Import necessary modules
+from __future__ import annotations
from dataclasses import dataclass # Used for defining data classes
from .base import BaseTemplate # Importing the base class for templates
from tvm import te # Importing TVM's tensor expression module
from ..arch import TileDevice # Importing TileDevice for hardware-specific configurations
from ..roller import Hint # Importing Hint for optimization hints
-from typing import List # Importing List type hint
from ..utils import get_roller_hints_from_func # Function to obtain optimization hints
@@ -19,10 +19,10 @@ class ElementwiseTemplate(BaseTemplate):
"""
# OP Related Config
- shape: List[int] = None # Shape of the tensor
+ shape: list[int] = None # Shape of the tensor
dtype: str = "float16" # Data type of the tensor
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves hardware-aware optimization configurations.
diff --git a/tilelang/carver/template/flashattention.py b/tilelang/carver/template/flashattention.py
index f9dc85b76..760b19817 100644
--- a/tilelang/carver/template/flashattention.py
+++ b/tilelang/carver/template/flashattention.py
@@ -1,17 +1,17 @@
+from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
from ..roller import PrimFuncNode, OutputNode, Edge
-from typing import List
from ..utils import get_roller_hints_from_output_nodes, get_tensorized_func_and_tags
@dataclass
class FlashAttentionTemplate(BaseTemplate):
- _output_nodes: List[OutputNode] = None
+ _output_nodes: list[OutputNode] = None
# Operation-related configuration parameters
batch_size: int = 1
@@ -26,7 +26,7 @@ class FlashAttentionTemplate(BaseTemplate):
out_dtype: str = "float16"
accum_dtype: str = "float16"
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
diff --git a/tilelang/carver/template/gemv.py b/tilelang/carver/template/gemv.py
index a6e943a01..7195a0b87 100644
--- a/tilelang/carver/template/gemv.py
+++ b/tilelang/carver/template/gemv.py
@@ -1,9 +1,9 @@
+from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
-from typing import List
from ..utils import get_roller_hints_from_func
@@ -25,7 +25,7 @@ class GEMVTemplate(BaseTemplate):
accum_dtype: str = "float16" # Accumulation data type
with_bias: bool = False # Whether to add a bias term
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
diff --git a/tilelang/carver/template/general_reduce.py b/tilelang/carver/template/general_reduce.py
index 9eba86c63..a8da5fd6c 100644
--- a/tilelang/carver/template/general_reduce.py
+++ b/tilelang/carver/template/general_reduce.py
@@ -1,9 +1,9 @@
+from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
-from typing import List, Union
from ..utils import get_roller_hints_from_func
@@ -11,11 +11,11 @@
class GeneralReductionTemplate(BaseTemplate):
# OP Related Config
- structure: Union[str, List[str]] = None
- shape: List[int] = None
+ structure: str | list[str] = None
+ shape: list[int] = None
dtype: str = "float16"
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
roller_hints = get_roller_hints_from_func(
self._func, arch=arch, topk=topk, allow_gemv=False)
return roller_hints
diff --git a/tilelang/carver/template/matmul.py b/tilelang/carver/template/matmul.py
index 24aa6ef91..4847cdb22 100644
--- a/tilelang/carver/template/matmul.py
+++ b/tilelang/carver/template/matmul.py
@@ -1,9 +1,9 @@
+from __future__ import annotations
from dataclasses import dataclass
from .base import BaseTemplate
from tvm import te
from ..arch import TileDevice
from ..roller import Hint
-from typing import List
from ..utils import get_roller_hints_from_func
@@ -38,7 +38,7 @@ class MatmulTemplate(BaseTemplate):
accum_dtype: str = "float16" # Data type for accumulation
with_bias: bool = False # Whether to add a bias term
- def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> List[Hint]:
+ def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:
"""
Retrieves optimized hardware-aware configurations.
diff --git a/tilelang/carver/utils.py b/tilelang/carver/utils.py
index 649b4388c..cedb7547a 100644
--- a/tilelang/carver/utils.py
+++ b/tilelang/carver/utils.py
@@ -1,4 +1,4 @@
-from typing import List, Optional, Union
+from __future__ import annotations
from tvm import tir, IRModule
from tvm.tir import PrimFunc
from .arch import TileDevice
@@ -26,11 +26,11 @@ def get_rasterization_code(pannel_width: int = 8) -> str:
"""
-def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
+def get_roller_hints_from_func(func_or_module: tir.PrimFunc | IRModule,
arch: TileDevice,
topk: int = 10,
tensorcore_only: bool = False,
- allow_gemv: bool = False) -> Optional[List[Hint]]:
+ allow_gemv: bool = False) -> list[Hint] | None:
func = None
if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module
@@ -69,11 +69,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
return roller_hints
-def get_roller_hints_from_output_nodes(
- output_nodes: List[OutputNode],
- arch: TileDevice,
- topk: int = 10,
- extra_tags: Optional[List[str]] = None) -> Optional[List[Hint]]:
+def get_roller_hints_from_output_nodes(output_nodes: list[OutputNode],
+ arch: TileDevice,
+ topk: int = 10,
+ extra_tags: list[str] | None = None) -> list[Hint] | None:
assert isinstance(output_nodes, list), "The input should be a list of functions."
lints = []
diff --git a/tilelang/contrib/cc.py b/tilelang/contrib/cc.py
index 26bb419db..d5cba6c4e 100644
--- a/tilelang/contrib/cc.py
+++ b/tilelang/contrib/cc.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Util to invoke C/C++ compilers in the system."""
+from __future__ import annotations
import functools
import os
import shutil
@@ -23,7 +24,6 @@
# pylint: disable=invalid-name
import sys
-from typing import Dict
from tvm.base import py_str
from tvm.contrib import tar as _tar
@@ -208,7 +208,7 @@ def create_executable(output, objects, options=None, cc=None, cwd=None, ccache_e
raise ValueError("Unsupported platform")
-def get_global_symbol_section_map(path, *, nm=None) -> Dict[str, str]:
+def get_global_symbol_section_map(path, *, nm=None) -> dict[str, str]:
"""Get global symbols from a library via nm -g
Parameters
diff --git a/tilelang/contrib/hipcc.py b/tilelang/contrib/hipcc.py
index afd381223..92fbcc8e3 100644
--- a/tilelang/contrib/hipcc.py
+++ b/tilelang/contrib/hipcc.py
@@ -54,7 +54,7 @@ def compile_hip(code,
if target_format not in ["hsaco"]:
raise ValueError("target_format must be hsaco")
temp_code = temp.relpath("my_kernel.cc")
- temp_target = temp.relpath("my_kernel.%s" % target_format)
+ temp_target = temp.relpath(f"my_kernel.{target_format}")
with open(temp_code, "w") as out_file:
out_file.write(code)
diff --git a/tilelang/contrib/nvcc.py b/tilelang/contrib/nvcc.py
index 6b2e739a0..8e813d92b 100644
--- a/tilelang/contrib/nvcc.py
+++ b/tilelang/contrib/nvcc.py
@@ -2,11 +2,11 @@
# modified from apache tvm python/tvm/contrib/nvcc.py
"""Utility to invoke nvcc compiler in the system"""
from __future__ import absolute_import as _abs
+from __future__ import annotations
import os
import subprocess
import warnings
-from typing import Tuple
from tilelang.env import CUDA_HOME
import tvm.ffi
@@ -299,7 +299,7 @@ def get_target_compute_version(target=None):
"Try specifying it by adding '-arch=sm_xx' to your target.")
-def parse_compute_version(compute_version) -> Tuple[int, int]:
+def parse_compute_version(compute_version) -> tuple[int, int]:
"""Parse compute capability string to divide major and minor version
Parameters
diff --git a/tilelang/contrib/nvrtc.py b/tilelang/contrib/nvrtc.py
index 0f07022c9..b69115549 100644
--- a/tilelang/contrib/nvrtc.py
+++ b/tilelang/contrib/nvrtc.py
@@ -1,10 +1,11 @@
+from __future__ import annotations
import cuda.bindings.nvrtc as nvrtc
-from typing import Literal, Union, List, Optional, Tuple
+from typing import Literal
from tvm.target import Target
from .nvcc import get_target_compute_version, parse_compute_version
-def get_nvrtc_version() -> Tuple[int, int]:
+def get_nvrtc_version() -> tuple[int, int]:
result, major, minor = nvrtc.nvrtcVersion()
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f"Failed to get NVRTC version: {result}"
return (major, minor)
@@ -12,8 +13,8 @@ def get_nvrtc_version() -> Tuple[int, int]:
def compile_cuda(code: str,
target_format: Literal["ptx", "cubin"] = "ptx",
- arch: Optional[int] = None,
- options: Optional[Union[str, List[str]]] = None,
+ arch: int | None = None,
+ options: str | list[str] | None = None,
verbose: bool = False) -> bytearray:
"""Compile cuda code with NVRTC.
diff --git a/tilelang/engine/callback.py b/tilelang/engine/callback.py
index 8d43e41d5..ee1c80693 100644
--- a/tilelang/engine/callback.py
+++ b/tilelang/engine/callback.py
@@ -1,4 +1,5 @@
-from typing import Callable, Union
+from __future__ import annotations
+from typing import Callable
from tvm import register_func
from tvm.target import Target
@@ -25,7 +26,7 @@ def register_hip_postproc(func: Callable[[str, Target], str], override: bool = T
register_func("tilelang_callback_hip_postproc", f=func, override=override)
-def register_cuda_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
+def register_cuda_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering CUDA post-processing callback function.
Can be used with or without parentheses:
@@ -58,7 +59,7 @@ def _register(fn: Callable[[str, Target], str]):
raise TypeError("Invalid decorator usage")
-def register_hip_postproc_callback(func: Union[Callable, bool] = None, override: bool = True):
+def register_hip_postproc_callback(func: Callable | bool = None, override: bool = True):
"""Decorator for registering HIP post-processing callback function.
Can be used with or without parentheses:
diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py
index 717a8ebd2..8738f58a1 100644
--- a/tilelang/engine/lower.py
+++ b/tilelang/engine/lower.py
@@ -1,8 +1,9 @@
"""The compiler for TL programs."""
+from __future__ import annotations
import os
import os.path as osp
-from typing import Union, Optional, Callable, List
+from typing import Callable
import tilelang.transform
from tilelang import tvm as tvm
from tvm import tir
@@ -114,7 +115,7 @@ def tilelang_callback_hip_compile(code, target):
return hsaco
-def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
+def extrac_params(func: tir.PrimFunc) -> list[KernelParam]:
tensor_types = []
for var in func.params:
if var in func.buffer_map:
@@ -124,7 +125,7 @@ def extrac_params(func: tir.PrimFunc) -> List[KernelParam]:
return tensor_types
-def canon_target_host(target: Union[str, Target], target_host: Optional[Union[str, Target]]):
+def canon_target_host(target: str | Target, target_host: str | Target | None):
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "c"
@@ -190,9 +191,9 @@ def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target) ->
def lower(
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
- target: Union[str, Target] = "auto",
- target_host: Optional[Union[str, Target]] = None,
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
+ target: str | Target = "auto",
+ target_host: str | Target | None = None,
runtime_only=False,
enable_host_codegen=False,
enable_device_compile=False,
diff --git a/tilelang/engine/param.py b/tilelang/engine/param.py
index 2db2d8391..de3c979ea 100644
--- a/tilelang/engine/param.py
+++ b/tilelang/engine/param.py
@@ -1,7 +1,7 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
from dataclasses import dataclass
-from typing import List, Union, Optional
import torch
from tilelang import tvm as tvm
from tvm.tir import Buffer, IntImm, Var, PrimExpr
@@ -15,7 +15,7 @@ class KernelParam:
Used to describe tensor or scalar parameters in TVM/PyTorch interop.
"""
dtype: torch.dtype # PyTorch data type of the parameter
- shape: List[Union[int, Var]] # List of dimensions, can be integers or TVM variables
+ shape: list[int | Var] # List of dimensions, can be integers or TVM variables
@classmethod
def from_buffer(cls, buffer: Buffer):
@@ -111,7 +111,6 @@ class CompiledArtifact:
"""
host_mod: tvm.IRModule # Host-side TVM IR module for managing kernel execution
device_mod: tvm.IRModule # Device-side TVM IR module containing the actual kernel code
- params: List[KernelParam] # List of parameters (tensors/scalars) used by the kernel
+ params: list[KernelParam] # List of parameters (tensors/scalars) used by the kernel
kernel_source: str # Raw source code of the generated kernel
- rt_mod: Optional[
- tvm.runtime.Module] = None # Runtime module for execution, may be lazily initialized
+ rt_mod: tvm.runtime.Module | None = None # Runtime module for execution, may be lazily initialized
diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py
index 7126186cc..10fd87d10 100644
--- a/tilelang/engine/phase.py
+++ b/tilelang/engine/phase.py
@@ -1,13 +1,13 @@
+from __future__ import annotations
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from tilelang.contrib.nvcc import have_tma, is_hopper
-from typing import Optional
-def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
- target: Optional[Target] = None) -> bool:
+def allow_warp_specialized(pass_ctx: PassContext | None = None,
+ target: Target | None = None) -> bool:
# avoid circular import
from tilelang.jit.adapter.utils import is_cuda_target
@@ -19,8 +19,8 @@ def allow_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_warp_specialized
-def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
- target: Optional[Target] = None) -> bool:
+def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None,
+ target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if not have_tma(target):
@@ -29,26 +29,26 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)
-def allow_fence_proxy(target: Optional[Target] = None) -> bool:
+def allow_fence_proxy(target: Target | None = None) -> bool:
return have_tma(target)
-def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
+def allow_vectorize(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize
-def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None) -> bool:
+def allow_global_thread_synchronization(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_global_thread_sync = pass_ctx.config.get("tir.detect_global_barrier", False)
return enable_global_thread_sync
-def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
- target: Optional[Target] = None) -> bool:
+def should_enable_aggressive_merge(pass_ctx: PassContext | None = None,
+ target: Target | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
enable_aggressive_merge = bool(
@@ -61,7 +61,7 @@ def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
return enable_aggressive_merge
-def should_force_let_inline(pass_ctx: Optional[PassContext] = None) -> bool:
+def should_force_let_inline(pass_ctx: PassContext | None = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
return bool(pass_ctx and pass_ctx.config.get(tilelang.PassConfigKey.TL_FORCE_LET_INLINE, False))
diff --git a/tilelang/env.py b/tilelang/env.py
index 08cf031ca..9d3f50a8e 100644
--- a/tilelang/env.py
+++ b/tilelang/env.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import sys
import os
import pathlib
@@ -5,7 +6,6 @@
import shutil
import glob
from dataclasses import dataclass
-from typing import Optional
logger = logging.getLogger(__name__)
@@ -170,7 +170,7 @@ class Environment:
key: str # Environment variable name (e.g. "TILELANG_PRINT_ON_COMPILATION")
default: str # Default value if the environment variable is not set
- _forced_value: Optional[str] = None # Temporary runtime override (mainly for tests/debugging)
+ _forced_value: str | None = None # Temporary runtime override (mainly for tests/debugging)
def get(self):
if self._forced_value is not None:
diff --git a/tilelang/intrinsics/mfma_macro_generator.py b/tilelang/intrinsics/mfma_macro_generator.py
index 12551b193..aa369980f 100644
--- a/tilelang/intrinsics/mfma_macro_generator.py
+++ b/tilelang/intrinsics/mfma_macro_generator.py
@@ -1,17 +1,16 @@
+from __future__ import annotations
from tilelang import tvm as tvm
import tilelang.language as T
-from typing import Tuple
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
-from typing import Optional
from .utils import (
mfma_store_index_map,)
lift = convert
-class MatrixCoreIntrinEmitter(object):
+class MatrixCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
@@ -51,9 +50,9 @@ def __init__(
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
- k_pack: Optional[int] = None,
- is_m_first: Optional[bool] = False,
- b_preshuffle: Optional[bool] = False,
+ k_pack: int | None = None,
+ is_m_first: bool | None = False,
+ b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
@@ -135,15 +134,15 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_y = n_dim
self.micro_size_k = k_dim
- def _initialize_k_pack(self, k_pack: Optional[int] = None):
+ def _initialize_k_pack(self, k_pack: int | None = None):
if k_pack is not None:
self.k_pack = k_pack
- def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
+ def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
- def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False):
+ def _initialize_b_preshuffle(self, b_preshuffle: bool | None = False):
if b_preshuffle is not None:
self.b_preshuffle = b_preshuffle
@@ -203,7 +202,7 @@ def get_ldmatrix_index_map(self, is_b=False):
def extract_thread_binding(self,
thread_id,
- is_m_first=None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
+ is_m_first=None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
'''
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
@@ -418,10 +417,10 @@ def __init__(
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
- k_pack: Optional[int] = None,
- is_m_first: Optional[bool] = False,
- a_preshuffle: Optional[bool] = False,
- b_preshuffle: Optional[bool] = False,
+ k_pack: int | None = None,
+ is_m_first: bool | None = False,
+ a_preshuffle: bool | None = False,
+ b_preshuffle: bool | None = False,
):
self.a_dtype = a_dtype
diff --git a/tilelang/intrinsics/mma_layout.py b/tilelang/intrinsics/mma_layout.py
index 8ddd9f96d..1fec00584 100644
--- a/tilelang/intrinsics/mma_layout.py
+++ b/tilelang/intrinsics/mma_layout.py
@@ -1,4 +1,4 @@
-from typing import Union
+from __future__ import annotations
from tvm import arith, DataType
import tilelang.language as T
@@ -163,7 +163,7 @@ def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)
-def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str], swizzle_bytes=None):
+def get_swizzle_layout(row_idx, col_idx, row_size, dtype: DataType | str, swizzle_bytes=None):
ana = arith.Analyzer()
if isinstance(dtype, str):
dtype = DataType(dtype)
diff --git a/tilelang/intrinsics/mma_macro_generator.py b/tilelang/intrinsics/mma_macro_generator.py
index 65d2ab0ca..537cc762c 100644
--- a/tilelang/intrinsics/mma_macro_generator.py
+++ b/tilelang/intrinsics/mma_macro_generator.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
import tilelang.language as T
-from typing import Union, Tuple, Optional, Literal, Callable
+from typing import Literal, Callable
from tilelang.common import TransformKind
from tvm import DataType
from tvm.tir import PrimExpr, IndexMap, Buffer, Var
@@ -25,7 +26,7 @@
lift = convert
-class TensorCoreIntrinEmitter(object):
+class TensorCoreIntrinEmitter:
"""
To eliminate Python syntax within TIR Macro.
"""
@@ -62,8 +63,8 @@ def __init__(
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
- is_m_first: Optional[bool] = False,
- thread_var: Optional[Var] = None,
+ is_m_first: bool | None = False,
+ thread_var: Var | None = None,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
@@ -144,7 +145,7 @@ def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
self.micro_size_x = m_dim
self.micro_size_k = k_dim
- def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
+ def _initialize_is_m_first(self, is_m_first: bool | None = False):
if is_m_first is not None:
self.is_m_first = is_m_first
@@ -167,7 +168,7 @@ def get_store_index_map(self, inverse: bool = False) -> IndexMap:
def extract_thread_binding(
self,
thread_id: PrimExpr,
- is_m_first: Optional[bool] = None) -> Tuple[PrimExpr, PrimExpr, PrimExpr]:
+ is_m_first: bool | None = None) -> tuple[PrimExpr, PrimExpr, PrimExpr]:
"""
is_m_first: True if the thread binding is in the form of (tx, warp_n, warp_m)
which represents [warp_size, block_row_warps (split n), block_col_warps (split m)]
@@ -200,7 +201,7 @@ def ldmatrix_a(self,
A_local_buf: Buffer,
A_shared_buf: Buffer,
ki: PrimExpr,
- rk: Optional[PrimExpr] = 0):
+ rk: PrimExpr | None = 0):
warp_row_tiles = self.warp_row_tiles
warp_rows = self.warp_rows
chunk = self.chunk
@@ -264,7 +265,7 @@ def ldmatrix_b(self,
B_local_buf: Buffer,
B_shared_buf: Buffer,
ki: PrimExpr,
- rk: Optional[PrimExpr] = 0):
+ rk: PrimExpr | None = 0):
warp_col_tiles = self.warp_col_tiles
warp_cols = self.warp_cols
chunk = self.chunk
@@ -336,7 +337,7 @@ def mma(self,
A_local_buf: Buffer,
B_local_buf: Buffer,
C_local_buf: Buffer,
- k_inner: Optional[PrimExpr] = 0):
+ k_inner: PrimExpr | None = 0):
warp_rows = self.warp_rows
warp_cols = self.warp_cols
local_size_a = self.local_size_a
@@ -518,8 +519,7 @@ def make_mma_load_layout(self,
else:
raise ValueError(f"Unsupported matrix {matrix}")
- assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
- local_buf.scope())
+ assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
if matrix_is_a:
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
@@ -684,9 +684,9 @@ def __init__(
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
- is_m_first: Optional[bool] = False,
- transform_kind_a: Union[int, TransformKind] = 0,
- transform_kind_b: Union[int, TransformKind] = 0,
+ is_m_first: bool | None = False,
+ transform_kind_a: int | TransformKind = 0,
+ transform_kind_b: int | TransformKind = 0,
):
super().__init__(
a_dtype=a_dtype,
diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py
index 9d64a15fe..d9d591f72 100644
--- a/tilelang/intrinsics/wgmma_macro_generator.py
+++ b/tilelang/intrinsics/wgmma_macro_generator.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
import tilelang.language as T
from enum import IntEnum
-from typing import Optional, Callable
+from typing import Callable
from .mma_macro_generator import TensorCoreIntrinEmitter as MMAIntrinEmitter
from tvm import DataType
from tvm.tir import PrimExpr, Buffer, Var, IndexMap
@@ -86,8 +87,8 @@ def __init__(
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
- is_m_first: Optional[bool] = False,
- thread_var: Optional[Var] = None,
+ is_m_first: bool | None = False,
+ thread_var: Var | None = None,
):
super().__init__(a_dtype, b_dtype, accum_dtype, a_transposed, b_transposed, block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk, reduce_k,
@@ -409,8 +410,7 @@ def make_mma_load_layout(self, local_buf: Buffer, matrix: str = "A") -> T.Fragme
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
j, i)
- assert is_fragment(local_buf), "local_buf must be a fragment, but got {}".format(
- local_buf.scope())
+ assert is_fragment(local_buf), f"local_buf must be a fragment, but got {local_buf.scope()}"
micro_size_s, micro_size_r = self.micro_size_x, self.micro_size_k
diff --git a/tilelang/jit/__init__.py b/tilelang/jit/__init__.py
index 447e43b2a..78454a558 100644
--- a/tilelang/jit/__init__.py
+++ b/tilelang/jit/__init__.py
@@ -3,17 +3,13 @@
It includes functionality to JIT-compile TileLang programs into a runnable
kernel adapter using TVM.
"""
+from __future__ import annotations
from typing import (
Any,
- List,
- Union,
Callable,
- Tuple,
overload,
Literal,
- Dict, # For type hinting dicts
- Optional,
)
from tilelang import tvm as tvm
from tilelang.jit.adapter.utils import is_metal_target
@@ -33,13 +29,13 @@
def compile(
func: PrimFunc = None,
- out_idx: Union[List[int], int, None] = None,
+ out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target, None] = None,
+ target: str | Target = "auto",
+ target_host: str | Target | None = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[Union[List[str], str]] = None,
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Compile the given TileLang PrimFunc with TVM and build a JITKernel.
@@ -92,24 +88,24 @@ def compile(
class _JitImplementation:
- out_idx: Optional[Union[List[int], int]]
- target: Union[str, Target]
- target_host: Union[str, Target]
+ out_idx: list[int] | int | None
+ target: str | Target
+ target_host: str | Target
execution_backend: Literal["dlpack", "ctypes", "cython"]
verbose: bool
- pass_configs: Optional[Dict[str, Any]]
- debug_root_path: Optional[str]
- compile_flags: Optional[Union[List[str], str]]
+ pass_configs: dict[str, Any] | None
+ debug_root_path: str | None
+ compile_flags: list[str] | str | None
def __init__(self,
out_idx: Any = None,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- debug_root_path: Optional[str] = None,
- compile_flags: Optional[Union[List[str], str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ debug_root_path: str | None = None,
+ compile_flags: list[str] | str | None = None):
"""
Initializes the JIT compiler decorator.
@@ -162,12 +158,12 @@ def __init__(self,
except NameError:
self.debug_root_path = path.abspath(self.debug_root_path)
- self._kernel_cache: Dict[tuple, Kernel] = {}
+ self._kernel_cache: dict[tuple, Kernel] = {}
# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
- def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, Kernel]]:
+ def __call__(self, func: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, Kernel]]:
...
@overload
@@ -242,16 +238,16 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> Any:
def jit( # This is the new public interface
- func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
+ func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
out_idx: Any = None,
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- debug_root_path: Optional[str] = None,
- compile_flags: Optional[Union[List[str], str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ debug_root_path: str | None = None,
+ compile_flags: list[str] | str | None = None):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
diff --git a/tilelang/jit/adapter/base.py b/tilelang/jit/adapter/base.py
index 1b584d71c..9d998bc96 100644
--- a/tilelang/jit/adapter/base.py
+++ b/tilelang/jit/adapter/base.py
@@ -1,21 +1,22 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import Any, List, Callable, Optional
+from typing import Any, Callable
from tilelang.engine.param import KernelParam
class BaseKernelAdapter(ABC):
- func: Optional[Callable] = None
+ func: Callable | None = None
- def __init__(self, mod, params: List[KernelParam], result_idx: List[int]) -> None:
+ def __init__(self, mod, params: list[KernelParam], result_idx: list[int]) -> None:
self.mod = mod
self.params = params
self.result_idx = self._legalize_result_idx(result_idx)
self._post_init()
- def _legalize_result_idx(self, result_idx: Optional[List[int]]) -> List[int]:
+ def _legalize_result_idx(self, result_idx: list[int] | None) -> list[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
diff --git a/tilelang/jit/adapter/ctypes/adapter.py b/tilelang/jit/adapter/ctypes/adapter.py
index 7ec6cef0d..648c66c1c 100644
--- a/tilelang/jit/adapter/ctypes/adapter.py
+++ b/tilelang/jit/adapter/ctypes/adapter.py
@@ -1,9 +1,10 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
import torch
from ..base import BaseKernelAdapter
import ctypes
-from typing import List, Optional, Union, Callable, Dict, Tuple, Any
+from typing import Callable, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tvm.relax import TensorType
@@ -25,32 +26,32 @@ class CtypesKernelAdapter(BaseKernelAdapter):
# Class attributes to store compiled kernel information
target = "cuda"
- ir_module: Optional[tvm.IRModule] = None
+ ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
- kernel_global_source: Optional[str] = None
- lib: Optional[ctypes.CDLL] = None # Compiled library handle
- wrapped_source: Optional[str] = None # Generated C++ wrapper code
+ kernel_global_source: str | None = None
+ lib: ctypes.CDLL | None = None # Compiled library handle
+ wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
- dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
+ dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Pass configs for the compiler
- pass_configs: Optional[Dict[str, Any]] = None
+ pass_configs: dict[str, Any] | None = None
# Add new cache attributes
- param_dtypes: Optional[List[torch.dtype]] = None # Cache for parameter dtypes
- param_shapes: Optional[List[List]] = None # Cache for parameter shapes
+ param_dtypes: list[torch.dtype] | None = None # Cache for parameter dtypes
+ param_shapes: list[list] | None = None # Cache for parameter shapes
def __init__(self,
- params: List[TensorType],
- result_idx: List[int],
+ params: list[TensorType],
+ result_idx: list[int],
target: str,
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
- host_mod: Optional[tvm.IRModule] = None,
- device_mod: Optional[tvm.IRModule] = None,
- kernel_global_source: Optional[str] = None,
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
+ host_mod: tvm.IRModule | None = None,
+ device_mod: tvm.IRModule | None = None,
+ kernel_global_source: str | None = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module.
Args:
@@ -107,15 +108,15 @@ def __init__(self,
@classmethod
def from_database(cls,
- params: List[TensorType],
- result_idx: List[int],
+ params: list[TensorType],
+ result_idx: list[int],
target: str,
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
@@ -155,7 +156,7 @@ def from_database(cls,
adapter._post_init()
return adapter
- def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
+ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
@@ -182,7 +183,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
- def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
+ def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
@@ -193,9 +194,7 @@ def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
ctypes_args.append(ctypes.c_void_p(stream))
self.lib.call(*ctypes_args)
- def _wrap_forward_from_prebuild_lib(self,
- *ins: List[torch.Tensor],
- stream: Optional[int] = None):
+ def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
"""High-level wrapper for kernel execution.
Handles:
diff --git a/tilelang/jit/adapter/cython/adapter.py b/tilelang/jit/adapter/cython/adapter.py
index d210de46c..7857872cf 100644
--- a/tilelang/jit/adapter/cython/adapter.py
+++ b/tilelang/jit/adapter/cython/adapter.py
@@ -1,10 +1,11 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
import ctypes
import logging
import torch
-from typing import List, Optional, Union, Callable, Dict, Tuple, Any
+from typing import Callable, Any
from tilelang import tvm as tvm
from tvm.target import Target
from tilelang.engine.param import KernelParam
@@ -44,43 +45,43 @@ class CythonKernelAdapter(BaseKernelAdapter):
"""
# Class attributes to store compiled kernel information
- target: Union[str, Target] = "cuda"
- ir_module: Optional[tvm.IRModule] = None
+ target: str | Target = "cuda"
+ ir_module: tvm.IRModule | None = None
# The global source code of the kernel -> global means the source code of the kernel
# that is not wrapped by the wrapper code
- kernel_global_source: Optional[str] = None
- lib: Optional[ctypes.CDLL] = None # Compiled library handle
- wrapped_source: Optional[str] = None # Generated C++ wrapper code
+ kernel_global_source: str | None = None
+ lib: ctypes.CDLL | None = None # Compiled library handle
+ wrapped_source: str | None = None # Generated C++ wrapper code
# Maps symbolic variables to their corresponding buffer and shape indices
- dynamic_symbolic_map: Optional[Dict[tir.Var, Tuple[int, int]]] = None
+ dynamic_symbolic_map: dict[tir.Var, tuple[int, int]] | None = None
# Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
- ptr_map: Optional[Dict[int, str]] = None
+ ptr_map: dict[int, str] | None = None
# Maps buffer variables to their corresponding dtypes
- buffer_dtype_map: Optional[Dict[tir.Var, Tuple[int, torch.dtype]]] = None
+ buffer_dtype_map: dict[tir.Var, tuple[int, torch.dtype]] | None = None
# Maps buffer variables to their corresponding static shapes and strides,
# e.g., {
# "A": [(0, 16), (1, 16)] -> represents A.shape/strides = (16, 16)
# }
- static_shape_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
- static_strides_map: Optional[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]]] = None
+ static_shape_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
+ static_strides_map: dict[tir.Var, tuple[int, list[tuple[int, int]]]] | None = None
# Contains contiguous buffers
- static_contiguous_list: Optional[List[tir.Var]] = None
+ static_contiguous_list: list[tir.Var] | None = None
# Maps buffer variables to their corresponding devices
- buffer_device_map: Optional[Dict[tir.Var, Tuple[int, torch.device]]] = None
+ buffer_device_map: dict[tir.Var, tuple[int, torch.device]] | None = None
# Pass configs for the compiler
- pass_configs: Optional[Dict[str, Any]] = None
+ pass_configs: dict[str, Any] | None = None
def __init__(self,
- params: List[KernelParam],
- result_idx: List[int],
- target: Union[str, Target],
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
- host_mod: Optional[tvm.IRModule] = None,
- device_mod: Optional[tvm.IRModule] = None,
- kernel_global_source: Optional[str] = None,
+ params: list[KernelParam],
+ result_idx: list[int],
+ target: str | Target,
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
+ host_mod: tvm.IRModule | None = None,
+ device_mod: tvm.IRModule | None = None,
+ kernel_global_source: str | None = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
"""Initialize the adapter with the given TIR function or module.
Args:
@@ -146,15 +147,15 @@ def __init__(self,
@classmethod
def from_database(cls,
- params: List[TensorType],
- result_idx: List[int],
+ params: list[TensorType],
+ result_idx: list[int],
target: str,
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
@@ -205,7 +206,7 @@ def from_database(cls,
adapter._post_init()
return adapter
- def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
+ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (id, buffer_index, dimension)
@@ -232,7 +233,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int, int]]:
dynamic_symbolic_map[stride] = (1, i, j)
return dynamic_symbolic_map
- def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
+ def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]:
"""Extract information about buffer dtypes from the TIR function.
Maps buffer variables to their corresponding dtypes.
@@ -248,7 +249,7 @@ def _process_buffer_dtype(self) -> Dict[tir.Var, Tuple[int, torch.dtype]]:
buffer_dtype_map[name] = (i, map_torch_type(dtype))
return buffer_dtype_map
- def _process_ptr_map(self) -> Dict[int, str]:
+ def _process_ptr_map(self) -> dict[int, str]:
"""Extract information about pointer arguments from the TIR function.
Maps pointer arguments to their corresponding (buffer_index, shape_dimension)
@@ -263,9 +264,9 @@ def _process_ptr_map(self) -> Dict[int, str]:
return ptr_map
def _process_static_buffer_infos(self) -> \
- Tuple[Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
- Dict[tir.Var, Tuple[int, List[Tuple[int, int]]]],
- List[Tuple[tir.Var]]]:
+ tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]],
+ dict[tir.Var, tuple[int, list[tuple[int, int]]]],
+ list[tuple[tir.Var]]]:
"""Extract information about static shapes from the TIR function.
Maps buffer variables to their corresponding static shapes.
@@ -300,7 +301,7 @@ def _process_static_buffer_infos(self) -> \
static_contiguous_list.append((i, buffer.name))
return static_shape_map, static_strides_map, static_contiguous_list
- def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
+ def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]:
"""Extract information about buffer devices from the TIR function.
Maps buffer variables to their corresponding devices.
@@ -326,7 +327,7 @@ def _process_buffer_device(self) -> Dict[tir.Var, Tuple[int, torch.device]]:
buffer_device_map[name] = (i, device)
return buffer_device_map
- def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
+ def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
Converts PyTorch tensor pointers to C void pointers for ctypes interface.
diff --git a/tilelang/jit/adapter/dlpack.py b/tilelang/jit/adapter/dlpack.py
index b45742433..9fa767f04 100644
--- a/tilelang/jit/adapter/dlpack.py
+++ b/tilelang/jit/adapter/dlpack.py
@@ -1,7 +1,7 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
import torch
-from typing import List
from tilelang.contrib.dlpack import to_pytorch_func
from .base import BaseKernelAdapter
@@ -11,7 +11,7 @@ class TorchDLPackKernelAdapter(BaseKernelAdapter):
def _convert_torch_func(self) -> callable:
torch_func = to_pytorch_func(self.mod)
- def func(*ins: List[torch.Tensor]):
+ def func(*ins: list[torch.Tensor]):
if len(ins) + len(self.result_idx) != len(self.params):
raise ValueError(
f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs"
diff --git a/tilelang/jit/adapter/libgen.py b/tilelang/jit/adapter/libgen.py
index 5d1143a67..1e33ec040 100644
--- a/tilelang/jit/adapter/libgen.py
+++ b/tilelang/jit/adapter/libgen.py
@@ -1,3 +1,4 @@
+from __future__ import annotations
import ctypes
import importlib
import logging
@@ -5,7 +6,7 @@
import os.path as osp
import subprocess
import tempfile
-from typing import Any, Dict, Optional, List
+from typing import Any
from tvm.target import Target
@@ -29,21 +30,21 @@
is_nvrtc_available = False
-class LibraryGenerator(object):
- srcpath: Optional[str] = None
- libpath: Optional[str] = None
- lib_code: Optional[str] = None
- pass_configs: Optional[Dict[str, Any]] = None
- compile_flags: Optional[List[str]] = None
+class LibraryGenerator:
+ srcpath: str | None = None
+ libpath: str | None = None
+ lib_code: str | None = None
+ pass_configs: dict[str, Any] | None = None
+ compile_flags: list[str] | None = None
def __init__(self, target: Target, verbose: bool = False):
self.target = target
self.verbose = verbose
- def assign_pass_configs(self, pass_configs: Optional[Dict[str, Any]] = None):
+ def assign_pass_configs(self, pass_configs: dict[str, Any] | None = None):
self.pass_configs = pass_configs
- def assign_compile_flags(self, compile_flags: Optional[List[str]] = None):
+ def assign_compile_flags(self, compile_flags: list[str] | None = None):
if compile_flags is None:
compile_flags = []
self.compile_flags = compile_flags
@@ -52,7 +53,7 @@ def update_lib_code(self, lib_code: str):
self.lib_code = lib_code
# Assume currently we only support CUDA compilation
- def load_lib(self, lib_path: Optional[str] = None):
+ def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
else:
@@ -185,7 +186,7 @@ def set_src_path(self, srcpath):
class PyLibraryGenerator(LibraryGenerator):
- host_func: Optional[str] = None
+ host_func: str | None = None
culib = None
pymodule = None
@@ -206,7 +207,7 @@ def import_from_file(module_name, file_path):
def update_host_func(self, host_func: str):
self.host_func = host_func
- def load_lib(self, lib_path: Optional[str] = None):
+ def load_lib(self, lib_path: str | None = None):
if lib_path is None:
lib_path = self.libpath
diff --git a/tilelang/jit/adapter/nvrtc/adapter.py b/tilelang/jit/adapter/nvrtc/adapter.py
index d1fd9d421..d6723a031 100644
--- a/tilelang/jit/adapter/nvrtc/adapter.py
+++ b/tilelang/jit/adapter/nvrtc/adapter.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable
import torch
from tvm import tir
@@ -26,16 +27,16 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
kernels = {}
def __init__(self,
- params: List[KernelParam],
- result_idx: List[int],
- target: Union[str, Target],
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
- host_mod: Optional[tvm.IRModule] = None,
- device_mod: Optional[tvm.IRModule] = None,
- kernel_global_source: Optional[str] = None,
+ params: list[KernelParam],
+ result_idx: list[int],
+ target: str | Target,
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
+ host_mod: tvm.IRModule | None = None,
+ device_mod: tvm.IRModule | None = None,
+ kernel_global_source: str | None = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
check_nvrtc_available()
@@ -91,15 +92,15 @@ def __init__(self,
@classmethod
def from_database(cls,
- params: List[KernelParam],
- result_idx: List[int],
+ params: list[KernelParam],
+ result_idx: list[int],
target: str,
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
kernel_global_source: str,
kernel_lib_path: str,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None):
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None):
adapter = cls.__new__(cls)
adapter.params = params
adapter.result_idx = adapter._legalize_result_idx(result_idx)
@@ -143,7 +144,7 @@ def from_database(cls,
adapter._post_init()
return adapter
- def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
+ def _process_dynamic_symbolic(self) -> dict[tir.Var, tuple[int, int]]:
"""Extract information about dynamic shapes from the TIR function.
Maps symbolic variables to their corresponding (buffer_index, shape_dimension)
@@ -165,7 +166,7 @@ def _process_dynamic_symbolic(self) -> Dict[tir.Var, Tuple[int, int]]:
dynamic_symbolic_map[shape] = (i, j)
return dynamic_symbolic_map
- def get_kernel_source(self) -> Optional[str]:
+ def get_kernel_source(self) -> str | None:
"""Get the CUDA kernel source code.
Returns
@@ -175,14 +176,12 @@ def get_kernel_source(self) -> Optional[str]:
"""
return self.kernel_global_source
- def _forward_from_prebuild_lib(self, *args, stream: Optional[int] = None):
+ def _forward_from_prebuild_lib(self, *args, stream: int | None = None):
"""Low-level function to call the compiled CUDA kernel.
"""
return self.pymodule.call(self.kernels, *args, stream=stream)
- def _wrap_forward_from_prebuild_lib(self,
- *ins: List[torch.Tensor],
- stream: Optional[int] = None):
+ def _wrap_forward_from_prebuild_lib(self, *ins: list[torch.Tensor], stream: int | None = None):
"""High-level wrapper for kernel execution.
Handles:
@@ -242,7 +241,7 @@ def _wrap_forward_from_prebuild_lib(self,
else:
return [args[i] for i in self.result_idx]
- def _convert_torch_func(self) -> Callable[..., Union[torch.Tensor, List[torch.Tensor]]]:
+ def _convert_torch_func(self) -> Callable[..., torch.Tensor | list[torch.Tensor]]:
"""Convert to a PyTorch-compatible function.
Returns
diff --git a/tilelang/jit/adapter/torch/metal.py b/tilelang/jit/adapter/torch/metal.py
index 9693fca06..30e84ad71 100644
--- a/tilelang/jit/adapter/torch/metal.py
+++ b/tilelang/jit/adapter/torch/metal.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
from functools import wraps
-from typing import Callable, Optional, Union, List
+from typing import Callable
import torch
from tvm import tir
@@ -14,13 +15,13 @@ class MetalKernelAdapter(BaseKernelAdapter):
def __init__(
self,
- params: List[KernelParam],
- result_idx: List[int],
+ params: list[KernelParam],
+ result_idx: list[int],
# target: Union[str, Target],
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
# host_mod: Optional[tvm.IRModule] = None,
- device_mod: Optional[tvm.IRModule] = None,
- kernel_global_source: Optional[str] = None,
+ device_mod: tvm.IRModule | None = None,
+ kernel_global_source: str | None = None,
verbose: bool = False,
# pass_configs: Optional[Dict[str, Any]] = None,
# compile_flags: Optional[List[str]] = None
diff --git a/tilelang/jit/adapter/utils.py b/tilelang/jit/adapter/utils.py
index 6a09d6f6f..efc965e1b 100644
--- a/tilelang/jit/adapter/utils.py
+++ b/tilelang/jit/adapter/utils.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import re
-from typing import Union, Optional, Literal, Dict
+from typing import Literal
from tilelang import tvm as tvm
from tvm import IRModule, tir
from tvm.target import Target
@@ -65,11 +65,11 @@ def is_metal_target(target: Target) -> bool:
def get_annotated_mod(
- func_or_mod: Union[tir.PrimFunc, tvm.IRModule],
- target: Union[str, Target] = "auto",
- target_host: Optional[Union[str, Target]] = None,
+ func_or_mod: tir.PrimFunc | tvm.IRModule,
+ target: str | Target = "auto",
+ target_host: str | Target | None = None,
model_type: Literal["device", "host", "all"] = "all",
-) -> Union[IRModule, tuple[IRModule, IRModule]]:
+) -> IRModule | tuple[IRModule, IRModule]:
# Validate model_type early
if model_type not in {"device", "host", "all"}:
@@ -107,7 +107,7 @@ def get_annotated_mod(
return dispatch[model_type](mod)
-def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: Optional[Dict[str, str]] = None) -> str:
+def pythonic_expr(expr: tvm.tir.PrimExpr, dtype_map: dict[str, str] | None = None) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.
diff --git a/tilelang/jit/adapter/wrapper.py b/tilelang/jit/adapter/wrapper.py
index 9c032826f..235e70135 100644
--- a/tilelang/jit/adapter/wrapper.py
+++ b/tilelang/jit/adapter/wrapper.py
@@ -1,6 +1,7 @@
+from __future__ import annotations
from abc import ABC, abstractmethod
from tilelang import tvm as tvm
-from typing import Optional, List, Dict, Union, Any
+from typing import Any
from tvm import IRModule
from tvm.target import Target
from .utils import (is_metal_target, match_declare_kernel, match_declare_kernel_cpu, is_cuda_target,
@@ -176,7 +177,7 @@ def wrap(self, *args, **kwargs):
logger = logging.getLogger(__name__)
-class TLCUDASourceWrapper(object):
+class TLCUDASourceWrapper:
_TYPE_MAP = {
"float32": "float",
"float16": "half_t",
@@ -196,33 +197,33 @@ class TLCUDASourceWrapper(object):
}
backend = "tl"
- device_mod: Optional[IRModule] = None
- host_mod: Optional[IRModule] = None
- pass_configs: Optional[Dict[str, Any]] = None
+ device_mod: IRModule | None = None
+ host_mod: IRModule | None = None
+ pass_configs: dict[str, Any] | None = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
- device_mod: Optional[IRModule] = None,
- host_mod: Optional[IRModule] = None,
- pass_configs: Optional[Dict[str, Any]] = None):
+ device_mod: IRModule | None = None,
+ host_mod: IRModule | None = None,
+ pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.pass_configs = pass_configs
self.device_mod = device_mod
self.host_mod = host_mod
- self.function_names: Optional[str] = None
- self.dynamic_smem_buf: Optional[int] = None
- self.block_info: Union[List[int], Dict] = [1, 1, 1]
- self.grid_info: Union[List[int], Dict] = [1, 1, 1]
- self.tma_descriptor_args: Optional[Dict] = None
- self.l2_persistent_map: Optional[Dict[str, Dict]] = {}
+ self.function_names: str | None = None
+ self.dynamic_smem_buf: int | None = None
+ self.block_info: list[int] | dict = [1, 1, 1]
+ self.grid_info: list[int] | dict = [1, 1, 1]
+ self.tma_descriptor_args: dict | None = None
+ self.l2_persistent_map: dict[str, dict] | None = {}
self.parse_source_information()
- self.srcpath: Optional[str] = None
- self.libpath: Optional[str] = None
- self.lib_code: Optional[str] = self.update_lib_code(source)
+ self.srcpath: str | None = None
+ self.libpath: str | None = None
+ self.lib_code: str | None = self.update_lib_code(source)
def _pythonic_expr(self, expr: tvm.tir.PrimExpr) -> str:
return pythonic_expr(expr, self._TYPE_MAP)
@@ -264,10 +265,10 @@ def create_dispatch_func(self, code, function_informations):
def func_call_args(s,
function_args,
function_params,
- desc_name_map: Optional[Dict[str, str]] = None,
- desc_name_var_map: Optional[Dict[str, tvm.tir.Var]] = None):
+ desc_name_map: dict[str, str] | None = None,
+ desc_name_var_map: dict[str, tvm.tir.Var] | None = None):
# Extract the function call arguments matching the function definition
- def maybe_desc(name: str, matches: List[str], i: int):
+ def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
@@ -305,8 +306,8 @@ def maybe_desc(name: str, matches: List[str], i: int):
kernel_launch_code = """"""
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_CREATE_HANDLE
- desc_name_map: Dict[str, str] = {}
- desc_name_var_map: Dict[str, tvm.tir.Var] = {}
+ desc_name_map: dict[str, str] = {}
+ desc_name_var_map: dict[str, tvm.tir.Var] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
grid_info = function_info["grid_info"]
@@ -322,14 +323,8 @@ def maybe_desc(name: str, matches: List[str], i: int):
# Identify the start of the function body to insert arguments
index = code.index("{", index)
- block_str = "dim3({}, {}, {})".format(
- self._pythonic_expr(block_info[0]),
- self._pythonic_expr(block_info[1]),
- self._pythonic_expr(block_info[2]),
- )
- grid_str = "dim3({}, {}, {})".format(
- self._pythonic_expr(grid_info[0]), self._pythonic_expr(grid_info[1]),
- self._pythonic_expr(grid_info[2]))
+ block_str = f"dim3({self._pythonic_expr(block_info[0])}, {self._pythonic_expr(block_info[1])}, {self._pythonic_expr(block_info[2])})"
+ grid_str = f"dim3({self._pythonic_expr(grid_info[0])}, {self._pythonic_expr(grid_info[1])}, {self._pythonic_expr(grid_info[2])})"
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
@@ -353,9 +348,8 @@ def maybe_desc(name: str, matches: List[str], i: int):
args_list
), f"Function {function_name} has {len(function_params)} parameters, but {len(args_list)} arguments"
call_args = ", ".join(args_list)
- kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
- function_name, grid_str, block_str, smem_str, call_args)
- kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
+ kernel_launch_code += f"\t{function_name}<<<{grid_str}, {block_str}, {smem_str}, stream>>>({call_args});\n"
+ kernel_launch_code += f"\tTILELANG_CHECK_LAST_ERROR(\"{function_name}\");\n"
if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
@@ -386,8 +380,8 @@ def generate_l2_persistent_map(self, function_name: str) -> str:
return init_l2_persistent_map
- def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
- desc_name_var_map: Dict[str, tvm.tir.Var]) -> str:
+ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str],
+ desc_name_var_map: dict[str, tvm.tir.Var]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
@@ -512,7 +506,7 @@ def parse_source_information(self):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
- dynamic_symbolic_set: List[str] = []
+ dynamic_symbolic_set: list[str] = []
def unique_push_back(name: str):
if name not in dynamic_symbolic_set:
@@ -565,7 +559,7 @@ def update_lib_code(self, code: str):
assert function_name in self.device_mod, f"Function {function_name} not found in device module"
device_func = self.device_mod[function_name]
kernel_params_cnt = len(device_func.params)
- function_params: List[str] = None
+ function_params: list[str] = None
def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
nonlocal function_params
@@ -599,7 +593,7 @@ def visitor(node, fn=function_name, param_cnt=kernel_params_cnt):
lib_code = self.source + init_func + host_func
return lib_code
- def get_stream_type(self) -> Dict[str, str]:
+ def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=cudaStreamDefault", "type": "cudaStream_t"}
@property
@@ -669,9 +663,9 @@ def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
- device_mod: Optional[IRModule] = None,
- host_mod: Optional[IRModule] = None,
- pass_configs: Optional[Dict[str, Any]] = None):
+ device_mod: IRModule | None = None,
+ host_mod: IRModule | None = None,
+ pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def create_dispatch_func(self, code, function_informations):
@@ -701,9 +695,9 @@ def create_dispatch_func(self, code, function_informations):
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['name']}" for arg in function_args])
- def func_call_args(s, function_args, desc_name_map: Optional[Dict[str, str]] = None):
+ def func_call_args(s, function_args, desc_name_map: dict[str, str] | None = None):
# Extract the function call arguments matching the function definition
- def maybe_desc(name: str, matches: List[str], i: int):
+ def maybe_desc(name: str, matches: list[str], i: int):
match = matches[i]
if not (match == name + "_desc" or match.startswith(name + "_desc_")):
return False
@@ -729,7 +723,7 @@ def maybe_desc(name: str, matches: List[str], i: int):
call_args.append((match, "None"))
return call_args
- desc_name_map: Dict[str, str] = {}
+ desc_name_map: dict[str, str] = {}
device_index = 0
kernel_launch_code = """"""
for function_name, function_info in function_informations.items():
@@ -766,7 +760,7 @@ def maybe_desc(name: str, matches: List[str], i: int):
repr(list(function_informations.keys())), def_args, kernel_launch_code)
return host_func
- def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
+ def generate_tma_descriptor_args(self, desc_name_map: dict[str, str]) -> str:
tma_descripter_init = ""
if self.tma_descriptor_args is None:
return tma_descripter_init
@@ -844,7 +838,7 @@ def update_lib_code(self, code: str):
self.host_func = self.create_dispatch_func(code, function_informations)
return self.lib_code
- def get_stream_type(self) -> Dict[str, str]:
+ def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=0", "type": "int"}
@@ -877,9 +871,9 @@ def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
- device_mod: Optional[IRModule] = None,
- host_mod: Optional[IRModule] = None,
- pass_configs: Optional[Dict[str, Any]] = None):
+ device_mod: IRModule | None = None,
+ host_mod: IRModule | None = None,
+ pass_configs: dict[str, Any] | None = None):
super().__init__(scheduled_ir_module, source, target, device_mod, host_mod, pass_configs)
def get_init_func(self):
@@ -895,11 +889,11 @@ def get_init_func(self):
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs
- def get_stream_type(self) -> Dict[str, str]:
+ def get_stream_type(self) -> dict[str, str]:
return {"name": "stream=hipStreamDefault", "type": "hipStream_t"}
-class TLCPUSourceWrapper(object):
+class TLCPUSourceWrapper:
_TYPE_MAP = {
"float32": "float",
"float16": "half",
@@ -925,29 +919,29 @@ class TLCPUSourceWrapper(object):
""")
backend = "tl"
- device_mod: Optional[IRModule] = None
- host_mod: Optional[IRModule] = None
- pass_configs: Optional[Dict[str, Any]] = None
+ device_mod: IRModule | None = None
+ host_mod: IRModule | None = None
+ pass_configs: dict[str, Any] | None = None
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
- device_mod: Optional[IRModule] = None,
- host_mod: Optional[IRModule] = None,
- pass_configs: Optional[Dict[str, Any]] = None):
+ device_mod: IRModule | None = None,
+ host_mod: IRModule | None = None,
+ pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
self.device_mod = device_mod
self.host_mod = host_mod
self.pass_configs = pass_configs
- self.function_names: Optional[str] = None
- self.dynamic_smem_buf: Optional[int] = None
+ self.function_names: str | None = None
+ self.dynamic_smem_buf: int | None = None
self.parse_source_information()
- self.srcpath: Optional[str] = None
- self.libpath: Optional[str] = None
- self.lib_code: Optional[str] = self.update_lib_code(source)
+ self.srcpath: str | None = None
+ self.libpath: str | None = None
+ self.lib_code: str | None = self.update_lib_code(source)
def create_call_func(self, code, function_informations):
# Extract the set of dynamic symbolic names used in the primary function
@@ -997,7 +991,7 @@ def func_call_args(s, function_args):
index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args))
- _call_str += "{}({})".format(function_name, call_args)
+ _call_str += f"{function_name}({call_args})"
# Wrap the kernel dispatch logic in an external C function
host_func = self.CALL_PREFIX.format(def_args, _call_str)
@@ -1018,7 +1012,7 @@ def parse_source_information(self):
def get_dynamic_symbolic_set(self, prim_func):
# Determine the set of dynamic symbols used in the function
- dynamic_symbolic_set: List[str] = []
+ dynamic_symbolic_set: list[str] = []
for param in prim_func.params:
if param in prim_func.buffer_map:
buffer = prim_func.buffer_map[param]
@@ -1066,15 +1060,15 @@ def prim_func(self):
raise ValueError("Cannot find primary function in the module.")
-class TLMetalSourceWrapper(object):
+class TLMetalSourceWrapper:
def __init__(self,
scheduled_ir_module: IRModule,
source: str,
target: Target,
- device_mod: Optional[IRModule] = None,
- host_mod: Optional[IRModule] = None,
- pass_configs: Optional[Dict[str, Any]] = None):
+ device_mod: IRModule | None = None,
+ host_mod: IRModule | None = None,
+ pass_configs: dict[str, Any] | None = None):
self.mod = scheduled_ir_module
self.target = target
self.source = source
@@ -1092,11 +1086,11 @@ class TLWrapper(BaseWrapper):
"""
A wrapper class for the TileLang backend.
"""
- device_mod: Optional[IRModule] = None
- host_mod: Optional[IRModule] = None
- pass_configs: Optional[Dict[str, Any]] = None
- target: Optional[Target] = None
- lib: Optional[object] = None
+ device_mod: IRModule | None = None
+ host_mod: IRModule | None = None
+ pass_configs: dict[str, Any] | None = None
+ target: Target | None = None
+ lib: object | None = None
def __init__(self, target: Target):
super().__init__()
@@ -1108,7 +1102,7 @@ def __init__(self, target: Target):
def assign_optimized_module(self, scheduled_ir_module: IRModule):
self.scheduled_ir_module = scheduled_ir_module
- def assign_pass_configs(self, pass_configs: Dict[str, Any]):
+ def assign_pass_configs(self, pass_configs: dict[str, Any]):
self.pass_configs = pass_configs
def assign_host_module(self, host_mod: IRModule):
diff --git a/tilelang/jit/kernel.py b/tilelang/jit/kernel.py
index 64fc7bdf1..71dafffb2 100644
--- a/tilelang/jit/kernel.py
+++ b/tilelang/jit/kernel.py
@@ -1,4 +1,5 @@
-from typing import Any, Callable, Dict, List, Literal, Optional, Union
+from __future__ import annotations
+from typing import Any, Callable, Literal
from tilelang.jit.adapter.utils import is_metal_target
from tvm.target import Target
@@ -17,7 +18,7 @@
logger = logging.getLogger(__name__)
-class JITKernel(object):
+class JITKernel:
"""
A wrapper class for compiling and invoking TileLang (TVM TIR) functions as PyTorch-compatible functions.
@@ -37,20 +38,20 @@ class JITKernel(object):
# tuner result
latency: float = None
- config: Dict[str, Any] = None
+ config: dict[str, Any] = None
ref_latency: float = None
def __init__(
self,
func: PrimFunc = None,
- out_idx: Union[List[int], int] = None,
+ out_idx: list[int] | int = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
- target: Union[str, Target] = "auto",
- target_host: Union[str, Target] = None,
+ target: str | Target = "auto",
+ target_host: str | Target = None,
verbose: bool = False,
- pass_configs: Optional[Dict[str, Any]] = None,
+ pass_configs: dict[str, Any] | None = None,
from_database: bool = False,
- compile_flags: Optional[List[str]] = None,
+ compile_flags: list[str] | None = None,
):
"""
Initializes a TorchFunction instance.
@@ -138,13 +139,13 @@ def from_database(
func: PrimFunc,
kernel_global_source: str,
kernel_lib_path: str,
- params: List[KernelParam],
- target: Union[str, Target],
- target_host: Union[str, Target],
- out_idx: Union[List[int], int],
+ params: list[KernelParam],
+ target: str | Target,
+ target_host: str | Target,
+ out_idx: list[int] | int,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"],
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None,
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None,
):
"""
Alternative constructor to create a TorchFunction directly from a database.
@@ -192,7 +193,7 @@ def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.torch_function(*args, **kwds)
def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
- out_idx: List[int]) -> BaseKernelAdapter:
+ out_idx: list[int]) -> BaseKernelAdapter:
"""
Compiles the given TileLang PrimFunc using TVM and creates a kernel adapter.
@@ -295,16 +296,15 @@ def _compile_and_create_adapter(self, tilelang_func: PrimFunc,
return adapter
- def _create_adapter_from_database(
- self,
- params: List[KernelParam],
- result_idx: Union[List[int], int],
- target: Union[str, Target],
- func_or_mod: Union[PrimFunc, tvm.runtime.Module],
- kernel_global_source: str,
- kernel_lib_path: str,
- pass_configs: Optional[Dict[str, Any]] = None,
- compile_flags: Optional[List[str]] = None) -> BaseKernelAdapter:
+ def _create_adapter_from_database(self,
+ params: list[KernelParam],
+ result_idx: list[int] | int,
+ target: str | Target,
+ func_or_mod: PrimFunc | tvm.runtime.Module,
+ kernel_global_source: str,
+ kernel_lib_path: str,
+ pass_configs: dict[str, Any] | None = None,
+ compile_flags: list[str] | None = None) -> BaseKernelAdapter:
target = self.target
execution_backend = self.execution_backend
@@ -405,11 +405,11 @@ def get_host_source(self) -> str:
"""
return str(self.artifact.host_mod)
- def run_once(self, func: Optional[Callable] = None) -> None:
+ def run_once(self, func: Callable | None = None) -> None:
return self.get_profiler().run_once(func)
- def update_tuner_result(self, latency: float, config: Dict[str, Any],
- ref_latency: float) -> "JITKernel":
+ def update_tuner_result(self, latency: float, config: dict[str, Any],
+ ref_latency: float) -> JITKernel:
"""
Updates the tuning results for this kernel.
@@ -432,7 +432,7 @@ def update_tuner_result(self, latency: float, config: Dict[str, Any],
return self
- def get_tuner_result(self) -> Dict[str, Any]:
+ def get_tuner_result(self) -> dict[str, Any]:
"""
Gets the tuning results for this kernel.
@@ -454,11 +454,11 @@ def get_tuner_result(self) -> Dict[str, Any]:
}
@property
- def out_idx(self) -> List[int]:
+ def out_idx(self) -> list[int]:
return self.adapter.result_idx
@property
- def params(self) -> List[KernelParam]:
+ def params(self) -> list[KernelParam]:
return self.artifact.params if self.artifact else self.adapter.params
@property
diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py
index 994f338f2..1a26b53d0 100644
--- a/tilelang/language/__init__.py
+++ b/tilelang/language/__init__.py
@@ -1,6 +1,6 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import Optional
# from .parser import *
# now is fully compatible with the upstream
# tir script
@@ -90,6 +90,6 @@
)
-def import_source(source: Optional[str] = None):
+def import_source(source: str | None = None):
# source is the source code to be imported
return block_attr({"pragma_import_c": source}) if source is not None else None
diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py
index cee46ca2f..12d3af4d3 100644
--- a/tilelang/language/annotations.py
+++ b/tilelang/language/annotations.py
@@ -1,6 +1,7 @@
"""Annotation helpers exposed on the TileLang language surface."""
+from __future__ import annotations
-from typing import Callable, Dict
+from typing import Callable
from tilelang.layout import Layout
from tvm.script.parser.tir import attr, block_attr
@@ -21,7 +22,7 @@ def use_swizzle(panel_size: int, order: str = "row", enable: bool = True):
return attr(None, "threadblock_swizzle_pattern", f"tl::{device_func}<{panel_size}>")
-def annotate_layout(layout_map: Dict):
+def annotate_layout(layout_map: dict):
"""Annotate the layout of the buffer."""
_layout_map = {}
for buffer, layout in layout_map.items():
@@ -35,7 +36,7 @@ def annotate_layout(layout_map: Dict):
return block_attr({"layout_map": _layout_map})
-def annotate_safe_value(safe_value_map: Dict):
+def annotate_safe_value(safe_value_map: dict):
"""Annotate the safe value of the buffer."""
_safe_value_map = {}
for buffer, safe_value in safe_value_map.items():
@@ -43,7 +44,7 @@ def annotate_safe_value(safe_value_map: Dict):
return block_attr({"safe_value_map": _safe_value_map})
-def annotate_l2_hit_ratio(l2_hit_ratio_map: Dict):
+def annotate_l2_hit_ratio(l2_hit_ratio_map: dict):
"""Annotate the L2 hit ratio of the buffer."""
_l2_hit_ratio_map = {}
for buffer, hit_ratio in l2_hit_ratio_map.items():
diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py
index eb2d18526..f1b37d236 100644
--- a/tilelang/language/atomic.py
+++ b/tilelang/language/atomic.py
@@ -1,11 +1,11 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
"""Atomic operations for tilelang."""
+from __future__ import annotations
import tilelang.language as T
from tvm import ir, tir
from tvm.tir import PrimExpr, Buffer, BufferRegion, Var, op
-from typing import Optional
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
from tilelang.utils.language import get_buffer_region_from_load
@@ -21,7 +21,7 @@
def atomic_max(dst: Buffer,
value: PrimExpr,
- memory_order: Optional[str] = None,
+ memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
@@ -67,7 +67,7 @@ def atomic_max(dst: Buffer,
def atomic_min(dst: Buffer,
value: PrimExpr,
- memory_order: Optional[str] = None,
+ memory_order: str | None = None,
return_prev: bool = False) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
@@ -115,7 +115,7 @@ def atomic_min(dst: Buffer,
def atomic_add(dst: Buffer,
value: PrimExpr,
- memory_order: Optional[str] = None,
+ memory_order: str | None = None,
return_prev: bool = False,
use_tma: bool = False) -> PrimExpr:
"""
diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py
index f9867f235..f0b223f46 100644
--- a/tilelang/language/builtin.py
+++ b/tilelang/language/builtin.py
@@ -1,17 +1,18 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tilelang import tvm as tvm
from tilelang.language import ptx_arrive_barrier, evaluate
from tilelang.language.kernel import get_thread_bindings, get_block_extents
from tilelang.utils.target import check_hip_availability
from tvm import tir
-from typing import Union, Any, Optional
+from typing import Any
from tvm.tir import PrimExpr, Var, Call, Buffer, BufferLoad
_IS_HIP_AVAILABLE = check_hip_availability()
-def _normalize_index_arg(value: Optional[Union[int, PrimExpr]]) -> Optional[PrimExpr]:
+def _normalize_index_arg(value: int | PrimExpr | None) -> PrimExpr | None:
"""
Normalize warp sizing arguments so both Python ints and PrimExpr values
are accepted uniformly.
@@ -183,7 +184,7 @@ def disable_warp_group_reg_alloc():
return no_set_max_nreg()
-def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
+def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var):
"""Wait for memory barrier parity condition.
Args:
@@ -233,7 +234,7 @@ def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union
return tir.call_intrin("handle", tir.op.Op.get("tl.mbarrier_wait_parity"), mbarrier, parity)
-def mbarrier_arrive(mbarrier: Union[int, PrimExpr, tir.Call]):
+def mbarrier_arrive(mbarrier: int | PrimExpr | tir.Call):
"""Arrive at memory barrier.
Args:
@@ -294,7 +295,7 @@ def warpgroup_wait(num_mma: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.warpgroup_wait"), num_mma)
-def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
+def get_lane_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the logical lane index of the calling thread within a warp.
Parameters
@@ -319,7 +320,7 @@ def get_lane_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
return tir.call_intrin("int32", tir.op.Op.get("tl.get_lane_idx"), warp_size_expr)
-def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
+def get_warp_idx_sync(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index, assuming the warp's threads are converged.
Parameters
@@ -343,7 +344,7 @@ def get_warp_idx_sync(warp_size: Optional[Union[int, PrimExpr]] = None,) -> Prim
return tir.call_intrin("int32", tir.op.Op.get("tl.get_warp_idx_sync"), warp_size_expr)
-def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
+def get_warp_idx(warp_size: int | PrimExpr | None = None,) -> PrimExpr:
"""Return the canonical warp index without synchronizing the warp.
Parameters
@@ -368,8 +369,8 @@ def get_warp_idx(warp_size: Optional[Union[int, PrimExpr]] = None,) -> PrimExpr:
def get_warp_group_idx(
- warp_size: Optional[Union[int, PrimExpr]] = None,
- warps_per_group: Optional[Union[int, PrimExpr]] = None,
+ warp_size: int | PrimExpr | None = None,
+ warps_per_group: int | PrimExpr | None = None,
) -> PrimExpr:
"""Return the canonical warp group index for the calling thread.
@@ -441,7 +442,7 @@ def wait_wgmma(id: int):
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), id)
-def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None):
+def barrier_wait(barrier_id: int | PrimExpr | tir.Call, parity: int | Var | None = None):
"""Wait for a memory barrier to complete.
Args:
@@ -456,7 +457,7 @@ def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int,
return mbarrier_wait_parity(barrier_id, parity)
-def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
+def barrier_arrive(barrier_id: int | PrimExpr | tir.Call):
"""Arrive at a memory barrier.
Args:
@@ -466,7 +467,7 @@ def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
return mbarrier_arrive(barrier_id)
-def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
+def shfl_xor(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with XOR offset.
Args:
@@ -483,7 +484,7 @@ def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
-def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
+def shfl_down(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with down offset.
Args:
@@ -496,7 +497,7 @@ def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
-def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
+def shfl_up(value: int | PrimExpr | tir.Call, offset: int | PrimExpr | tir.Call):
"""Perform a shuffle operation with up offset.
Args:
@@ -601,7 +602,7 @@ def loop_break():
return tir.call_intrin("handle", tir.op.Op.get("tl.loop_break"))
-def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
+def cp_async_barrier_noinc(barrier_id: int | PrimExpr | tir.Call):
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py
index 0be3e21ac..84444b8c6 100644
--- a/tilelang/language/copy.py
+++ b/tilelang/language/copy.py
@@ -1,17 +1,18 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import Union, Optional, Literal
+from typing import Literal
from tilelang import language as T
from tilelang.utils.language import get_buffer_region_from_load
from tvm import ir, tir
from tilelang.language.utils import buffer_to_tile_region, buffer_region_to_tile_region, buffer_load_to_tile_region
-def copy(src: Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion],
- dst: Union[tir.Buffer, tir.BufferLoad],
- coalesced_width: Optional[int] = None,
+def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion,
+ dst: tir.Buffer | tir.BufferLoad,
+ coalesced_width: int | None = None,
disable_tma: bool = False,
- eviction_policy: Optional[Literal["evict_normal", "evict_first", "evict_last"]] = None):
+ eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"""Copy data between memory regions.
Args:
@@ -94,8 +95,7 @@ def c2d_im2col(img: tir.Buffer,
stride: int,
dilation: int,
pad: int,
- eviction_policy: Optional[Literal["evict_normal", "evict_first",
- "evict_last"]] = None):
+ eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):
"""Perform im2col transformation for 2D convolution.
Args:
diff --git a/tilelang/language/customize.py b/tilelang/language/customize.py
index e31cce4a6..0830c22dc 100644
--- a/tilelang/language/customize.py
+++ b/tilelang/language/customize.py
@@ -1,8 +1,8 @@
"""The language interface for tl programs."""
+from __future__ import annotations
import tilelang.language as T
from tvm.tir import PrimExpr, Buffer, op
-from typing import List, Union
from .atomic import atomic_max, atomic_min, atomic_add, atomic_addx2, atomic_addx4, atomic_load, atomic_store # noqa: F401
@@ -36,7 +36,7 @@ def clamp(dst: PrimExpr, min_val: PrimExpr, max_val: PrimExpr) -> PrimExpr:
return dst
-def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
+def reshape(src: Buffer, shape: list[PrimExpr]) -> Buffer:
"""Reshapes the input buffer to the specified shape.
Args:
@@ -49,9 +49,7 @@ def reshape(src: Buffer, shape: List[PrimExpr]) -> Buffer:
return T.Tensor(shape, src.dtype, src.data)
-def view(src: Buffer,
- shape: Union[List[PrimExpr], None] = None,
- dtype: Union[str, None] = None) -> Buffer:
+def view(src: Buffer, shape: list[PrimExpr] | None = None, dtype: str | None = None) -> Buffer:
"""
Return a Tensor view of the input buffer with an optional new shape and dtype.
diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py
index 5cb6eb837..fc511c007 100644
--- a/tilelang/language/experimental/gemm_sp.py
+++ b/tilelang/language/experimental/gemm_sp.py
@@ -1,16 +1,16 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
-from typing import Union
def gemm_sp(
- A_sparse: Union[tir.Buffer, tir.Var],
- E: Union[tir.Buffer, tir.Var],
- B: Union[tir.Buffer, tir.Var],
- C: Union[tir.Buffer, tir.Var],
+ A_sparse: tir.Buffer | tir.Var,
+ E: tir.Buffer | tir.Var,
+ B: tir.Buffer | tir.Var,
+ C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
@@ -42,7 +42,7 @@ def gemm_sp(
AssertionError: If the K dimensions of matrices A and B don't match
"""
- def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
+ def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py
index de6b3cff3..95ef26746 100644
--- a/tilelang/language/fill.py
+++ b/tilelang/language/fill.py
@@ -1,12 +1,12 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tvm import tir
-from typing import Union
from tilelang.language import has_let_value, get_let_value
from tilelang.utils.language import get_buffer_region_from_load
-def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
+def fill(buffer: tir.Buffer | tir.BufferRegion, value: tir.PrimExpr):
"""Fill a buffer or buffer region with a specified value.
Args:
@@ -21,7 +21,7 @@ def fill(buffer: Union[tir.Buffer, tir.BufferRegion], value: tir.PrimExpr):
return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), buffer, value)
-def clear(buffer: Union[tir.Buffer, tir.Var]):
+def clear(buffer: tir.Buffer | tir.Var):
"""Clear a buffer by filling it with zeros.
Args:
diff --git a/tilelang/language/frame.py b/tilelang/language/frame.py
index b82cfe5ef..8e6d59268 100644
--- a/tilelang/language/frame.py
+++ b/tilelang/language/frame.py
@@ -1,4 +1,5 @@
"""Override the LetFrame to print a message when entering the frame."""
+from __future__ import annotations
from tvm.ffi import register_object as _register_object
from tvm.tir import Var, PrimExpr, BufferLoad, BufferRegion
@@ -6,7 +7,6 @@
from tvm import DataType
from tvm.script.ir_builder.tir.frame import TIRFrame
from collections import deque
-from typing import Optional
import threading
@@ -150,7 +150,7 @@ def __exit__(self, ptype, value, trace):
super().__exit__(ptype, value, trace)
@classmethod
- def Current(cls) -> "LetFrame":
+ def Current(cls) -> LetFrame:
"""Get the current (topmost) let frame.
Returns:
@@ -198,7 +198,7 @@ def has_let_value(var: Var) -> bool:
return _get_let_stack().has_value(var)
-def get_let_value(var: Var) -> Optional[PrimExpr]:
+def get_let_value(var: Var) -> PrimExpr | None:
"""Get the value bound to a variable in the current let frame stack.
Args:
diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py
index 3c4aa5452..bb8dc6ce8 100644
--- a/tilelang/language/gemm.py
+++ b/tilelang/language/gemm.py
@@ -1,23 +1,23 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tilelang.primitives.gemm.base import GemmWarpPolicy
import tilelang.language as T
from tvm import tir
-from typing import Union, List, Optional
from tilelang.utils.language import get_buffer_region_from_load
def gemm(
- A: Union[tir.Buffer, tir.Var],
- B: Union[tir.Buffer, tir.Var],
- C: Union[tir.Buffer, tir.Var],
+ A: tir.Buffer | tir.Var,
+ B: tir.Buffer | tir.Var,
+ C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
clear_accum: bool = False,
k_pack: int = 1,
wg_wait: int = 0,
- mbar: Optional[tir.Buffer] = None,
+ mbar: tir.Buffer | None = None,
):
"""Perform a General Matrix Multiplication (GEMM) operation.
@@ -45,7 +45,7 @@ def gemm(
AssertionError: If the K dimensions of matrices A and B don't match
"""
- def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
+ def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
@@ -63,7 +63,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
C = legalize_arguments(C)
mbar = legalize_arguments(mbar) if mbar is not None else None
- def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
+ def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
@@ -82,7 +82,7 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
- def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
+ def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
@@ -137,8 +137,7 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
stride_a = A_stride[-2]
stride_b = B_stride[-2]
- def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
- access_type: str = "r") -> tir.PrimExpr:
+ def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
@@ -175,7 +174,7 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
- def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
+ def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
@@ -214,9 +213,9 @@ def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr
# experimental currently, for fast compilation
def gemm_v2(
- A: Union[tir.Buffer, tir.Var],
- B: Union[tir.Buffer, tir.Var],
- C: Union[tir.Buffer, tir.Var],
+ A: tir.Buffer | tir.Var,
+ B: tir.Buffer | tir.Var,
+ C: tir.Buffer | tir.Var,
transpose_A: bool = False,
transpose_B: bool = False,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
@@ -247,7 +246,7 @@ def gemm_v2(
AssertionError: If the K dimensions of matrices A and B don't match
"""
- def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
+ def legalize_arguments(arg: tir.Buffer | tir.Var):
"""Convert let-bound variables to their corresponding buffers.
Args:
@@ -264,7 +263,7 @@ def legalize_arguments(arg: Union[tir.Buffer, tir.Var]):
B = legalize_arguments(B)
C = legalize_arguments(C)
- def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
+ def retrieve_shape(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
return object.shape
elif isinstance(object, tir.BufferRegion):
@@ -283,7 +282,7 @@ def retrieve_shape(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
raise ValueError(
f"Unsupported retrieve_shape argument type: {type(object)} for buffer {object}")
- def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
+ def retrieve_stride(object: tir.Buffer | tir.BufferRegion) -> list[int]:
if isinstance(object, tir.Buffer):
strides = []
stride = 1
@@ -338,8 +337,7 @@ def retrieve_stride(object: Union[tir.Buffer, tir.BufferRegion]) -> List[int]:
stride_a = A_stride[-2]
stride_b = B_stride[-2]
- def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
- access_type: str = "r") -> tir.PrimExpr:
+ def retrieve_ptr(object: tir.Buffer | tir.BufferRegion, access_type: str = "r") -> tir.PrimExpr:
if isinstance(object, tir.Buffer):
return object.access_ptr(access_type)
elif isinstance(object, tir.BufferRegion):
@@ -376,7 +374,7 @@ def retrieve_ptr(object: Union[tir.Buffer, tir.BufferRegion],
raise ValueError(
f"Unsupported retrieve_ptr argument type: {type(object)} for buffer {object}")
- def retrieve_offset(object: Union[tir.Buffer, tir.BufferRegion]) -> tir.PrimExpr:
+ def retrieve_offset(object: tir.Buffer | tir.BufferRegion) -> tir.PrimExpr:
"""Retrieve the offset of the buffer or buffer region."""
if isinstance(object, tir.Buffer):
return [0] * len(object.shape)
diff --git a/tilelang/language/kernel.py b/tilelang/language/kernel.py
index 303e88a94..54b78d3d9 100644
--- a/tilelang/language/kernel.py
+++ b/tilelang/language/kernel.py
@@ -1,6 +1,6 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import Union, List, Tuple, Optional
from collections import deque
from tvm import tir
from tvm.tir import Var
@@ -80,7 +80,7 @@ def _get_current_stack() -> FrameStack:
return _local.kernel_launch_frame_stack
-def _normalize_bindings(bindings: List[Var]) -> Union[Var, List[Var]]:
+def _normalize_bindings(bindings: list[Var]) -> Var | list[Var]:
"""
Return a bare Var when we only have a single binding so that users may write either
`with T.Kernel(...) as pid:` or `with T.Kernel(...) as (pid,)`.
@@ -98,7 +98,7 @@ class KernelLaunchFrame(TIRFrame):
and handles the entry and exit of the kernel launch scope.
"""
- def __enter__(self) -> Union[Var, List[Var]]:
+ def __enter__(self) -> Var | list[Var]:
"""
Enters the KernelLaunchFrame scope and pushes this frame onto the stack.
Returns one Var if we detect exactly 5 frames (meaning there is a single
@@ -132,7 +132,7 @@ def __exit__(self, ptype, value, trace):
super().__exit__(ptype, value, trace)
@classmethod
- def Current(cls) -> Optional["KernelLaunchFrame"]:
+ def Current(cls) -> KernelLaunchFrame | None:
"""
Returns the topmost (current) KernelLaunchFrame from the stack if it exists,
or None if the stack is empty.
@@ -148,7 +148,7 @@ def get_block_extent(self, dim: int) -> int:
iter_var = self.frames[dim].iter_var
return int(iter_var.dom.extent)
- def get_block_extents(self) -> List[int]:
+ def get_block_extents(self) -> list[int]:
"""
Returns the block extents for all three dimensions.
"""
@@ -162,7 +162,7 @@ def get_thread_extent(self, dim: int) -> int:
iter_var = self.frames[-4 + dim].iter_var
return int(iter_var.dom.extent)
- def get_thread_extents(self) -> List[int]:
+ def get_thread_extents(self) -> list[int]:
"""
Returns the thread extents for all three dimensions.
"""
@@ -175,7 +175,7 @@ def get_thread_binding(self, dim: int = 0) -> Var:
"""
return self.frames[-4 + dim].iter_var.var
- def get_thread_bindings(self) -> List[Var]:
+ def get_thread_bindings(self) -> list[Var]:
"""
Returns the thread binding for the given dimension.
dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z.
@@ -198,21 +198,21 @@ def get_block_binding(self, dim: int = 0) -> Var:
"""
return self.frames[dim].iter_var.var
- def get_block_bindings(self) -> List[Var]:
+ def get_block_bindings(self) -> list[Var]:
"""
Returns all three block bindings.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
- def blocks(self) -> List[Var]:
+ def blocks(self) -> list[Var]:
"""
Returns the block indices from the topmost frame.
"""
return [frame.iter_var.var for frame in self.frames[0:-4]]
@property
- def threads(self) -> List[Var]:
+ def threads(self) -> list[Var]:
"""
Returns the thread indices from the topmost frame.
"""
@@ -227,10 +227,10 @@ def num_threads(self) -> int:
def Kernel(
- *blocks: List[tir.PrimExpr],
- threads: Optional[Union[int, List[int], Tuple]] = None,
+ *blocks: list[tir.PrimExpr],
+ threads: int | list[int] | tuple | None = None,
is_cpu: bool = False,
- prelude: Optional[str] = None,
+ prelude: str | None = None,
):
"""Tools to quickly construct a GPU kernel launch frame.
@@ -310,7 +310,7 @@ def get_thread_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_thread_binding(dim)
-def get_thread_bindings() -> List[Var]:
+def get_thread_bindings() -> list[Var]:
"""Returns all three thread bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
@@ -324,7 +324,7 @@ def get_block_binding(dim: int = 0) -> Var:
return KernelLaunchFrame.Current().get_block_binding(dim)
-def get_block_bindings() -> List[Var]:
+def get_block_bindings() -> list[Var]:
"""Returns all three block bindings.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
@@ -338,7 +338,7 @@ def get_thread_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_thread_extent(dim)
-def get_thread_extents() -> List[int]:
+def get_thread_extents() -> list[int]:
"""Returns all three thread extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
@@ -352,7 +352,7 @@ def get_block_extent(dim: int = 0) -> int:
return KernelLaunchFrame.Current().get_block_extent(dim)
-def get_block_extents() -> List[int]:
+def get_block_extents() -> list[int]:
"""Returns all three block extents.
"""
assert KernelLaunchFrame.Current() is not None, "KernelLaunchFrame is not initialized"
diff --git a/tilelang/language/logical.py b/tilelang/language/logical.py
index a08627203..a09088e68 100644
--- a/tilelang/language/logical.py
+++ b/tilelang/language/logical.py
@@ -1,13 +1,13 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tilelang import language as T
from tvm.tir import Buffer, BufferRegion, BufferLoad
from tvm import tir
-from typing import Union
from tilelang.utils.language import get_buffer_elems
-def any_of(buffer: Union[T.Tensor, BufferRegion]):
+def any_of(buffer: T.Tensor | BufferRegion):
"""Check if any element in the buffer is true.
Args:
@@ -42,7 +42,7 @@ def any_of(buffer: Union[T.Tensor, BufferRegion]):
raise ValueError(f"Invalid buffer type: {type(buffer)}")
-def all_of(buffer: Union[T.Tensor, BufferRegion]):
+def all_of(buffer: T.Tensor | BufferRegion):
"""Check if all elements in the buffer are true.
Args:
diff --git a/tilelang/language/overrides/parser.py b/tilelang/language/overrides/parser.py
index 5a9343650..01d59b607 100644
--- a/tilelang/language/overrides/parser.py
+++ b/tilelang/language/overrides/parser.py
@@ -1,7 +1,7 @@
"""TVMScript parser overrides tailored for TileLang."""
+from __future__ import annotations
from functools import partial
-from typing import Tuple
from tvm.script.ir_builder import tir as T
from tvm.script.parser._core import dispatch, doc
@@ -10,7 +10,7 @@
from tvm.script.parser.tir import parser as tvm_tir_parser
-def _get_node_span(node: doc.AST) -> Tuple[int, int, int, int]:
+def _get_node_span(node: doc.AST) -> tuple[int, int, int, int]:
"""Return the span (lineno, col, end_lineno, end_col) for a doc node."""
return (node.lineno, node.col_offset, node.end_lineno, node.end_col_offset)
diff --git a/tilelang/language/parallel.py b/tilelang/language/parallel.py
index a70846a62..8173675a8 100644
--- a/tilelang/language/parallel.py
+++ b/tilelang/language/parallel.py
@@ -1,11 +1,12 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import Optional, Dict, Any
+from typing import Any
from tvm import tir
from tilelang import _ffi_api
-def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
+def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None):
"""Tools to construct nested parallel for loop.
This can be used to create element-wise tensor expression.
@@ -22,7 +23,7 @@ def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None):
res : frame.ForFrame
The ForFrame.
"""
- annotations: Dict[str, Any] = {}
+ annotations: dict[str, Any] = {}
if coalesced_width is not None:
annotations.update({"coalesced_width": coalesced_width})
return _ffi_api.Parallel(extents, annotations) # type: ignore[attr-defined] # pylint: disable=no-member
diff --git a/tilelang/language/parser/operation.py b/tilelang/language/parser/operation.py
index e16fa261b..43774947e 100644
--- a/tilelang/language/parser/operation.py
+++ b/tilelang/language/parser/operation.py
@@ -17,8 +17,7 @@
# This file is modified from the original version,
# which is part of the TVM project (https://tvm.apache.org/).
"""The tir expression operation registration"""
-
-from typing import Type
+from __future__ import annotations
from tvm import tir
from tvm.ffi.runtime_ctypes import DataType, DataTypeCode
@@ -28,7 +27,7 @@
from tvm.script.parser._core import OpMethod, doc, register_op
-def _register_expr_op(ty: Type): # pylint: disable=invalid-name
+def _register_expr_op(ty: type): # pylint: disable=invalid-name
ty._dispatch_type = ty # pylint: disable=protected-access
def _and(a, b):
@@ -115,7 +114,7 @@ def _gt(a, b):
def _ge(a, b):
return _auto_broadcast(a, b, tir.GE)
- def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name
+ def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name
register_op(ty, op, i)(m)
for i in [0, 1]:
diff --git a/tilelang/language/persistent.py b/tilelang/language/persistent.py
index 1761cfa53..0ee7f112a 100644
--- a/tilelang/language/persistent.py
+++ b/tilelang/language/persistent.py
@@ -1,15 +1,15 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import List, Optional
from tvm import tir
from tilelang import _ffi_api
def Persistent(
- domain: List[tir.PrimExpr],
+ domain: list[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
- group_size: Optional[tir.PrimExpr] = 8,
+ group_size: tir.PrimExpr | None = 8,
):
"""Tools to construct persistent for loop.
diff --git a/tilelang/language/pipeline.py b/tilelang/language/pipeline.py
index 85fd90cc0..895ed914a 100644
--- a/tilelang/language/pipeline.py
+++ b/tilelang/language/pipeline.py
@@ -1,6 +1,6 @@
"""The language interface for tl programs."""
+from __future__ import annotations
-from typing import List, Optional
from tvm import tir
from tvm.tir import IntImm
from tilelang import _ffi_api
@@ -10,10 +10,10 @@ def Pipelined(
start: tir.PrimExpr,
stop: tir.PrimExpr = None,
num_stages: int = 0,
- order: Optional[List[int]] = None,
- stage: Optional[List[int]] = None,
- sync: Optional[List[List[int]]] = None,
- group: Optional[List[List[int]]] = None,
+ order: list[int] | None = None,
+ stage: list[int] | None = None,
+ sync: list[list[int]] | None = None,
+ group: list[list[int]] | None = None,
):
"""Tools to construct pipelined for loop.
diff --git a/tilelang/language/proxy.py b/tilelang/language/proxy.py
index 83513f7a1..539c1d94c 100644
--- a/tilelang/language/proxy.py
+++ b/tilelang/language/proxy.py
@@ -1,7 +1,7 @@
"""The language interface for tl programs."""
from __future__ import annotations
-from typing import Any, Optional, Sequence, SupportsIndex, TYPE_CHECKING, Tuple, Union
+from typing import Any, Sequence, SupportsIndex, TYPE_CHECKING
from typing_extensions import Self
from tvm import tir
@@ -143,7 +143,7 @@ class TensorProxy(BaseTensorProxy):
"""
@staticmethod
- def _construct_strides(shape: Tuple[Any]):
+ def _construct_strides(shape: tuple[Any]):
s, strides = 1, [1]
for dim in shape[:0:-1]:
s *= dim
@@ -151,7 +151,7 @@ def _construct_strides(shape: Tuple[Any]):
return tuple(reversed(strides))
def __call__(self,
- shape: Union[Tuple[Any], PrimExpr, int],
+ shape: tuple[Any] | PrimExpr | int,
dtype: str = "float32",
data=None,
scope=None) -> tir.Buffer:
@@ -172,8 +172,8 @@ class StridedTensorProxy(BaseTensorProxy):
"""
def __call__(self,
- shape: Tuple[Any],
- strides: Tuple[Any],
+ shape: tuple[Any],
+ strides: tuple[Any],
dtype: str = "float32",
scope=None) -> tir.Buffer:
if len(shape) != len(strides):
@@ -270,7 +270,7 @@ class LocalBuffer(BaseTensor):
LocalBuffer = LocalBufferProxy() # pylint: disable=invalid-name
-def ptr(dtype: Optional[str] = None,
+def ptr(dtype: str | None = None,
storage_scope: str = "global",
*,
is_size_var: bool = False) -> Var:
diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py
index 5cfca850b..55ac2bb0d 100644
--- a/tilelang/language/reduce.py
+++ b/tilelang/language/reduce.py
@@ -1,7 +1,7 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tvm import tir
-from typing import Optional
from tilelang.language import copy, macro, alloc_shared
@@ -199,7 +199,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) -
copy(cumsum_smem, dst)
-def cumsum(src: tir.Buffer, dst: Optional[tir.Buffer] = None, dim: int = 0, reverse: bool = False):
+def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse: bool = False):
"""
Compute the cumulative sum of `src` along `dim`, writing results to `dst`.
diff --git a/tilelang/language/tir/entry.py b/tilelang/language/tir/entry.py
index ade36b81c..22702ae43 100644
--- a/tilelang/language/tir/entry.py
+++ b/tilelang/language/tir/entry.py
@@ -1,14 +1,15 @@
+from __future__ import annotations
import inspect
-from typing import Callable, Optional, Union
+from typing import Callable
import tvm.script.parser.tir.entry as _tir_entry
from tvm.tir.function import PrimFunc
from tvm.script.parser._core import parse, scan_macro, utils
-def prim_func(func: Optional[Callable] = None,
+def prim_func(func: Callable | None = None,
private: bool = False,
- check_well_formed: bool = False) -> Union[PrimFunc, Callable]:
+ check_well_formed: bool = False) -> PrimFunc | Callable:
"""The parsing method for tir prim func, by using `@prim_func` as decorator.
Parameters
diff --git a/tilelang/language/tir/ir.py b/tilelang/language/tir/ir.py
index 1143f2a9e..0c0d167e0 100644
--- a/tilelang/language/tir/ir.py
+++ b/tilelang/language/tir/ir.py
@@ -1,7 +1,8 @@
+from __future__ import annotations
import tvm.script.ir_builder.tir.ir as _ir
from tvm.script.ir_builder.tir import frame
from tvm.tir import PrimExpr
-from typing import Any, Dict
+from typing import Any
import tilelang.language.tir.op as _tir_op
import functools
@@ -9,7 +10,7 @@
def serial(start: PrimExpr,
stop: PrimExpr = None,
*,
- annotations: Dict[str, Any] = None) -> frame.ForFrame:
+ annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The serial For statement.
Parameters
@@ -34,7 +35,7 @@ def serial(start: PrimExpr,
def parallel(start: PrimExpr,
stop: PrimExpr = None,
*,
- annotations: Dict[str, Any] = None) -> frame.ForFrame:
+ annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The parallel For statement.
Parameters
@@ -59,7 +60,7 @@ def parallel(start: PrimExpr,
def vectorized(start: PrimExpr,
stop: PrimExpr = None,
*,
- annotations: Dict[str, Any] = None) -> frame.ForFrame:
+ annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The vectorized For statement.
Parameters
@@ -84,7 +85,7 @@ def vectorized(start: PrimExpr,
def unroll(start: PrimExpr,
stop: PrimExpr = None,
*,
- annotations: Dict[str, Any] = None) -> frame.ForFrame:
+ annotations: dict[str, Any] = None) -> frame.ForFrame:
"""The unrolled For statement.
Parameters
@@ -111,7 +112,7 @@ def thread_binding(
stop: PrimExpr = None,
thread: str = None,
*,
- annotations: Dict[str, Any] = None,
+ annotations: dict[str, Any] = None,
) -> frame.ForFrame:
"""The thread-binding For statement.
diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py
index 10ca7ca93..925665609 100644
--- a/tilelang/language/tir/op.py
+++ b/tilelang/language/tir/op.py
@@ -1,4 +1,5 @@
-from typing import Any, Optional
+from __future__ import annotations
+from typing import Any
import tvm
from tvm.ir import PrimExpr
from tvm.ir.base import Span
@@ -1857,7 +1858,7 @@ def min_value(dtype, span=None):
return _tvm_op.min_value(dtype, span)
-def max_value(dtype: str, span: Optional[Span] = None) -> Any:
+def max_value(dtype: str, span: Span | None = None) -> Any:
"""maximum value of dtype
Parameters
@@ -1876,7 +1877,7 @@ def max_value(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.max_value(dtype, span)
-def infinity(dtype: str, span: Optional[Span] = None) -> Any:
+def infinity(dtype: str, span: Span | None = None) -> Any:
"""infinity value of dtype
Parameters
@@ -1895,7 +1896,7 @@ def infinity(dtype: str, span: Optional[Span] = None) -> Any:
return _tvm_op.infinity(dtype, span)
-def reinterpret(dtype, value, span: Optional[Span] = None) -> Any:
+def reinterpret(dtype, value, span: Span | None = None) -> Any:
"""infinity value of dtype
Parameters
diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py
index 9b21596bb..caed14aa4 100644
--- a/tilelang/language/utils.py
+++ b/tilelang/language/utils.py
@@ -1,5 +1,5 @@
+from __future__ import annotations
from tilelang import tvm as tvm
-from typing import List
from tvm import tir
from tvm.tir import PrimExpr, Buffer, BufferLoad, op
from tilelang import language as T
@@ -42,7 +42,7 @@ def buffer_to_tile_region(buffer: Buffer, access_type: str):
return region(T.BufferLoad(buffer, mins), access_type, *extents)
-def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List[PrimExpr]):
+def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):
"""Convert a buffer load operation to a tile region descriptor.
Args:
@@ -69,7 +69,7 @@ def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: List
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
- extents: List[tir.PrimExpr]):
+ extents: list[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
@@ -88,7 +88,7 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
-def index_to_coordinates(index, shape) -> List[PrimExpr]:
+def index_to_coordinates(index, shape) -> list[PrimExpr]:
"""
Convert a flat (linear) index into multi-dimensional coordinates for a given shape.
diff --git a/tilelang/language/warpgroup.py b/tilelang/language/warpgroup.py
index 2e64d66fa..872d30010 100644
--- a/tilelang/language/warpgroup.py
+++ b/tilelang/language/warpgroup.py
@@ -1,10 +1,10 @@
"""The language interface for tl programs."""
+from __future__ import annotations
from tvm.script.ir_builder.tir.frame import TIRFrame
from tvm.ffi import register_object
from tilelang import _ffi_api
from .kernel import get_thread_bindings, get_thread_extents
-from typing import List
@register_object("tl.WarpSpecializeFrame")
@@ -45,7 +45,7 @@ def WarpSpecialize(*warp_group_idx):
# only available for nvidia gpus.
warp_group_size = 128
- warp_group_ids: List[int] = []
+ warp_group_ids: list[int] = []
for warp_group_id in warp_group_idx:
warp_group_ids.append(warp_group_id)
diff --git a/tilelang/layout/fragment.py b/tilelang/layout/fragment.py
index b26affaa2..b9c2b10ec 100644
--- a/tilelang/layout/fragment.py
+++ b/tilelang/layout/fragment.py
@@ -1,12 +1,12 @@
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
+from __future__ import annotations
import tvm
from tvm.ir import Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
from tilelang.layout import Layout
-from typing import List
@tvm.ffi.register_object("tl.Fragment")
@@ -123,7 +123,7 @@ def get_thread_size(self):
def repeat(self,
repeats,
repeat_on_thread: bool = False,
- lower_dim_first: bool = True) -> "Fragment":
+ lower_dim_first: bool = True) -> Fragment:
"""
Returns a new Fragment that repeats the iteration space a given number of times.
@@ -143,7 +143,7 @@ def repeat(self,
"""
return _ffi_api.Fragment_repeat(self, repeats, repeat_on_thread, lower_dim_first)
- def replicate(self, replicate: int) -> "Fragment":
+ def replicate(self, replicate: int) -> Fragment:
"""
Replicate the Fragment across a new thread dimension.
@@ -159,7 +159,7 @@ def replicate(self, replicate: int) -> "Fragment":
"""
return _ffi_api.Fragment_replicate(self, replicate)
- def condense_rep_var(self) -> "Fragment":
+ def condense_rep_var(self) -> Fragment:
"""
Condense or fold the replicate variable into the existing iteration space.
This operation may be used to reduce dimensionality if the replicate variable
@@ -172,7 +172,7 @@ def condense_rep_var(self) -> "Fragment":
"""
return _ffi_api.Fragment_condense_rep_var(self)
- def map_forward_thread(self, indices: List[PrimExpr]) -> PrimExpr:
+ def map_forward_thread(self, indices: list[PrimExpr]) -> PrimExpr:
"""
Get the thread mapping expression for a given set of argument indices.
@@ -206,7 +206,7 @@ def __repr__(self):
"""
return f"Fragment<{self.get_input_shape()}->{self.get_output_shape()}, thread={self.thread}, index={self.index}>"
- def is_equal(self, other: "Fragment") -> bool:
+ def is_equal(self, other: Fragment) -> bool:
"""
Check if the current fragment is equal to another fragment.
"""
diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py
index 1417d1b73..2fd58cd2e 100644
--- a/tilelang/layout/gemm_sp.py
+++ b/tilelang/layout/gemm_sp.py
@@ -1,17 +1,16 @@
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
+from __future__ import annotations
-from typing import Optional
import tvm
import tilelang.language as T
import warnings
from tilelang.contrib import nvcc
-from typing import List
from math import prod
-def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]:
+def decompose_col_major(index_1d: int, basis: list[int]) -> list[int]:
res = []
for x in basis:
res.append(index_1d % x)
@@ -136,7 +135,7 @@ def ColumnMajorInterleaved(i: int, j: int) -> int:
def make_metadata_layout(buffer: tvm.tir.Buffer,
mma_dtype: str = "float16",
backend: str = 'cutlass',
- arch: Optional[str] = None,
+ arch: str | None = None,
**extra_args):
if arch is None:
arch = nvcc.get_target_compute_version()
diff --git a/tilelang/layout/layout.py b/tilelang/layout/layout.py
index fd8e31225..dd0f11709 100644
--- a/tilelang/layout/layout.py
+++ b/tilelang/layout/layout.py
@@ -1,11 +1,11 @@
"""Wrapping Layouts."""
# pylint: disable=invalid-name, unsupported-binary-operation
+from __future__ import annotations
import tvm
from tvm.ir import Node, Range
from tvm.tir import IterVar, Var, PrimExpr, IndexMap
from tilelang import _ffi_api
-from typing import List
# Register the Layout class as a TVM object under the name "tl.Layout"
@@ -92,7 +92,7 @@ def get_forward_vars(self):
def get_forward_index(self):
return self.index
- def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr:
+ def map_forward_index(self, indices: list[PrimExpr]) -> PrimExpr:
"""
Compute the forward index mapping for a given set of input indices.
@@ -122,7 +122,7 @@ def map_forward_index(self, indices: List[PrimExpr]) -> PrimExpr:
# Map the provided indices using the constructed index mapping
return index_map.map_indices(indices)
- def inverse(self) -> "Layout":
+ def inverse(self) -> Layout:
"""
Compute the inverse of the current layout transformation.
@@ -133,7 +133,7 @@ def inverse(self) -> "Layout":
"""
return _ffi_api.Layout_inverse(self)
- def is_equal(self, other: "Layout") -> bool:
+ def is_equal(self, other: Layout) -> bool:
"""
Check if the current layout is equal to another layout.
diff --git a/tilelang/primitives/gemm/__init__.py b/tilelang/primitives/gemm/__init__.py
index 64f108957..ee9436d15 100644
--- a/tilelang/primitives/gemm/__init__.py
+++ b/tilelang/primitives/gemm/__init__.py
@@ -1,4 +1,5 @@
-from typing import Optional
+from __future__ import annotations
+
from tvm import tir
from tilelang.utils import is_local, is_fragment, is_shared
from tilelang.primitives.gemm.base import GemmWarpPolicy
@@ -12,11 +13,11 @@ def gemm(
C: tir.Buffer,
transpose_A: bool = False,
transpose_B: bool = False,
- block_row_warps: Optional[int] = None,
- block_col_warps: Optional[int] = None,
- warp_row_tiles: Optional[int] = None,
- warp_col_tiles: Optional[int] = None,
- chunk: Optional[int] = None,
+ block_row_warps: int | None = None,
+ block_col_warps: int | None = None,
+ warp_row_tiles: int | None = None,
+ warp_col_tiles: int | None = None,
+ chunk: int | None = None,
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1,
):
diff --git a/tilelang/primitives/gemm/base.py b/tilelang/primitives/gemm/base.py
index d79961635..827ff78f9 100644
--- a/tilelang/primitives/gemm/base.py
+++ b/tilelang/primitives/gemm/base.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
from enum import IntEnum
from dataclasses import dataclass
-from typing import Optional
from tvm import tir
@@ -161,7 +161,7 @@ def compute_warp_partition(self, M, N, num_warps):
return m_warp, n_warp
@classmethod
- def from_warp_partition(cls, m_warp: int, n_warp: int) -> 'GemmWarpPolicy':
+ def from_warp_partition(cls, m_warp: int, n_warp: int) -> GemmWarpPolicy:
"""
Determine the warp policy based on the given warp partitioning.
@@ -197,11 +197,11 @@ class GemmBaseParams:
transpose_A: bool = False
transpose_B: bool = False
- block_row_warps: Optional[int] = None
- block_col_warps: Optional[int] = None
- warp_row_tiles: Optional[int] = None
- warp_col_tiles: Optional[int] = None
- chunk: Optional[int] = None
+ block_row_warps: int | None = None
+ block_col_warps: int | None = None
+ warp_row_tiles: int | None = None
+ warp_col_tiles: int | None = None
+ chunk: int | None = None
policy: GemmWarpPolicy = GemmWarpPolicy.Square,
k_pack: int = 1
@@ -226,7 +226,7 @@ def params_as_dict(self):
"k_pack": self.k_pack,
}
- def infer_block_partition(self, threads: Optional[int]) -> None:
+ def infer_block_partition(self, threads: int | None) -> None:
"""
Infer and set block partition parameters (e.g., block_row_warps,
block_col_warps, warp_row_tiles, warp_col_tiles, chunk) based on the
diff --git a/tilelang/profiler/__init__.py b/tilelang/profiler/__init__.py
index 4f4f710d0..c681ee976 100644
--- a/tilelang/profiler/__init__.py
+++ b/tilelang/profiler/__init__.py
@@ -1,6 +1,7 @@
"""The profiler and convert to torch utils"""
+from __future__ import annotations
-from typing import List, Optional, Callable, Any, Literal
+from typing import Callable, Any, Literal
from functools import partial
import torch
from contextlib import suppress
@@ -28,17 +29,17 @@ class Profiler:
adapter: Optional kernel adapter for interfacing with different backends
"""
- params: List[KernelParam]
- result_idx: List[int]
+ params: list[KernelParam]
+ result_idx: list[int]
supply_type: TensorSupplyType
- adapter: Optional[BaseKernelAdapter] = None
+ adapter: BaseKernelAdapter | None = None
def __post_init__(self):
"""Initialize tensor supply after dataclass initialization"""
self.result_idx = self._legalize_result_idx(self.result_idx)
self.supply = get_tensor_supply(self.supply_type)
- def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]:
+ def _legalize_result_idx(self, result_idx: list[int] | None = None) -> list[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
@@ -55,7 +56,7 @@ def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[i
return result_idx
- def with_default_adapter(self, adapter: BaseKernelAdapter) -> "Profiler":
+ def with_default_adapter(self, adapter: BaseKernelAdapter) -> Profiler:
self.adapter = adapter
return self
@@ -76,7 +77,7 @@ def _get_params(self, with_output=False):
def assert_allclose(
self,
reference_program: Callable,
- input_tensors: Optional[List[torch.Tensor]] = None,
+ input_tensors: list[torch.Tensor] | None = None,
atol: float = 1e-2,
rtol: float = 1e-2,
max_mismatched_ratio=0.01,
@@ -147,7 +148,7 @@ def is_float8(tensor: torch.Tensor) -> bool:
def manual_assert_close(
self,
reference_program: Callable,
- input_tensors: Optional[List[torch.Tensor]] = None,
+ input_tensors: list[torch.Tensor] | None = None,
manual_check_prog: Callable = None,
):
"""Validates kernel output against a reference implementation.
@@ -194,13 +195,13 @@ def assert_consistent(self, repeat=10):
rhs,
]
- def run_once(self, func: Optional[Callable] = None):
+ def run_once(self, func: Callable | None = None):
ins = self._get_inputs()
if not func:
func = self.__call__
return func(*ins)
- def determine_profiler(self, func: Optional[Callable] = None):
+ def determine_profiler(self, func: Callable | None = None):
"""Determines which profiler backend to use based on function type.
Args:
@@ -217,14 +218,14 @@ def determine_profiler(self, func: Optional[Callable] = None):
def do_bench(
self,
- func: Optional[Callable] = None,
+ func: Callable | None = None,
warmup: int = 25,
rep: int = 100,
n_warmup: int = 1,
n_repeat: int = 1,
- input_tensors: List[torch.Tensor] = None,
+ input_tensors: list[torch.Tensor] = None,
backend: Literal["event", "cupti"] = "event",
- quantiles: Optional[List[float]] = None,
+ quantiles: list[float] | None = None,
return_mode: Literal["min", "max", "mean", "median"] = "mean",
) -> float:
"""Benchmarks the execution time of a given function.
diff --git a/tilelang/profiler/bench.py b/tilelang/profiler/bench.py
index d6f8c0820..a851ceb3d 100644
--- a/tilelang/profiler/bench.py
+++ b/tilelang/profiler/bench.py
@@ -1,8 +1,9 @@
"""Profiler and benchmarking utilities for PyTorch functions."""
+from __future__ import annotations
import os
import sys
-from typing import Callable, List, Literal, Optional, Union
+from typing import Callable, Literal
import torch
@@ -65,11 +66,11 @@ def do_bench(
rep: float = 100,
_n_warmup: int = 0,
_n_repeat: int = 0,
- quantiles: Optional[List[float]] = None,
+ quantiles: list[float] | None = None,
fast_flush: bool = True,
backend: Literal["event", "cupti"] = "event",
return_mode: Literal["min", "max", "mean", "median"] = "mean",
-) -> Union[float, List[float]]:
+) -> float | list[float]:
"""Benchmark the runtime of a PyTorch function with L2 cache management.
This function provides accurate GPU kernel timing by:
@@ -138,9 +139,9 @@ def _bench_with_cuda_events(
fn: Callable,
cache: torch.Tensor,
n_repeat: int,
- quantiles: Optional[List[float]],
+ quantiles: list[float] | None,
return_mode: str,
-) -> Union[float, List[float]]:
+) -> float | list[float]:
"""Benchmark using CUDA events for timing."""
# Create timing events
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
diff --git a/tilelang/quantize/lop3.py b/tilelang/quantize/lop3.py
index f1bc6910f..47d91f056 100644
--- a/tilelang/quantize/lop3.py
+++ b/tilelang/quantize/lop3.py
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
-from typing import Dict, Literal
+from __future__ import annotations
+from typing import Literal
decode_i4_to_f16 = """
template
@@ -1096,7 +1097,7 @@ def get_lop3_intrin_group(
with_zeros: bool = False,
zeros_mode: Literal["original", "rescale", "quantized"] = "original",
storage_scope: str = "local",
-) -> Dict[str, str]:
+) -> dict[str, str]:
"""
This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding.
LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of
@@ -1186,9 +1187,9 @@ def get_lop3_intrin_group(
elif out_dtype == "int4":
d4f = "i4s"
else:
- raise ValueError("Unsupported target dtype: {}".format(target_dtype))
+ raise ValueError(f"Unsupported target dtype: {target_dtype}")
source_symbol = "u" if source_format == "uint" else "s"
- func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f)
+ func_name = f"decode_i{source_bit}{source_symbol}_to_{d4f}"
if with_scaling:
func_name += "_scale"
if with_zeros:
diff --git a/tilelang/quantize/mxfp.py b/tilelang/quantize/mxfp.py
index 552f3db3c..0425c549d 100644
--- a/tilelang/quantize/mxfp.py
+++ b/tilelang/quantize/mxfp.py
@@ -1,4 +1,5 @@
-from typing import Literal, Dict
+from __future__ import annotations
+from typing import Literal
# Implementation asm for fp4 to bf16, using twiddling
# Reference: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py#L11-L18
@@ -54,7 +55,7 @@ def get_mxfp_intrin_group(
source_bit: int = 4,
storage_dtype: Literal["int32", "int8", "uint8"] = "uint8",
use_twiddling: bool = False,
-) -> Dict[str, str]:
+) -> dict[str, str]:
"""
Return metadata for an MXFP decoding intrinsic: function name and C source string.
diff --git a/tilelang/quantize/quantization.py b/tilelang/quantize/quantization.py
index bc0ea47bf..db9d2349d 100644
--- a/tilelang/quantize/quantization.py
+++ b/tilelang/quantize/quantization.py
@@ -223,7 +223,7 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str):
e4 = val & tir.const(0x40, "uint16")
prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"),
tir.const(0x4000, "uint16"))
- e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix
+ e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | prefix
return tir.reinterpret("float16", s_f16 | e_f16)
@@ -232,7 +232,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str):
assert dtype == "float16"
s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16")
e4 = val & tir.const(0x40, "uint16")
- e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
+ e_f16 = ((val & tir.const(63, "uint16")) << tir.const(7, "uint16")) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16"))
e_f16 = e_f16 ^ tir.const(0x2000, "uint16")
return tir.reinterpret("float16", s_f16 | e_f16)
diff --git a/tilelang/tileop/gemm/gemm_base.py b/tilelang/tileop/gemm/gemm_base.py
index 849b6d33a..4968b09f4 100644
--- a/tilelang/tileop/gemm/gemm_base.py
+++ b/tilelang/tileop/gemm/gemm_base.py
@@ -9,7 +9,7 @@
@dataclass
-class GemmBase(object):
+class GemmBase:
gemm_node: Node
def infer_layout(self, target: Target, thread_nums: int):
diff --git a/tilelang/tools/Analyzer.py b/tilelang/tools/Analyzer.py
index 379dfc119..205c647e3 100644
--- a/tilelang/tools/Analyzer.py
+++ b/tilelang/tools/Analyzer.py
@@ -1,9 +1,9 @@
+from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from tilelang import tvm
from tvm.tir.stmt_functor import ir_transform
import logging
-from typing import Optional
# Configuration for different hardware architectures.
# Each entry contains: (cores per SM, default clock (GHz), FLOPs per cycle, max SM count)
ARCH_CONFIGS = {"80": (128, 1.41, 2, 108), "86": (128, 1.70, 2, 84), "89": (128, 2.52, 2, 128)}
@@ -168,7 +168,7 @@ def calculate(self) -> AnalysisResult:
AnalysisResult: The calculated performance metrics.
"""
- def get_peak_tflops(device) -> Optional[float]:
+ def get_peak_tflops(device) -> float | None:
"""
Get the peak TFLOPS for the target device.
Args:
diff --git a/tilelang/transform/add_bufstore_wrapper.py b/tilelang/transform/add_bufstore_wrapper.py
index 1b3b4cd4c..7ccab4707 100644
--- a/tilelang/transform/add_bufstore_wrapper.py
+++ b/tilelang/transform/add_bufstore_wrapper.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
from tvm.tir import (BufferStore, For, AttrStmt, ForKind, Var, PrimFunc, BufferLoad, Buffer, IntImm)
from tvm.tir.stmt_functor import ir_transform, post_order_visit
from tvm.tir.transform import prim_func_pass
-from typing import Tuple, List, Dict
def AddWrapperForSingleBufStore():
@@ -42,7 +42,7 @@ def visit_variable(node):
post_order_visit(operation, visit_variable)
return used_variables
- def collect_buffer_accesses(statement) -> Tuple[List[Buffer], List[Buffer]]:
+ def collect_buffer_accesses(statement) -> tuple[list[Buffer], list[Buffer]]:
"""
Categorizes buffers accessed in the statement by their scope.
@@ -69,7 +69,7 @@ def visit_buffer_access(node):
local_buffers.append(buffer)
return local_buffers, fragment_buffers
- def collect_buffer_indices(statement) -> Dict[Buffer, List[int]]:
+ def collect_buffer_indices(statement) -> dict[Buffer, list[int]]:
"""
Maps each buffer to its access indices.
diff --git a/tilelang/transform/simplify.py b/tilelang/transform/simplify.py
index 6b8fedfc3..7e0c5062b 100644
--- a/tilelang/transform/simplify.py
+++ b/tilelang/transform/simplify.py
@@ -1,7 +1,8 @@
+from __future__ import annotations
from tilelang import tvm as tvm
from tvm import IRModule
from tvm.tir import PrimFunc
-from typing import Union, Callable
+from typing import Callable
from . import _ffi_api
@@ -27,8 +28,7 @@ def Simplify(simplify_arguments: bool = False):
return _ffi_api.Simplify(simplify_arguments) # type: ignore
-def _Simplify(stmt: Union[PrimFunc, IRModule],
- inline_let: bool = False) -> Union[PrimFunc, IRModule]:
+def _Simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
if isinstance(stmt, PrimFunc):
if inline_let:
mod = LetInline()(IRModule.from_expr(stmt))
@@ -53,13 +53,12 @@ def _Simplify(stmt: Union[PrimFunc, IRModule],
def simplify_prim_func(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
- stmt: Union[PrimFunc, IRModule] = (func)(*args, **kwargs)
+ stmt: PrimFunc | IRModule = (func)(*args, **kwargs)
return _Simplify(stmt)
return wrapper
-def apply_simplify(stmt: Union[PrimFunc, IRModule],
- inline_let: bool = False) -> Union[PrimFunc, IRModule]:
+def apply_simplify(stmt: PrimFunc | IRModule, inline_let: bool = False) -> PrimFunc | IRModule:
"""Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt, inline_let)
diff --git a/tilelang/utils/language.py b/tilelang/utils/language.py
index 2c0b4efad..0972175a8 100644
--- a/tilelang/utils/language.py
+++ b/tilelang/utils/language.py
@@ -1,5 +1,5 @@
+from __future__ import annotations
from tvm.tir import Buffer
-from typing import List, Optional
from functools import reduce
from tvm import IRModule
from tvm.tir import PrimFunc
@@ -85,7 +85,7 @@ def get_buffer_elems(buffer: Buffer) -> int:
return reduce(lambda x, y: x * y, buffer.shape)
-def array_reduce(array: List[int]) -> int:
+def array_reduce(array: list[int]) -> int:
"""
Reduce an array of integers to a single integer.
@@ -121,7 +121,7 @@ def retrieve_func_from_module(ir_module: IRModule) -> PrimFunc:
return func
-def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> Optional[tir.BufferRegion]:
+def get_buffer_region_from_load(buffer_load: tir.BufferLoad) -> tir.BufferRegion | None:
"""
Get the buffer region from a buffer load.
diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py
index 22cd95f21..cd364b8bb 100644
--- a/tilelang/utils/sparse.py
+++ b/tilelang/utils/sparse.py
@@ -1,7 +1,7 @@
+from __future__ import annotations
import os
import torch
import warnings
-from typing import Optional, Tuple
from tilelang.contrib import nvcc
from torch.utils.cpp_extension import load, _import_module_from_library
from tilelang import env
@@ -44,7 +44,7 @@ def _get_cached_lib():
def compress_sm90(A: torch.Tensor, block_k: int,
- transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
+ transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
if block_k > 128:
block_k = 128
# Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146
@@ -56,7 +56,7 @@ def compress_sm90(A: torch.Tensor, block_k: int,
return compress_lib.compress_sm90(A, block_k, transposed)
-def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torch.Tensor]:
+def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torch.Tensor]:
try:
from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
except ImportError as err:
@@ -75,8 +75,8 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> Tuple[torch.Tensor, torc
def compress(A: torch.Tensor,
transposed: bool,
- arch: Optional[str] = None,
- **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
+ arch: str | None = None,
+ **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compress a tensor using the appropriate method based on the CUDA architecture.
"""
diff --git a/tilelang/utils/target.py b/tilelang/utils/target.py
index 948308b81..094c099fe 100644
--- a/tilelang/utils/target.py
+++ b/tilelang/utils/target.py
@@ -1,12 +1,13 @@
+from __future__ import annotations
from platform import mac_ver
-from typing import Dict, Literal, Union
+from typing import Literal
from tilelang import tvm as tvm
from tilelang import _ffi_api
from tvm.target import Target
from tvm.contrib import rocm
from tilelang.contrib import nvcc
-SUPPORTED_TARGETS: Dict[str, str] = {
+SUPPORTED_TARGETS: dict[str, str] = {
"auto": "Auto-detect CUDA/HIP/Metal based on availability.",
"cuda": "CUDA GPU target (supports options such as `cuda -arch=sm_80`).",
"hip": "ROCm HIP target (supports options like `hip -mcpu=gfx90a`).",
@@ -17,7 +18,7 @@
}
-def describe_supported_targets() -> Dict[str, str]:
+def describe_supported_targets() -> dict[str, str]:
"""
Return a mapping of supported target names to usage descriptions.
"""
@@ -58,8 +59,8 @@ def check_metal_availability() -> bool:
return arch == 'arm64'
-def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
- return_object: bool = False) -> Union[str, Target]:
+def determine_target(target: str | Target | Literal["auto"] = "auto",
+ return_object: bool = False) -> str | Target:
"""
Determine the appropriate target for compilation (CUDA, HIP, or manual selection).
@@ -76,7 +77,7 @@ def determine_target(target: Union[str, Target, Literal["auto"]] = "auto",
AssertionError: If the target is invalid.
"""
- return_var: Union[str, Target] = target
+ return_var: str | Target = target
if target == "auto":
target = tvm.target.Target.current(allow_none=True)
diff --git a/version_provider.py b/version_provider.py
index c5aa42210..31a7e8ad5 100644
--- a/version_provider.py
+++ b/version_provider.py
@@ -3,7 +3,6 @@
import os
import platform
import subprocess
-from typing import Optional
from pathlib import Path
ROOT = Path(__file__).parent
@@ -17,13 +16,12 @@ def _read_cmake_bool(i: str | None, default=False):
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '')
-def get_git_commit_id() -> Optional[str]:
+def get_git_commit_id() -> str | None:
"""Get the current git commit hash by running git in the current file's directory."""
r = subprocess.run(['git', 'rev-parse', 'HEAD'],
cwd=ROOT,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
+ capture_output=True,
encoding='utf-8')
if r.returncode == 0:
return r.stdout.strip()