Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# -*- coding: utf-8 -*-

# General information about the project.
project = "Tile Language <br>"
author = "Tile Lang Contributors"
copyright = "2025-2025, %s" % author
copyright = f"2025-2025, {author}"

# Version information.
with open("../VERSION", "r") as f:
with open("../VERSION") as f:
version = f.read().strip()
release = version

Expand Down
18 changes: 14 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,25 @@ target-version = "py38"
line-length = 100
output-format = "full"

exclude = [
"3rdparty",
"examples/deepseek_v32/inference",
]

[tool.ruff.lint.per-file-ignores]
# Do not upgrade type hint in testing and examples.
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
"testing/**.py" = ["UP", "FA"]
"examples/**.py" = ["UP", "FA"]

[tool.ruff.lint]
select = [
# pycodestyle
"E", "W",
# Pyflakes
"F",
# pyupgrade
# "UP",
"UP", "FA",
# flake8-bugbear
"B",
# flake8-simplify
Expand All @@ -115,16 +126,15 @@ ignore = [
"SIM108",
# key in dict.keys()
"SIM118",
# open file w.o. ctx manager
"SIM115",
# memory leaks
"B019",
# zip without explicit strict
"B905",
# No such file or directory
"E902",
]
[tool.ruff.lint.per-file-ignores]
"3rdparty/**/*" = ["ALL"]
"examples/deepseek_v32/inference/**/*" = ["ALL"]

[tool.pytest.ini_options]
verbosity_assertions = 3
Expand Down
7 changes: 4 additions & 3 deletions tilelang/autotuner/capture.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import threading
from typing import List, Any, Optional
from typing import Any

# Use thread local to store the stack
# This is to avoid the cross-thread interference
Expand Down Expand Up @@ -87,7 +88,7 @@ class AutotuneInputsCapture:

__slots__ = ("tensors")

def __init__(self, tensors: List[Any]):
def __init__(self, tensors: list[Any]):
self.tensors = tensors

def __enter__(self) -> None:
Expand Down Expand Up @@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture:
return AutotuneInputsCapture(tensors)


def get_autotune_inputs() -> Optional[List[Any]]:
def get_autotune_inputs() -> list[Any] | None:
"""
Get the current autotune inputs from the stack.
"""
Expand Down
39 changes: 20 additions & 19 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""The auto-tune parameters.
"""
from __future__ import annotations

import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc
from tvm.target import Target
from typing import Callable, List, Literal, Any, Optional, Union, Dict
from typing import Callable, Literal, Any
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -47,12 +48,12 @@ class CompileArgs:
"tl.disable_safe_memory_legalize": bool, default: False
"""

out_idx: Optional[Union[List[int], int]] = None
out_idx: list[int] | int | None = None
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
target: Literal['auto', 'cuda', 'hip'] = 'auto'
target_host: Union[str, Target] = None
target_host: str | Target = None
verbose: bool = False
pass_configs: Optional[Dict[str, Any]] = None
pass_configs: dict[str, Any] | None = None

def compile_program(self, program: PrimFunc):
return tilelang.compile(
Expand Down Expand Up @@ -142,12 +143,12 @@ class AutotuneResult:
func: Optimized function.
kernel: Compiled kernel function.
"""
latency: Optional[float] = None
config: Optional[dict] = None
ref_latency: Optional[float] = None
libcode: Optional[str] = None
func: Optional[Callable] = None
kernel: Optional[Callable] = None
latency: float | None = None
config: dict | None = None
ref_latency: float | None = None
libcode: str | None = None
func: Callable | None = None
kernel: Callable | None = None

def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
"""
Expand Down Expand Up @@ -211,9 +212,9 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
def _load_kernel_from_disk(
self,
cache_path: Path,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: Optional[Union[List[int], int]] = None,
target: str | Target = "auto",
target_host: str | Target = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
pass_configs: dict = None,
func: Callable = None,
Expand All @@ -239,14 +240,14 @@ def _load_kernel_from_disk(
if not os.path.exists(cache_path):
return None

kernel_global_source: Optional[str] = None
kernel_params: Optional[List[KernelParam]] = None
kernel_global_source: str | None = None
kernel_params: list[KernelParam] | None = None

try:
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
if verbose:
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
with open(wrapped_kernel_path, "r") as f:
with open(wrapped_kernel_path) as f:
kernel_global_source = f.read()
except Exception as e:
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
Expand Down Expand Up @@ -307,15 +308,15 @@ def save_to_disk(self, path: Path, verbose: bool = False):
self._save_kernel_to_disk(path, self.kernel)

@classmethod
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult:
if not os.path.exists(path):
return None

verbose = compile_args.verbose
# load best config
if verbose:
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
with open(path / BEST_CONFIG_PATH, "r") as f:
with open(path / BEST_CONFIG_PATH) as f:
config = json.load(f)

# load function
Expand All @@ -327,7 +328,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul
# load latency
if verbose:
logger.debug(f"Loading latency from file: {path / LATENCY_PATH}")
with open(path / LATENCY_PATH, "r") as f:
with open(path / LATENCY_PATH) as f:
latency = json.load(f)
latency, ref_latency = latency["latency"], latency["ref_latency"]

Expand Down
33 changes: 17 additions & 16 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
and performance optimization through configuration search.
"""
from __future__ import annotations

import tilelang
from tilelang import tvm as tvm
from tvm.tir import PrimFunc, Var
from tvm.target import Target
import inspect
from functools import partial
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
from typing import (Callable, Literal, Any, overload)
from tqdm import tqdm
import logging
import functools
Expand Down Expand Up @@ -103,8 +104,8 @@ class AutoTuner:
compile_args = CompileArgs()
profile_args = ProfileArgs()

_kernel_parameters: Optional[Tuple[str, ...]] = None
_function_parameters: Optional[Dict[str, Any]] = None
_kernel_parameters: tuple[str, ...] | None = None
_function_parameters: dict[str, Any] | None = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
Expand All @@ -131,12 +132,12 @@ def from_kernel(cls, kernel: Callable, configs):
return cls(kernel, configs)

def set_compile_args(self,
out_idx: Union[List[int], int, None] = None,
out_idx: list[int] | int | None = None,
target: Literal['auto', 'cuda', 'hip'] = 'auto',
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
target_host: Union[str, Target] = None,
target_host: str | Target = None,
verbose: bool = False,
pass_configs: Optional[Dict[str, Any]] = None):
pass_configs: dict[str, Any] | None = None):
"""Set compilation arguments for the auto-tuner.

Args:
Expand Down Expand Up @@ -223,12 +224,12 @@ def set_profile_args(self,

return self

def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
# for cache key generation
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters

def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process.
"""

Expand Down Expand Up @@ -307,8 +308,8 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
return result

best_latency: float = 1e8
best_config: Optional[Dict[str, Any]] = None
best_kernel: Optional[tilelang.JITKernel] = None
best_config: dict[str, Any] | None = None
best_kernel: tilelang.JITKernel | None = None

def _compile(**config_arg) -> tilelang.JITKernel:
compile_args = self.compile_args
Expand Down Expand Up @@ -591,7 +592,7 @@ class _AutoTunerImplementation:
warmup: int = 25
rep: int = 100
timeout: int = 100
configs: Union[Dict, Callable] = None
configs: dict | Callable = None
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
ref_prog: Callable = None
supply_prog: Callable = None
Expand All @@ -603,7 +604,7 @@ class _AutoTunerImplementation:
cache_input_tensors: bool = False

def __init__(self,
configs: Union[Dict, Callable],
configs: dict | Callable,
warmup: int = 25,
rep: int = 100,
timeout: int = 100,
Expand Down Expand Up @@ -653,12 +654,12 @@ def __init__(self,
self.cache_input_tensors = cache_input_tensors # Reuse inputs

# Cache for storing tuned kernel implementations
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel

# This tells the type checker what the *wrapper* function will return.
# this is for linting, please do not remove it.
@overload
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
...

@overload
Expand Down Expand Up @@ -720,9 +721,9 @@ def jit_compile(**config_arg):


def autotune( # This is the new public interface
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
func: Callable[_P, _RProg] | PrimFunc | None = None,
*, # Indicates subsequent arguments are keyword-only
configs: Union[Dict, Callable],
configs: dict | Callable,
# profile arguments
warmup: int = 25,
rep: int = 100,
Expand Down
17 changes: 9 additions & 8 deletions tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The cache utils with class and database persistence - Init file"""
from __future__ import annotations

from typing import List, Union, Literal, Optional
from typing import Literal
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
Expand All @@ -13,14 +14,14 @@

def cached(
func: PrimFunc = None,
out_idx: List[int] = None,
out_idx: list[int] = None,
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
compile_flags: Optional[Union[List[str], str]] = None,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython",
verbose: bool | None = False,
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
) -> JITKernel:
"""
Caches and reuses compiled kernels (using KernelCache class).
Expand Down
Loading
Loading