-
Notifications
You must be signed in to change notification settings - Fork 270
[Lint] Enable pyupgrade linter in ruff #963
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughModernizes Python code formatting and syntax across the codebase: converts Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Rationale: Predominantly homogeneous, repetitive changes—conversion of Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
5b9312f
to
0973611
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/jit/adapter/utils.py (1)
78-81
: Guardattrs["global_symbol"]
access on PrimFunc
Accessingfunc_or_mod.attrs["global_symbol"]
may KeyError if the attribute isn’t set (e.g. PrimFunc created in parser). Useattrs.get("global_symbol")
with a clear fallback or assert its presence before indexing.
🧹 Nitpick comments (2)
tilelang/carver/roller/shape_inference/tir.py (1)
354-354
: LGTM: Modernized string formatting.The f-string is clearer and more readable than the old
%
formatting style.For completeness, the static analysis tool suggests considering
TypeError
for invalid type errors and potentially creating a custom exception class, but these are minor style improvements that can be deferred.tilelang/carver/roller/node.py (1)
304-304
: LGTM: Simplified decorator syntax.The simplified
@functools.lru_cache
(without parentheses) is valid and cleaner in Python 3.8+.Note: Static analysis warns that using
lru_cache
on instance methods can lead to memory leaks because the cache holds references toself
, preventing garbage collection. This is an existing pattern in the codebase and not introduced by this change, but consider whether these methods truly need caching on instance methods or if the cache should be cleared when instances are no longer needed.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (72)
docs/conf.py
(1 hunks)examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(2 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(0 hunks)examples/bitnet-1.58b/eval_ppl.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(1 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(1 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(0 hunks)examples/bitnet-1.58b/utils_quant.py
(1 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(1 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(2 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(2 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(1 hunks)examples/cast/example_per_token_cast_to_fp8.py
(2 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(2 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(1 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_bwd.py
(2 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(2 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_decoding/example_gqa_decode.py
(1 hunks)examples/flash_decoding/example_mha_inference.py
(1 hunks)examples/grouped_gemm/example_grouped_gemm_bwd.py
(1 hunks)examples/hadamard_transform/example_hadamard.py
(1 hunks)examples/linear_attention/example_mamba_chunk_scan.py
(1 hunks)examples/linear_attention/example_mamba_chunk_state.py
(2 hunks)examples/minference/example_vertical_slash_sparse_attn.py
(1 hunks)examples/norm/rms_norm.py
(1 hunks)pyproject.toml
(1 hunks)setup.py
(6 hunks)testing/python/kernel/test_tilelang_kernel_gemm.py
(1 hunks)testing/python/kernel/test_tilelang_kernel_gemm_simt.py
(1 hunks)testing/python/language/test_tilelang_language_pipeline.py
(1 hunks)tilelang/autotuner/param.py
(3 hunks)tilelang/cache/kernel_cache.py
(1 hunks)tilelang/carver/arch/cuda.py
(1 hunks)tilelang/carver/arch/metal.py
(1 hunks)tilelang/carver/roller/bestfit.py
(1 hunks)tilelang/carver/roller/hint.py
(1 hunks)tilelang/carver/roller/node.py
(4 hunks)tilelang/carver/roller/rasterization.py
(1 hunks)tilelang/carver/roller/shape_inference/common.py
(2 hunks)tilelang/carver/roller/shape_inference/tir.py
(3 hunks)tilelang/contrib/hipcc.py
(1 hunks)tilelang/intrinsics/mfma_macro_generator.py
(1 hunks)tilelang/intrinsics/mma_macro_generator.py
(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/jit/adapter/cython/adapter.py
(3 hunks)tilelang/jit/adapter/libgen.py
(1 hunks)tilelang/jit/adapter/utils.py
(3 hunks)tilelang/jit/adapter/wrapper.py
(6 hunks)tilelang/jit/kernel.py
(1 hunks)tilelang/language/proxy.py
(4 hunks)tilelang/quantize/lop3.py
(1 hunks)tilelang/quantize/quantization.py
(2 hunks)tilelang/tileop/gemm/gemm_base.py
(1 hunks)tilelang/version.py
(1 hunks)
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/tokenization_bitnet.py
- examples/bitnet-1.58b/configuration_bitnet.py
🧰 Additional context used
🧬 Code graph analysis (25)
examples/minference/example_vertical_slash_sparse_attn.py (1)
tilelang/language/builtin.py (1)
mbarrier_wait_parity
(172-219)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/carver/roller/shape_inference/common.py (1)
tilelang/carver/roller/shape_inference/tir.py (2)
Statement
(7-43)InputShapeInference
(169-318)
examples/flash_decoding/example_mha_inference.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/intrinsics/wgmma_macro_generator.py (1)
tilelang/utils/language.py (1)
is_fragment
(68-78)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
maybe_contiguous
(356-359)tl_bwd
(484-485)
examples/linear_attention/example_mamba_chunk_state.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (4)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)examples/flash_attention/example_gqa_bwd.py (1)
maybe_contiguous
(377-380)examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
maybe_contiguous
(134-137)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_decoding/example_gqa_decode.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/linear_attention/example_mamba_chunk_scan.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/intrinsics/mma_macro_generator.py (2)
tilelang/intrinsics/wgmma_macro_generator.py (1)
TensorCoreIntrinEmitter
(64-519)tilelang/utils/language.py (1)
is_fragment
(68-78)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
examples/flash_attention/example_mha_bwd.py (1)
maybe_contiguous
(258-261)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
maybe_contiguous
(364-367)examples/flash_attention/example_gqa_bwd.py (1)
maybe_contiguous
(377-380)
examples/flash_attention/example_mha_bwd.py (9)
examples/amd/example_amd_flash_attn_bwd.py (2)
maybe_contiguous
(242-245)run1
(340-341)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
maybe_contiguous
(364-367)examples/flash_attention/example_gqa_bwd.py (2)
maybe_contiguous
(377-380)run1
(514-515)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
maybe_contiguous
(399-402)run1
(536-537)examples/flash_attention/example_mha_bwd_bhsd.py (2)
maybe_contiguous
(259-262)run1
(336-337)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (2)
maybe_contiguous
(268-271)run1
(343-344)examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
maybe_contiguous
(134-137)tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/amd/example_amd_flash_attn_bwd.py (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)
examples/flash_attention/example_gqa_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/jit/adapter/utils.py (1)
tilelang/language/ast/ir.py (1)
target
(1682-1713)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (2)
maybe_contiguous
(377-380)run1
(514-515)
examples/flash_attention/example_gqa_bwd.py (1)
examples/amd/example_amd_flash_attn_bwd.py (2)
maybe_contiguous
(242-245)run1
(340-341)
🪛 GitHub Actions: CI
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: Reformatted by formatter. Please review and stage the changes.
docs/conf.py
[error] 1-1: Reformatted by formatter. Please review and stage the changes.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: Reformatted by formatter. Please review and stage the changes.
🪛 GitHub Actions: CI Test on AMD
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: Reformatted files. Please review and stage the changes.
docs/conf.py
[error] 1-1: Reformatted files. Please review and stage the changes.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 GitHub Actions: CI Test on Metal
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: Reformatted by formatting script. Changes not staged for commit.
docs/conf.py
[error] 1-1: Reformatted by formatting script. Changes not staged for commit.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: Reformatted by formatting script. Changes not staged for commit.
🪛 Ruff (0.13.3)
tilelang/carver/roller/node.py
304-304: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
421-421: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
tilelang/quantize/lop3.py
1189-1189: Avoid specifying long messages outside the exception class
(TRY003)
setup.py
312-312: Unused noqa
directive (non-enabled: SIM115
)
Remove unused noqa
directive
(RUF100)
tilelang/carver/roller/shape_inference/tir.py
354-354: Prefer TypeError
exception for invalid type
(TRY004)
354-354: Create your own exception
(TRY002)
354-354: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (96)
examples/linear_attention/example_mamba_chunk_state.py (1)
171-175
: Print modernisation looks goodNice switch to f-strings—same formatting semantics, cleaner syntax.
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
198-202
: Nice f-string modernization.Formatting stays identical while matching the new lint rule scope. 👍
examples/flash_attention/example_gqa_fwd_bshd.py (1)
251-255
: F-string migration looks good.Formatting precision and computed values are unchanged; the output stays consistent while modernizing the style.
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
211-215
: F-string conversion looks good.Formatting specifiers are preserved, so output remains identical while complying with pyupgrade.
examples/bitnet-1.58b/eval_ppl.py (1)
57-57
: LGTM! Clean f-string conversion.The modernization from legacy string formatting to f-strings improves readability and aligns with Python 3.6+ best practices. This change is consistent with the PR's goal of enabling pyupgrade linter rules.
examples/hadamard_transform/example_hadamard.py (1)
154-154
: LGTM! Clean modernization to f-string.The conversion from
.format()
to f-string improves readability and aligns with modern Python best practices while maintaining identical functionality.tilelang/carver/roller/bestfit.py (1)
20-20
: LGTM! Modernization to f-string improves readability.The conversion from
str.format
to an f-string is correct and aligns with the PR objective to enable pyupgrade linter. F-strings are more readable and typically faster than older formatting methods.examples/bitnet-1.58b/utils_quant.py (1)
219-219
: LGTM! Modernized super() call syntax.The change from
super(BitLinear, self).__init__(*kargs, **kwargs)
tosuper().__init__(*kargs, **kwargs)
correctly modernizes the code to use Python 3+ idiomatic syntax. The behavior remains identical.tilelang/carver/arch/metal.py (1)
1-1
: LGTM!Adding
from __future__ import annotations
enables postponed evaluation of type annotations (PEP 563), which is a modern Python practice and aligns with the PR's objective to modernize the codebase. This works well with the existing PEP 604 union syntax (Target | str
) on line 12.examples/deepseek_mla/example_mla_decode_paged.py (1)
403-404
: LGTM! Clean modernization to f-strings.The conversion from
format()
to f-strings is correct and improves readability while maintaining identical output semantics.examples/flash_decoding/example_gqa_decode.py (1)
476-480
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and follows Python best practices. F-strings are more readable, concise, and performant.examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (2)
225-226
: LGTM! Clean conversion to f-strings.The migration from
.format()
to f-strings is correct and follows modern Python conventions. The formatting specifiers and output remain identical.
228-229
: LGTM! Consistent f-string conversion.The conversion maintains identical output while improving code readability.
examples/linear_attention/example_mamba_chunk_scan.py (1)
232-236
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings improves readability while preserving the exact formatting (.2f
). This aligns with the PR objective to enable pyupgrade linting and follows Python 3.6+ best practices.examples/flash_attention/example_mha_fwd_bhsd.py (1)
206-210
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct, maintains identical output formatting (.2f
precision), and improves readability. These changes align with the PR objective of enabling pyupgrade linter rules.examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
537-538
: LGTM! Clean f-string conversion.The formatting is preserved (
.2f
for 2 decimal places) and the f-string syntax is more readable and Pythonic.examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (1)
553-554
: LGTM! Consistent f-string conversion.The change maintains identical formatting and improves code readability.
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
436-437
: LGTM! Consistent f-string conversion.The formatting is identical and the f-string syntax is cleaner and more maintainable.
examples/flash_attention/example_mha_fwd_bshd.py (1)
193-197
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and improves readability. The formatting specifications and variable references are properly preserved.examples/cast/example_per_token_cast_to_fp8.py (1)
103-117
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings is correct and follows modern Python best practices. The format specifications are preserved, and the output remains identical.examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
202-205
: LGTM! Consistent f-string modernization.The f-string conversions are correct and consistent with the broader codebase modernization in this PR.
tilelang/intrinsics/mfma_macro_generator.py (1)
14-14
: LGTM! Modern Python 3 class declaration.Removing the explicit
object
base class aligns with Python 3 best practices, as all classes implicitly inherit fromobject
.tilelang/intrinsics/wgmma_macro_generator.py (1)
408-408
: LGTM! Cleaner assertion message with f-string.The f-string format improves readability and aligns with modern Python conventions.
tilelang/intrinsics/mma_macro_generator.py (2)
28-28
: LGTM! Modern Python 3 class declaration.Removing the explicit
object
base class aligns with Python 3 best practices, consistent with the modernization inmfma_macro_generator.py
.
521-521
: LGTM! Cleaner assertion message with f-string.The f-string format improves readability and is consistent with the same change in
wgmma_macro_generator.py
.examples/amd/example_amd_flash_attn_bwd.py (2)
247-247
: LGTM! Generator expression improves efficiency.The change from list comprehension to generator expression is a good modernization. Since the generator is consumed immediately during unpacking, behavior is unchanged while memory efficiency is slightly improved.
346-350
: LGTM! F-string conversions improve readability.The conversion from
.format()
to f-strings is correct, with all formatting specifications (.2f
) properly preserved. This modernization improves code readability without changing behavior.examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (2)
273-273
: LGTM! Modern generator expression for unpacking.The change from a list comprehension to a generator expression is a good modernization. The generator is more memory-efficient since the values are immediately consumed during unpacking, and this aligns with pyupgrade recommendations.
349-353
: LGTM! F-string conversions are correct.The conversion to f-strings is a modern Python best practice that improves readability. All formatting specifications (
.2f
) are correctly preserved, and the logic remains unchanged.examples/flash_attention/example_mha_bwd_bhsd.py (2)
264-264
: LGTM! Generator expression for unpacking is correct.The change from list comprehension to generator expression is a valid modernization. While the memory benefit is minimal for 5 items, this aligns with pyupgrade's recommendations and works correctly.
342-346
: LGTM! F-string conversions are correct.The print statements have been properly converted from
.format()
to f-strings with correct formatting specifiers preserved.tilelang/quantize/quantization.py (2)
226-226
: LGTM! Redundant parentheses removed.The removal of the outer parentheses is a safe cosmetic improvement. The expression remains functionally identical, and the remaining parentheses correctly ensure the bitwise AND operation is evaluated before the shift.
235-235
: LGTM! Consistent style improvement.The redundant outer parentheses have been removed, making the expression cleaner while maintaining the correct evaluation order. This change is consistent with the improvement on line 226.
examples/flash_attention/example_mha_bwd.py (2)
263-263
: LGTM! Generator expression modernization.The change from list comprehension to generator expression for unpacking is a safe, standard Python modernization. Both are functionally equivalent when unpacking, but the generator expression is more memory-efficient.
339-343
: LGTM! F-string conversion.The conversion from
.format()
to f-strings is a standard Python modernization that improves readability. All format specifiers and expressions are correctly preserved.examples/flash_attention/example_gqa_bwd.py (2)
382-382
: LGTM! Generator expression is more memory-efficient.The change from list comprehension to generator expression for unpacking is a valid modernization. Both are functionally equivalent, and the generator expression avoids creating an intermediate list.
520-524
: LGTM! F-strings improve readability.The conversion to f-strings modernizes the code and improves readability while correctly preserving the format specifiers for floating-point precision.
examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
139-139
: LGTM! Standard pyupgrade optimization.Replacing the list comprehension with a generator expression is correct and avoids allocating an intermediate list. This is a standard pyupgrade rule (UP015) for immediate unpacking.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (4)
183-183
: LGTM! Modernized super() call.The simplified
super().__init__()
syntax is the Python 3 standard and is preferred over the explicit class/self parameters.
357-357
: LGTM! F-string conversion.The f-string syntax is preferred over
.format()
and improves readability.
362-362
: LGTM! F-string conversion.The f-string syntax with multiple interpolations is more readable than the equivalent
.format()
call.
1-1
: Stage formatting changes.CI is failing due to unstaged formatter updates. Run
ruff --fix
(orblack .
if used) and commit all modified files.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
196-196
: LGTM! Modernized super() call.The simplified
super().__init__()
syntax is the Python 3 standard and is preferred over the explicit class/self parameters.
369-369
: LGTM! F-string conversion.The f-string syntax is preferred over
.format()
and improves readability.
372-372
: LGTM! F-string conversion.The f-string syntax with multiple interpolations is more readable than the equivalent
.format()
call.tilelang/contrib/hipcc.py (1)
57-57
: F-string update preserves behaviorSwitching to the f-string keeps the target path identical while satisfying pyupgrade.
testing/python/language/test_tilelang_language_pipeline.py (1)
106-107
: Redundant parentheses removal is safeDropping the extra parentheses leaves the casting logic untouched; nice stylistic cleanup.
tilelang/carver/roller/shape_inference/common.py (1)
7-22
: Modern class declarations look goodRemoving the explicit
(object)
base aligns with Python 3 style and has no runtime impact.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
3-427
: Formatting modernization retains semanticsThe added annotations import supports future typing tweaks, and the f-strings mirror the prior formatting exactly.
tilelang/version.py (1)
24-25
: Default text-modeopen
is equivalentDropping the explicit
'r'
keeps text reading behavior unchanged while satisfying the linter.examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
37-39
: Simplified file open stays readableUsing the default read mode is fine here and keeps the helper neat.
testing/python/kernel/test_tilelang_kernel_gemm.py (1)
98-99
: Parentheses trim keeps conversion intactThe tensor bitcast still executes exactly the same, so this satisfies the formatter without behavioral change.
tilelang/carver/roller/rasterization.py (1)
91-93
: F-string emission is straightforwardInjecting
panel_width
via an f-string preserves the generated code verbatim.pyproject.toml (1)
34-34
: LGTM! Pyupgrade linter enabled successfully.The addition of
"UP"
and"FA102"
to the ruff lint selection enables pyupgrade rules, which modernize Python syntax. This change aligns with the PR objectives and the style updates throughout the codebase.tilelang/jit/adapter/cython/adapter.py (1)
105-105
: LGTM! Simplified file open calls.Removing the explicit
"r"
mode is consistent with Python 3 conventions, where text read mode is the default. The behavior remains identical.Also applies to: 118-118, 135-135
tilelang/autotuner/param.py (1)
249-249
: LGTM! Simplified file open calls.Removing the explicit
"r"
mode follows Python 3 conventions. The default text read mode preserves the intended behavior.Also applies to: 318-318, 330-330
setup.py (1)
140-140
: LGTM! Simplified file open calls.Removing the explicit
"r"
mode follows Python 3 conventions and is consistent with pyupgrade recommendations.Also applies to: 171-171, 711-711, 723-723, 740-740
tilelang/carver/arch/cuda.py (1)
94-94
: LGTM! Removed redundant base class.In Python 3, all classes implicitly inherit from
object
, making the explicit base class declaration unnecessary. This modernization aligns with PEP 8 and pyupgrade recommendations.examples/bitnet-1.58b/modeling_bitnet.py (1)
1515-1515
: LGTM! Simplified file open call.Removing the explicit
"r"
mode is consistent with Python 3 conventions and the modernization pattern applied throughout the codebase.examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
88-88
: LGTM! Simplified file open call.The change follows Python 3 conventions and is consistent with the modernization applied across the codebase.
examples/norm/rms_norm.py (1)
74-74
: LGTM! Modernized string formatting.Replacing
.format()
with f-strings improves readability and performance. This is a recommended Python 3.6+ practice and aligns with pyupgrade recommendations.Also applies to: 76-76
tilelang/cache/kernel_cache.py (1)
366-368
: No behavior change from dropping"r"
open()
defaults to text-read mode, so this remains a safe, no-op cleanup.tilelang/quantize/lop3.py (1)
1189-1193
: F-string conversion looks goodSwitching to an f-string keeps the message identical while matching the new lint expectations.
docs/conf.py (1)
5-9
: Docs tweaks are soundBoth the f-string and the implicit read mode are clean modernizations with no functional impact.
tilelang/carver/roller/hint.py (1)
150-152
: Class declaration cleanup is correctRemoving
(object)
is idiomatic in Python 3 and keeps semantics identical.tilelang/jit/kernel.py (1)
20-32
: Python 3 class style looks goodDropping the explicit
object
base keeps behavior the same while modernizing the declaration.examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
2-2
: LGTM: Future annotations import added.Enables PEP 563 postponed evaluation of annotations, consistent with modern Python 3 practices.
289-293
: LGTM: Format strings modernized to f-strings.The conversion from
.format()
to f-strings is correct and improves readability. The numeric formatting and calculations remain unchanged.examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
3-3
: LGTM: Modern Python formatting applied.The addition of future annotations and conversion to f-strings are standard modernizations that improve code readability without changing functionality.
Also applies to: 438-446
testing/python/kernel/test_tilelang_kernel_gemm_simt.py (1)
109-109
: LGTM: Removed redundant parentheses.The extra parentheses around the integer expression were unnecessary and have been correctly removed.
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
2-2
: LGTM: Modern Python patterns applied.The future annotations import and generator expression (instead of list comprehension) are appropriate modernizations. The generator expression is safe here since values are immediately unpacked into separate variables.
Also applies to: 361-361
488-492
: LGTM: Format strings modernized to f-strings.The conversion maintains the same output formatting while improving readability.
examples/flash_decoding/example_mha_inference.py (1)
321-325
: LGTM: Format strings modernized to f-strings.The conversion correctly maintains the formatting precision (
.2f
and.4f
) while improving code readability.examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
404-404
: LGTM: Generator expression applied.The switch from list comprehension to generator expression is more memory efficient and safe here since values are immediately unpacked.
542-546
: LGTM: Format strings modernized to f-strings.The conversion maintains the same output formatting while improving readability.
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
2-2
: LGTM: Future annotations import added.Enables PEP 563 forward reference support, consistent with modern Python typing practices.
369-369
: LGTM: Generator expression unpacking.More memory-efficient than creating an intermediate list for unpacking, while maintaining the same functionality.
492-496
: LGTM: Modernized to f-strings.Cleaner and more readable than
.format()
calls, consistent with modern Python style.tilelang/carver/roller/shape_inference/tir.py (2)
50-50
: LGTM: Removed redundant explicit base class.In Python 3, all classes implicitly inherit from
object
, so the explicit base is unnecessary.
79-79
: LGTM: Removed redundant explicit base class.Consistent with modern Python 3 style where
object
inheritance is implicit.tilelang/carver/roller/node.py (3)
32-32
: LGTM: Removed redundant explicit base class.Modern Python 3 style where
object
inheritance is implicit.
93-93
: LGTM: Removed redundant explicit base class.Consistent with Python 3 conventions.
421-421
: LGTM: Simplified decorator syntax.Same modernization as line 304, consistent with Python 3.8+ conventions.
tilelang/jit/adapter/wrapper.py (6)
179-179
: LGTM: Removed redundant explicit base class.Modern Python 3 style.
325-326
: LGTM: Modernized to f-strings.Cleaner and more readable for constructing the
dim3
strings.
350-351
: LGTM: Modernized to f-strings.Consistent string formatting improvements for kernel launch and error checking.
895-895
: LGTM: Removed redundant explicit base class.Consistent with Python 3 conventions.
993-993
: LGTM: Modernized to f-string.Cleaner string formatting for the function call construction.
1062-1062
: LGTM: Removed redundant explicit base class.Consistent Python 3 style.
tilelang/jit/adapter/utils.py (2)
4-4
: LGTM: Modernized type hints.Updated to use PEP 604 union syntax (
|
) and built-intuple
instead ofUnion
andTuple
, consistent with Python 3.10+ conventions.Also applies to: 68-72, 110-110
75-76
: LGTM: Added input validation.Early validation ensures
model_type
is one of the expected values, providing clear error messages.tilelang/language/proxy.py (5)
4-4
: LGTM: Updated imports.Removed unnecessary imports as the built-in
tuple
is now used instead ofTuple
.
146-146
: LGTM: Modernized type hint.Using built-in
tuple
instead ofTuple
from typing module, consistent with Python 3.9+ style.
154-154
: LGTM: Modernized type hint.Using PEP 604 union syntax (
|
) and built-intuple
, consistent with modern Python conventions.
175-176
: LGTM: Modernized type hints.Consistent use of built-in
tuple
type.
276-276
: LGTM: Modernized type hint.Using PEP 604 union syntax (
str | None
) instead ofOptional[str]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
1-1
: Fix formatting issues before merging.Pipeline failures indicate that this file does not match the project's formatting standards:
- CI: "Reformatted by formatter. Changes not staged for commit."
- CI Test on Metal: "Reformatted files. Please review and stage the changes."
- CI Test on AMD: "clang-format reformatted file. Please review and stage the changes."
Please run the project's formatter (likely
yapf
based on pyproject.toml) on this file and commit the changes.
🧹 Nitpick comments (1)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
264-264
: Consider reverting to list comprehension for clarity.While unpacking a generator expression is syntactically valid, it's less common and arguably less clear than the list comprehension. For 5 elements, the memory/performance benefit is negligible.
If you prefer the more conventional pattern, apply this diff:
- do, q, k, v, o = (maybe_contiguous(x) for x in (do, q, k, v, o)) + do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (72)
docs/conf.py
(1 hunks)examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(2 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(0 hunks)examples/bitnet-1.58b/eval_ppl.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(1 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(1 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(0 hunks)examples/bitnet-1.58b/utils_quant.py
(1 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(1 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(2 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(2 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(1 hunks)examples/cast/example_per_token_cast_to_fp8.py
(2 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(2 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(1 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_bwd.py
(2 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(2 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_decoding/example_gqa_decode.py
(1 hunks)examples/flash_decoding/example_mha_inference.py
(1 hunks)examples/grouped_gemm/example_grouped_gemm_bwd.py
(1 hunks)examples/hadamard_transform/example_hadamard.py
(1 hunks)examples/linear_attention/example_mamba_chunk_scan.py
(1 hunks)examples/linear_attention/example_mamba_chunk_state.py
(2 hunks)examples/minference/example_vertical_slash_sparse_attn.py
(1 hunks)examples/norm/rms_norm.py
(1 hunks)pyproject.toml
(1 hunks)setup.py
(6 hunks)testing/python/kernel/test_tilelang_kernel_gemm.py
(1 hunks)testing/python/kernel/test_tilelang_kernel_gemm_simt.py
(1 hunks)testing/python/language/test_tilelang_language_pipeline.py
(1 hunks)tilelang/autotuner/param.py
(3 hunks)tilelang/cache/kernel_cache.py
(1 hunks)tilelang/carver/arch/cuda.py
(1 hunks)tilelang/carver/arch/metal.py
(1 hunks)tilelang/carver/roller/bestfit.py
(1 hunks)tilelang/carver/roller/hint.py
(1 hunks)tilelang/carver/roller/node.py
(4 hunks)tilelang/carver/roller/rasterization.py
(1 hunks)tilelang/carver/roller/shape_inference/common.py
(2 hunks)tilelang/carver/roller/shape_inference/tir.py
(3 hunks)tilelang/contrib/hipcc.py
(1 hunks)tilelang/intrinsics/mfma_macro_generator.py
(1 hunks)tilelang/intrinsics/mma_macro_generator.py
(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/jit/adapter/cython/adapter.py
(3 hunks)tilelang/jit/adapter/libgen.py
(1 hunks)tilelang/jit/adapter/utils.py
(3 hunks)tilelang/jit/adapter/wrapper.py
(6 hunks)tilelang/jit/kernel.py
(1 hunks)tilelang/language/proxy.py
(4 hunks)tilelang/quantize/lop3.py
(1 hunks)tilelang/quantize/quantization.py
(2 hunks)tilelang/tileop/gemm/gemm_base.py
(1 hunks)tilelang/version.py
(1 hunks)
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/tokenization_bitnet.py
- examples/bitnet-1.58b/configuration_bitnet.py
🧰 Additional context used
🧬 Code graph analysis (23)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/intrinsics/mma_macro_generator.py (2)
tilelang/intrinsics/wgmma_macro_generator.py (1)
TensorCoreIntrinEmitter
(64-519)tilelang/utils/language.py (1)
is_fragment
(68-78)
examples/flash_decoding/example_gqa_decode.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/intrinsics/wgmma_macro_generator.py (1)
tilelang/utils/language.py (1)
is_fragment
(68-78)
examples/flash_decoding/example_mha_inference.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_bwd.py (2)
examples/amd/example_amd_flash_attn_bwd.py (2)
maybe_contiguous
(242-245)run1
(340-341)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
maybe_contiguous
(399-402)run1
(536-537)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)examples/flash_attention/example_mha_bwd.py (1)
maybe_contiguous
(258-261)
examples/minference/example_vertical_slash_sparse_attn.py (1)
tilelang/language/builtin.py (1)
mbarrier_wait_parity
(172-219)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/linear_attention/example_mamba_chunk_scan.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/jit/adapter/utils.py (1)
tilelang/language/ast/ir.py (1)
target
(1682-1713)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)
examples/amd/example_amd_flash_attn_bwd.py (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)
examples/flash_attention/example_mha_bwd_bhsd.py (3)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)examples/flash_attention/example_mha_bwd.py (1)
maybe_contiguous
(258-261)
examples/linear_attention/example_mamba_chunk_state.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
maybe_contiguous
(364-367)tl_bwd
(488-489)
examples/flash_attention/example_mha_bwd.py (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (2)
maybe_contiguous
(377-380)run1
(514-515)
🪛 GitHub Actions: CI
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: Reformatted by formatter. Changes not staged for commit.
docs/conf.py
[error] 1-1: Reformatted by formatter. Changes not staged for commit.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: Reformatted by formatter. Changes not staged for commit.
🪛 GitHub Actions: CI Test on AMD
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: clang-format reformatted file. Please review and stage the changes.
docs/conf.py
[error] 1-1: clang-format reformatted file. Please review and stage the changes.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: clang-format reformatted file. Please review and stage the changes.
🪛 GitHub Actions: CI Test on Metal
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
[error] 1-1: Reformatted files. Please review and stage the changes. Changes not staged for commit.
docs/conf.py
[error] 1-1: Reformatted files. Please review and stage the changes. Changes not staged for commit.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
[error] 1-1: Reformatted files. Please review and stage the changes. Changes not staged for commit.
🪛 Ruff (0.13.3)
tilelang/quantize/lop3.py
1189-1189: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/carver/roller/shape_inference/tir.py
354-354: Prefer TypeError
exception for invalid type
(TRY004)
354-354: Create your own exception
(TRY002)
354-354: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/carver/roller/node.py
304-304: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
421-421: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
setup.py
312-312: Unused noqa
directive (non-enabled: SIM115
)
Remove unused noqa
directive
(RUF100)
🔇 Additional comments (70)
examples/flash_decoding/example_mha_inference.py (1)
321-325
: F-string conversion looks goodFormatting and precision stay the same, readability improves.
testing/python/kernel/test_tilelang_kernel_gemm.py (1)
98-99
: Redundant parentheses cleanup looks good.The simplified expression preserves behavior and keeps the intent clear.
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
202-205
: LGTM on the f-string swapStraightforward modernization—f-strings keep the same output while matching the new Ruff style expectations.
examples/hadamard_transform/example_hadamard.py (1)
154-154
: LGTM! Modern f-string formatting applied correctly.The conversion from
.format()
to f-string is correct and improves readability while maintaining identical output.examples/cast/example_per_token_cast_to_fp8.py (1)
103-103
: LGTM! Clean modernization to f-strings.The migration from
.format()
to f-strings is correct and aligns with the PR objective of enabling pyupgrade linter. The formatting specifications are preserved accurately.Also applies to: 105-105, 117-117
examples/norm/rms_norm.py (1)
74-76
: LGTM! Clean modernization to f-strings.The conversion from format strings to f-strings is correct and improves readability. The formatting specifiers are preserved accurately.
examples/bitnet-1.58b/eval_ppl.py (1)
57-57
: LGTM! F-string modernization improves readability.The conversion to an f-string is correct and aligns with modern Python best practices. It also maintains consistency with the existing f-string usage at line 53.
testing/python/language/test_tilelang_language_pipeline.py (1)
106-107
: LGTM! Clean formatting improvement.Removing the redundant outer parentheses improves readability without affecting functionality. The tfloat32 conversion logic remains correct.
tilelang/carver/roller/rasterization.py (1)
92-92
: LGTM! Clean f-string conversion.The modernization from
.format()
to an f-string is correct and aligns with the PR's objective to enable pyupgrade linting.examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
493-494
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and improves readability. The formatting specifiers are preserved accurately, and the changes align with the PR's objective of modernizing Python style across the codebase.testing/python/kernel/test_tilelang_kernel_gemm_simt.py (1)
109-109
: LGTM! Redundant parentheses removed.The removal of redundant parentheses around
block_K // micro_size_k
improves code clarity and aligns with the usage ofT.serial
elsewhere in the file (lines 99, 110, 115, 121, 126). This change is part of the pyupgrade linter enforcement and has no functional impact.examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
198-202
: LGTM! Clean f-string conversion.The print statements have been correctly modernized to use f-strings, improving readability while maintaining identical output formatting. The conversions follow Python best practices and align with the PR's objective to enable pyupgrade linting.
examples/bitnet-1.58b/utils_quant.py (1)
219-219
: LGTM! Modern Python 3 super() call.The modernization from
super(BitLinear, self).__init__(...)
tosuper().__init__(...)
is correct and follows Python 3 best practices. This change is consistent with the existing code inBitLinearBitBLAS.__init__
(line 44) and aligns with the PR's objective to modernize Python style.examples/linear_attention/example_mamba_chunk_state.py (2)
43-43
: LGTM! Redundant parentheses removed.The extra parentheses around the subtraction expression are unnecessary and have been correctly removed without affecting functionality.
171-175
: LGTM! Print statements modernized to f-strings.The conversion from
.format()
to f-strings improves readability and aligns with modern Python style. The formatting specifiers and output values remain identical.examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (2)
225-226
: LGTM! F-string conversion is correct.The conversion from
.format()
to f-strings is properly done and aligns with the PR objective of modernizing Python style.
228-229
: LGTM! Consistent f-string modernization.The f-string conversion matches the pattern used for the reference benchmark output above, maintaining consistency throughout the file.
examples/flash_attention/example_mha_bwd_bhsd.py (1)
342-346
: LGTM! Good modernization to f-strings.The conversion from
.format()
to f-strings improves readability and aligns with modern Python style.examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
37-37
: LGTM! Clean modernization.Removing the explicit
"r"
mode is appropriate since it's Python's default foropen()
. This change aligns with the pyupgrade linter recommendations.tilelang/carver/arch/metal.py (1)
1-1
: LGTM! Enables modern type hint syntax.Adding
from __future__ import annotations
is appropriate for this module. It enables the modern union syntax (Target | str
on line 12) and improves forward reference handling during type checking.examples/bitnet-1.58b/modeling_bitnet.py (1)
1515-1515
: LGTM! Consistent with file I/O modernization.Removing the explicit
"r"
mode follows the same pattern as other file I/O updates in this PR and aligns with pyupgrade recommendations.tilelang/intrinsics/wgmma_macro_generator.py (1)
408-408
: LGTM! F-string improves readability.The conversion from
.format()
to an f-string is correct and aligns with modern Python style guidelines enforced by pyupgrade.examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (1)
553-554
: LGTM! Print statements modernized.The f-string conversions for the latency and TFlops output are correct and improve code readability while maintaining identical output formatting.
tilelang/carver/roller/bestfit.py (1)
20-20
: LGTM! Cleaner repr implementation.The f-string conversion in the
__repr__
method is correct and makes the code more readable while preserving the exact string representation.examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
211-215
: LGTM! Benchmark output formatting improved.All four print statements have been correctly converted to f-strings, maintaining the same output format while improving code clarity and consistency with the rest of the PR.
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
88-88
: LGTM! Consistent file I/O modernization.Removing the explicit
"r"
mode is appropriate and aligns with the broader PR pattern of modernizing file I/O operations across the codebase.pyproject.toml (1)
34-34
: LGTM! Pyupgrade rules enabled.Enabling the UP (pyupgrade) and FA102 rules aligns with the PR objective to prevent issues by modernizing Python syntax across the codebase.
examples/flash_attention/example_mha_fwd_bhsd.py (1)
206-210
: LGTM! F-string conversion improves readability.The conversion from
.format()
to f-strings is a standard Python modernization that improves code readability while maintaining identical functionality.tilelang/carver/roller/hint.py (1)
150-150
: LGTM! Python 3 class declaration modernization.Removing explicit
object
inheritance is correct for Python 3, where all classes implicitly inherit fromobject
. This is a standard pyupgrade modernization with no behavioral change.tilelang/autotuner/param.py (3)
249-249
: LGTM! Default file mode simplification.Omitting the explicit
"r"
mode is correct since text-read mode is the default foropen()
. This is a standard pyupgrade modernization.
318-318
: LGTM! Default file mode simplification.Omitting the explicit
"r"
mode is correct since text-read mode is the default foropen()
.
330-330
: LGTM! Default file mode simplification.Omitting the explicit
"r"
mode is correct since text-read mode is the default foropen()
.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
183-183
: LGTM! Modern super() call.The modern
super()
syntax without explicit class and self arguments is the recommended Python 3 pattern and is functionally equivalent.
357-362
: LGTM! F-string conversions improve readability.The conversion to f-strings modernizes the code while maintaining identical output.
tilelang/jit/kernel.py (1)
20-20
: LGTM! Python 3 class declaration modernization.Removing explicit
object
inheritance is correct for Python 3. This is a standard pyupgrade modernization with no behavioral change.tilelang/jit/adapter/libgen.py (1)
32-32
: LGTM! Python 3 class declaration modernization.Removing explicit
object
inheritance is correct for Python 3. This is a standard pyupgrade modernization with no behavioral change.examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
404-404
: LGTM! Generator expression for unpacking is more efficient.Using a generator expression instead of a list comprehension for tuple unpacking is more memory-efficient. Python will consume the generator during unpacking to assign all values correctly.
542-546
: LGTM! F-string conversions improve readability.The conversion from
.format()
to f-strings is a standard Python modernization that improves code readability.tilelang/language/proxy.py (1)
4-4
: LGTM! Type hints modernized to Python 3.10+ syntax.The changes correctly update type hints to use:
- Built-in
tuple
instead oftyping.Tuple
- Union syntax
X | Y
instead ofUnion[X, Y]
- Union syntax
X | None
instead ofOptional[X]
These align with PEP 604 and PEP 585 best practices.
Also applies to: 146-146, 154-154, 175-176, 276-276
tilelang/jit/adapter/utils.py (1)
4-4
: LGTM! Type hints modernized and validation improved.The changes correctly:
- Replace
Dict
with built-indict[...]
- Use union syntax
X | Y
instead ofUnion[X, Y]
- Use
X | None
instead ofOptional[X]
The early validation for
model_type
at lines 74-76 is a good addition that provides clearer error messages.Also applies to: 67-72, 110-110
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
3-3
: LGTM! Future annotations import added.This enables postponed evaluation of annotations (PEP 563), which is required for the modern type hints used elsewhere in the codebase.
423-427
: LGTM! Print statements modernized to f-strings.The f-string conversions preserve the exact formatting (
.2f
) and improve readability. No functional changes.examples/flash_decoding/example_gqa_decode.py (1)
476-480
: LGTM! Print statements modernized to f-strings.The f-string conversions preserve the exact formatting (
.2f
) and improve code readability.tilelang/version.py (1)
24-24
: LGTM! Explicit read mode removed.The explicit
"r"
mode is redundant as it's the default foropen()
. This modernization aligns with Python best practices.docs/conf.py (1)
5-5
: LGTM! String formatting and file opening modernized.The changes correctly:
- Convert
%
formatting to f-string for better readability- Remove redundant explicit
"r"
mode (default foropen()
)However, note the pipeline failures indicating formatting issues that need to be staged.
Please address the pipeline failures mentioned in the CI:
- "Reformatted by formatter. Changes not staged for commit."
- "clang-format reformatted file. Please review and stage the changes."
These appear to be formatting issues that need to be staged for commit.
Also applies to: 8-8
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
2-2
: LGTM! Future annotations import added.This enables postponed evaluation of annotations (PEP 563), supporting the modern type hints used throughout the codebase.
361-361
: LGTM! List comprehension changed to generator expression.The generator expression is more memory-efficient and is safe here because the result is immediately unpacked into individual variables. This is a good optimization for transforming multiple items.
488-492
: LGTM! Print statements modernized to f-strings.The f-string conversions preserve exact formatting (
.2f
) and improve readability.examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
2-2
: LGTM! Future annotations import added.This enables postponed evaluation of annotations (PEP 563), which is necessary for modern type hint syntax.
289-293
: LGTM! Print statements modernized to f-strings.The f-string conversions preserve the exact formatting (
.2f
) and are more Pythonic.tilelang/cache/kernel_cache.py (1)
366-366
: LGTM! Safe modernization.Removing the explicit
"r"
mode is a safe Python 3 idiom since"r"
is the default foropen()
.examples/deepseek_mla/example_mla_decode_paged.py (1)
403-404
: LGTM! Modern f-string syntax.The conversion from
.format()
to f-strings is a safe, idiomatic modernization with identical output.examples/flash_attention/example_mha_bwd.py (2)
263-263
: LGTM! Memory-efficient generator unpacking.Replacing the list comprehension with a generator expression is safe here since the unpacking consumes exactly five elements. This reduces memory overhead by avoiding an intermediate list.
339-343
: LGTM! Modern f-string syntax.The conversion from
.format()
to f-strings is a safe, idiomatic modernization with identical output.tilelang/tileop/gemm/gemm_base.py (1)
12-12
: LGTM! Implicit object inheritance.In Python 3, all classes implicitly inherit from
object
, so removing the explicit base is a safe modernization. The@dataclass
decorator and class behavior remain unchanged.examples/amd/example_amd_flash_attn_bwd.py (2)
247-247
: LGTM! Memory-efficient generator unpacking.Replacing the list comprehension with a generator expression is safe here since the unpacking consumes exactly five elements. This reduces memory overhead by avoiding an intermediate list.
346-350
: LGTM! Modern f-string syntax.The conversion from
.format()
to f-strings is a safe, idiomatic modernization with identical output.setup.py (3)
140-140
: LGTM! Safe modernization.Removing the explicit
"r"
mode is a safe Python 3 idiom since"r"
is the default foropen()
.
171-171
: LGTM! Safe modernization.Removing the explicit
"r"
mode is a safe Python 3 idiom since"r"
is the default foropen()
.
711-740
: LGTM! Safe modernization.Removing the explicit
"r"
mode from text file opens is a safe Python 3 idiom since"r"
is the default. The behavior remains identical.examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
2-2
: LGTM! Forward-compatible annotations.Adding
from __future__ import annotations
enables PEP 563 (postponed evaluation of annotations), which is forward-compatible and allows using newer annotation syntax. This is a safe, future-proof change.
369-369
: LGTM! Memory-efficient generator unpacking.Replacing the list comprehension with a generator expression is safe here since the unpacking consumes exactly six elements. This reduces memory overhead by avoiding an intermediate list.
492-496
: LGTM! Modern f-string syntax.The conversion from
.format()
to f-strings is a safe, idiomatic modernization with identical output.examples/flash_attention/example_gqa_bwd.py (2)
382-382
: LGTM! Memory-efficient generator unpacking.Replacing the list comprehension with a generator expression is safe here since the unpacking consumes exactly five elements. This reduces memory overhead by avoiding an intermediate list.
520-524
: LGTM! Modern f-string syntax.The conversion from
.format()
to f-strings is a safe, idiomatic modernization with identical output.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
196-196
: LGTM! Modern super() syntax.The change from
super(SparseFlashAttn, self).__init__()
tosuper().__init__()
follows Python 3 best practices and is the idiomatic way to call parent class constructors.
369-369
: LGTM! f-string conversion.The change from
.format()
to f-string improves readability and aligns with modern Python style.
372-372
: LGTM! f-string conversion with multiple expressions.The conversion correctly maps all format arguments to f-string expressions, improving readability.
tilelang/jit/adapter/cython/adapter.py (1)
105-105
: LGTM! Default file mode usage.Removing the explicit
"r"
mode is correct since Python defaults to text read mode. This modernization aligns with pyupgrade linter rules and makes the code cleaner without changing behavior.Also applies to: 118-118, 135-135
examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
139-139
: Confirm generator-expression change
The switch from[maybe_contiguous(x) for x in …]
to(maybe_contiguous(x) for x in …)
for immediate unpacking isn’t a known pyupgrade rule and offers no benefit. Verify this refactor is intentional and actually produced by your pyupgrade/ruff setup (e.g., share your lint output or rule config).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (5)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
357-357
: Good f-string conversion with minor optimization opportunity.The f-string conversion improves readability. Consider simplifying further by eliminating the concatenation:
- print(name + f" all_close={all_close}") + print(f"{name} all_close={all_close}")examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
369-369
: Good f-string conversion with minor optimization opportunity.The f-string conversion improves readability. Consider simplifying further by eliminating the concatenation:
- print(name + f" all_close={all_close}") + print(f"{name} all_close={all_close}")examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
139-139
: Consider usingmap()
for better idiomaticity.The generator expression works correctly but is unconventional for immediate unpacking. The
map()
builtin is the standard Python pattern for applying a function to multiple values.Apply this diff to use the more idiomatic
map()
:- A, B, batch_sizes = (maybe_contiguous(x) for x in (A, B, batch_sizes)) + A, B, batch_sizes = map(maybe_contiguous, (A, B, batch_sizes))setup.py (1)
312-312
: Remove the unusednoqa
directive.The
# noqa: SIM115
comment is no longer needed since the code has been updated to address the linting issue.Apply this diff:
- return open(get_path("README.md"), encoding="utf-8").read() # noqa: SIM115 + return open(get_path("README.md"), encoding="utf-8").read()tilelang/carver/roller/shape_inference/tir.py (1)
354-354
: F-string migration looks good; consider usingTypeError
for type errors.The f-string conversion is correct. As an optional improvement, consider raising
TypeError
instead of genericException
when encountering unexpected types, as this provides clearer intent.Apply this diff if desired:
- raise Exception(f"Unhandled node type in walk_indice(): {expr}") + raise TypeError(f"Unhandled node type in walk_indice(): {type(expr).__name__}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (72)
docs/conf.py
(1 hunks)examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(3 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(2 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(2 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(0 hunks)examples/bitnet-1.58b/eval_ppl.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(1 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(1 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(0 hunks)examples/bitnet-1.58b/utils_quant.py
(1 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(1 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(2 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(2 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(1 hunks)examples/cast/example_per_token_cast_to_fp8.py
(2 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(2 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(1 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_bwd.py
(2 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(2 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(2 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_decoding/example_gqa_decode.py
(1 hunks)examples/flash_decoding/example_mha_inference.py
(1 hunks)examples/grouped_gemm/example_grouped_gemm_bwd.py
(1 hunks)examples/hadamard_transform/example_hadamard.py
(1 hunks)examples/linear_attention/example_mamba_chunk_scan.py
(1 hunks)examples/linear_attention/example_mamba_chunk_state.py
(2 hunks)examples/minference/example_vertical_slash_sparse_attn.py
(1 hunks)examples/norm/rms_norm.py
(1 hunks)pyproject.toml
(1 hunks)setup.py
(6 hunks)testing/python/kernel/test_tilelang_kernel_gemm.py
(1 hunks)testing/python/kernel/test_tilelang_kernel_gemm_simt.py
(1 hunks)testing/python/language/test_tilelang_language_pipeline.py
(1 hunks)tilelang/autotuner/param.py
(3 hunks)tilelang/cache/kernel_cache.py
(1 hunks)tilelang/carver/arch/cuda.py
(1 hunks)tilelang/carver/arch/metal.py
(1 hunks)tilelang/carver/roller/bestfit.py
(1 hunks)tilelang/carver/roller/hint.py
(1 hunks)tilelang/carver/roller/node.py
(4 hunks)tilelang/carver/roller/rasterization.py
(1 hunks)tilelang/carver/roller/shape_inference/common.py
(2 hunks)tilelang/carver/roller/shape_inference/tir.py
(3 hunks)tilelang/contrib/hipcc.py
(1 hunks)tilelang/intrinsics/mfma_macro_generator.py
(1 hunks)tilelang/intrinsics/mma_macro_generator.py
(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/jit/adapter/cython/adapter.py
(3 hunks)tilelang/jit/adapter/libgen.py
(1 hunks)tilelang/jit/adapter/utils.py
(3 hunks)tilelang/jit/adapter/wrapper.py
(6 hunks)tilelang/jit/kernel.py
(1 hunks)tilelang/language/proxy.py
(4 hunks)tilelang/quantize/lop3.py
(1 hunks)tilelang/quantize/quantization.py
(2 hunks)tilelang/tileop/gemm/gemm_base.py
(1 hunks)tilelang/version.py
(1 hunks)
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/configuration_bitnet.py
- examples/bitnet-1.58b/tokenization_bitnet.py
🧰 Additional context used
🧬 Code graph analysis (25)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/minference/example_vertical_slash_sparse_attn.py (1)
tilelang/language/builtin.py (1)
mbarrier_wait_parity
(172-219)
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/linear_attention/example_mamba_chunk_scan.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/jit/adapter/utils.py (1)
tilelang/language/ast/ir.py (1)
target
(1682-1713)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (2)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
maybe_contiguous
(364-367)examples/flash_attention/example_gqa_bwd.py (1)
maybe_contiguous
(377-380)
examples/flash_decoding/example_gqa_decode.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_bwd.py (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/carver/roller/shape_inference/common.py (1)
tilelang/carver/roller/shape_inference/tir.py (2)
Statement
(7-43)InputShapeInference
(169-318)
examples/linear_attention/example_mamba_chunk_state.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/intrinsics/mma_macro_generator.py (1)
tilelang/utils/language.py (1)
is_fragment
(68-78)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/grouped_gemm/example_grouped_gemm_bwd.py (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)
examples/flash_decoding/example_mha_inference.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_bwd.py (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)
examples/flash_attention/example_mha_fwd_bhsd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (1)
examples/flash_attention/example_gqa_bwd.py (2)
maybe_contiguous
(377-380)run1
(514-515)
examples/amd/example_amd_flash_attn_bwd.py (1)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
maybe_contiguous
(356-359)
examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (1)
examples/amd/example_amd_flash_attn_bwd.py (1)
maybe_contiguous
(242-245)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
examples/flash_attention/example_mha_bwd.py (1)
maybe_contiguous
(258-261)
examples/flash_attention/example_mha_bwd_bhsd.py (1)
examples/amd/example_amd_flash_attn_bwd.py (2)
maybe_contiguous
(242-245)run1
(340-341)
tilelang/intrinsics/wgmma_macro_generator.py (1)
tilelang/utils/language.py (1)
is_fragment
(68-78)
🪛 Ruff (0.13.3)
tilelang/carver/roller/shape_inference/tir.py
354-354: Prefer TypeError
exception for invalid type
(TRY004)
354-354: Create your own exception
(TRY002)
354-354: Avoid specifying long messages outside the exception class
(TRY003)
setup.py
312-312: Unused noqa
directive (non-enabled: SIM115
)
Remove unused noqa
directive
(RUF100)
tilelang/quantize/lop3.py
1189-1189: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/carver/roller/node.py
304-304: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
421-421: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: format-check
🔇 Additional comments (98)
testing/python/language/test_tilelang_language_pipeline.py (1)
106-107
: No-op cleanup looks good.Dropping the redundant parentheses keeps the expression readable with identical semantics while satisfying the linter. ✅
examples/minference/example_vertical_slash_sparse_attn.py (1)
209-209
: Redundant parentheses removal looks good.The simplified parity expression keeps identical behavior while matching the linter’s style expectations.
examples/bitnet-1.58b/eval_ppl.py (1)
57-57
: LGTM! Clean modernization to f-string.The conversion from
.format()
to f-string is correct and improves readability. This change also enhances consistency, as line 53 already uses f-strings for similar output formatting.tilelang/contrib/hipcc.py (1)
57-57
: LGTM! Clean string formatting modernization.The conversion from
%
formatting to f-string is correct and improves readability. Sincetarget_format
is validated to be"hsaco"
on line 54-55, both the old and new formatting produce identical output.examples/cast/example_per_token_cast_to_fp8.py (1)
103-103
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and aligns with Python 3.6+ best practices. F-strings are more readable and slightly more performant than the older string formatting method.Also applies to: 105-105, 117-117
examples/flash_decoding/example_gqa_decode.py (1)
476-480
: F-string conversions are equivalentSwitchover keeps the formatting semantics intact while aligning with the pyupgrade style target.
tilelang/quantize/quantization.py (2)
219-227
: No behavior change in prefix assembly.Parentheses cleanup keeps the same precedence, so the packed bf16 exponent path remains identical. Looks good.
231-236
: Bitwise OR chaining still correct.Associativity of
|
means the expanded form matches the prior grouped expression; the e4m3 conversion stays stable.examples/cast/example_group_per_split_token_cast_to_fp8.py (2)
202-202
: LGTM! Clean f-string migration.The conversion from
.format()
to f-string is correct and improves readability, aligning with the PR objective to enable pyupgrade linter rules.
205-205
: LGTM! Clean f-string migration.The conversion from
.format()
to f-string is correct and improves readability, consistent with the modernization effort.examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
493-494
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and preserves the formatting specification (.2f
). This change improves code consistency with the existing f-string usage elsewhere in the file (e.g., lines 436-437) and aligns with modern Python best practices.examples/flash_decoding/example_mha_inference.py (1)
321-325
: LGTM! String formatting modernization improves readability.The conversion from
.format()
to f-strings is correct and aligns with modern Python best practices (PEP 498) and the PR's objective to enable pyupgrade linting. The formatting specifiers and calculations remain identical, ensuring no behavioral changes.examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
436-437
: LGTM: Clean f-string migration.The conversion from
.format()
to f-strings improves readability and aligns with modern Python best practices.examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (1)
553-554
: LGTM: Consistent f-string migration.The formatting update is correct and consistent with the broader modernization effort across the codebase.
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
537-538
: LGTM: F-string conversion applied correctly.The formatting change maintains the same output while improving code quality.
examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py (1)
280-284
: LGTM: F-string updates for all benchmark outputs.All four print statements have been correctly migrated to f-strings, maintaining consistent formatting across reference and kernel benchmarks.
examples/deepseek_mla/example_mla_decode_paged.py (1)
403-404
: LGTM! Clean f-string conversion.The conversion from
.format()
to f-strings is correct and maintains identical output. This modernization improves readability and aligns with Python 3.6+ best practices.examples/bitnet-1.58b/utils_quant.py (1)
219-219
: LGTM! Modern super() syntax applied correctly.The change from
super(BitLinear, self).__init__(*kargs, **kwargs)
tosuper().__init__(*kargs, **kwargs)
is correct and aligns with Python 3+ best practices. The behavior remains equivalent.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
183-183
: LGTM! Modern super() syntax.The simplified
super().__init__()
call is the recommended Python 3 idiom, reducing boilerplate and potential errors.
362-364
: LGTM! Clean multi-line f-string formatting.The conversion from
.format()
to f-strings improves readability and performance while maintaining the multi-line structure.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
196-196
: LGTM! Modern super() syntax.The simplified
super().__init__()
call is the recommended Python 3 idiom, reducing boilerplate and potential errors.
372-374
: LGTM! Clean multi-line f-string formatting.The conversion from
.format()
to f-strings improves readability and performance while maintaining the multi-line structure.examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
3-3
: LGTM! Essential for modern type hints.The
from __future__ import annotations
import enables theint | None
syntax used throughout the file (lines 209, 338, 383) without requiringtyping.Union
. This is the recommended approach for Python 3.10+ style type hints in earlier Python versions.
438-446
: LGTM! Clean f-string migration.The conversion to f-strings improves readability and follows modern Python best practices. All format specifiers are correctly preserved (
.2f
for floating-point precision).examples/flash_attention/example_mha_bwd.py (2)
263-263
: LGTM! Generator expression for unpacking is idiomatic.The change from list comprehension to generator expression aligns with pyupgrade's UP027 rule. For unpacking into a fixed number of variables, generators are preferred as they avoid creating an intermediate list.
339-343
: LGTM! F-strings improve readability.The conversion from
.format()
to f-strings follows pyupgrade's UP031/UP032 rules. The formatting specifications are correctly preserved, and f-strings offer better readability and performance.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
3-3
: LGTM!The
from __future__ import annotations
import is correctly placed and enables the use of modern type hint syntax (e.g.,int | None
,tuple[...]
at lines 202, 331, 358) while maintaining compatibility with Python 3.7-3.9.
423-427
: LGTM!The conversion from
.format()
to f-strings is correct and maintains identical formatting. The changes improve readability and align with modern Python best practices.Note: The AI summary mentions a speedup calculation line that is not present in the annotated code changes. This may indicate the summary is inconsistent or referencing changes outside the shown diff.
examples/linear_attention/example_mamba_chunk_scan.py (1)
232-236
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings improves readability while preserving identical formatting behavior. This aligns well with the PR's objective to enable pyupgrade linting.examples/attention_sink/example_mha_sink_bwd_bhsd.py (3)
2-2
: LGTM: Modern annotation handling enabled.The
from __future__ import annotations
import enables postponed evaluation of type hints (PEP 563), which allows the modern union syntaxint | None
used in function signatures (lines 395, 442) to work correctly across Python versions.
369-369
: LGTM: Efficient unpacking with generator expression.The conversion from list comprehension to generator expression is a minor memory optimization. During unpacking, the generator is fully consumed, so the behavior remains identical.
492-496
: LGTM: Modern f-string formatting.The conversion to f-strings improves readability while maintaining identical output formatting. All format specifiers (
.2f
) are correctly preserved.examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
2-2
: LGTM! Essential import for modern type hints.The
from __future__ import annotations
import enables PEP 563 postponed evaluation, which is necessary for the modern union syntax (int | None
) used throughout this file (lines 190, 246). This is a standard modernization pattern recommended by pyupgrade.
289-293
: LGTM! Clean f-string conversions.The migration from
.format()
to f-strings preserves the exact formatting semantics (2 decimal places) while improving readability and performance. The expressiontotal_flops / latency * 1e-9
is correctly embedded in the f-strings without any functional changes.tilelang/carver/roller/rasterization.py (1)
92-92
: LGTM! Clean modernization to f-string.The conversion from
.format()
to f-string is correct and aligns with the PR's pyupgrade objective. The variablepanel_width
is guaranteed to be defined at this point (handled by the None-check on lines 88-89), and the generated CUDA code maintains proper semicolon syntax.tilelang/language/proxy.py (5)
4-4
: LGTM! Clean import modernization.Correctly removed
Optional
,Tuple
, andUnion
from typing imports since they're replaced with Python 3.10+ built-in syntax (tuple
,X | Y
) throughout the file.
146-146
: LGTM! Type hint modernization.Correctly updated from
Tuple[Any]
totuple[Any]
per Python 3.10+ standards.
154-154
: LGTM! Enhanced type safety.The added type annotation
tuple[Any] | PrimExpr | int
correctly captures the method's overloaded parameter using modern union syntax, aligning with the runtime check at lines 158-159.
175-176
: LGTM! Consistent type hint modernization.Both
shape
andstrides
parameters correctly updated to usetuple[Any]
.
276-276
: LGTM! Modern union syntax.Correctly updated from
Optional[str]
tostr | None
, using the modern union operator. Thefrom __future__ import annotations
at line 3 ensures runtime compatibility while enabling this modern syntax.examples/flash_attention/example_mha_fwd_bhsd.py (1)
206-210
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions. The output formatting remains identical.examples/flash_attention/example_mha_bwd_bhsd.py (2)
264-264
: LGTM! Valid generator expression unpacking.Replacing the list comprehension with a generator expression is valid—Python consumes the generator during tuple unpacking. This is slightly more memory efficient while maintaining identical behavior.
342-346
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_mha_fwd_bshd.py (1)
193-197
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (2)
404-404
: LGTM! Valid generator expression unpacking.Replacing the list comprehension with a generator expression is valid and slightly more memory efficient.
542-546
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_mha_bwd_wgmma_pipelined.py (2)
273-273
: LGTM! Valid generator expression unpacking.Replacing the list comprehension with a generator expression is valid and slightly more memory efficient.
349-353
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_gqa_fwd_bshd.py (1)
251-255
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
198-202
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions.examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
211-215
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability and follows modern Python conventions. All changes in this PR consistently apply Python 3 modernization best practices.examples/flash_attention/example_gqa_bwd.py (2)
382-382
: LGTM: Memory-efficient unpacking.The generator expression is more efficient than a list comprehension for unpacking, eliminating the intermediate list object.
520-524
: LGTM: Modern string formatting.The f-string conversions improve readability and performance compared to
.format()
style.examples/linear_attention/example_mamba_chunk_state.py (2)
171-175
: LGTM! Clean f-string conversion.The migration from
.format()
to f-strings is correct and improves readability. These changes align with the PR's objective of enabling the pyupgrade linter in ruff.
43-43
: Approve parentheses removal in decay_states
This change only removes redundant parentheses around the subtraction; functionality is unchanged.tilelang/jit/adapter/utils.py (2)
1-4
: LGTM! Import cleanup aligns with modern Python.The removal of
Dict
from typing imports is correct, as the built-indict
generic can be used directly with Python 3.9+.
110-110
: LGTM! Type annotation modernization is correct.The signature update correctly uses modern Python type hint syntax. The Python version verification requested for
get_annotated_mod
(lines 67-72) also applies to this change.tilelang/jit/kernel.py (1)
20-20
: LGTM! Modernized class declaration.Removed explicit
object
base class, which is implicit in Python 3+. This aligns with the pyupgrade rule UP004.tilelang/intrinsics/wgmma_macro_generator.py (1)
408-408
: LGTM! Modernized string formatting.Converted to f-string for better readability and performance. This aligns with the pyupgrade rule UP032.
tilelang/carver/arch/metal.py (1)
1-1
: LGTM! Enabled postponed annotation evaluation.Added
from __future__ import annotations
to enable PEP 563, which supports the modern union type syntax (Target | str
on line 12) and improves performance by deferring annotation evaluation.examples/bitnet-1.58b/modeling_bitnet.py (1)
1515-1515
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.tilelang/jit/adapter/cython/adapter.py (3)
105-105
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.
118-118
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.
135-135
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.tilelang/carver/roller/bestfit.py (1)
20-20
: LGTM! Modernized string formatting.Converted to f-string for better readability and performance. This aligns with the pyupgrade rule UP032.
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
88-88
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.tilelang/version.py (1)
24-24
: LGTM! Removed redundant file mode.The default mode for
open()
is'r'
(read text), so explicitly specifying it is unnecessary. This aligns with the pyupgrade rule UP015.tilelang/carver/arch/cuda.py (1)
94-94
: LGTM! Python 3 modernization.Removing the explicit
object
base class is correct. In Python 3, all classes implicitly inherit fromobject
, making this explicit declaration redundant.tilelang/carver/roller/hint.py (1)
150-150
: LGTM! Python 3 modernization.The removal of the explicit
object
base class is appropriate for Python 3, where all classes implicitly inherit fromobject
.examples/hadamard_transform/example_hadamard.py (1)
154-154
: LGTM! F-string modernization.The conversion from
.format()
to f-string is a good modernization. F-strings are more readable and slightly more performant in Python 3.6+.examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
37-37
: LGTM! Simplified file opening.Removing the explicit
"r"
mode is appropriate since it's the default foropen()
. This makes the code cleaner without changing behavior.tilelang/quantize/lop3.py (1)
1189-1191
: LGTM! F-string modernization.The conversion to f-strings improves readability. The error message and function name construction are now clearer and more idiomatic.
tilelang/jit/adapter/libgen.py (1)
32-32
: LGTM! Python 3 modernization.Removing the explicit
object
base class aligns with Python 3 best practices, where all classes implicitly inherit fromobject
.tilelang/cache/kernel_cache.py (1)
366-366
: LGTM! Simplified file opening.Omitting the explicit
"r"
mode is appropriate since it's the default foropen()
. This simplifies the code without affecting behavior.tilelang/autotuner/param.py (3)
249-249
: LGTM! Simplified file opening.Removing the explicit
"r"
mode is correct since it's the default foropen()
. This makes the code cleaner.
318-318
: LGTM! Simplified file opening.Omitting the explicit
"r"
mode is appropriate and aligns with Python best practices.
330-330
: LGTM! Simplified file opening.The removal of the explicit
"r"
mode is correct and makes the code more concise.tilelang/intrinsics/mma_macro_generator.py (2)
28-28
: LGTM! Python 3 modernization.Removing the explicit
object
base class is the recommended style for Python 3, as all classes implicitly inherit fromobject
.
521-521
: LGTM! Improved string formatting.The f-string provides the same functionality as the previous
.format()
call with clearer syntax.docs/conf.py (2)
4-4
: LGTM! Modernized string formatting.The f-string provides the same result as the previous
%
formatting with improved readability.
7-7
: LGTM! Simplified file opening.Removing the explicit
"r"
mode is appropriate since text read mode is the default foropen()
.examples/norm/rms_norm.py (1)
74-76
: LGTM! Improved print formatting.The f-string conversion maintains the same output format while improving code readability.
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (3)
2-2
: LGTM! Enables modern type hint syntax.Adding
from __future__ import annotations
enables PEP 563 postponed evaluation of annotations, allowing the use of modern type hint syntax likeint | None
on Line 389.
361-361
: LGTM! Memory-efficient unpacking.The generator expression is more memory-efficient than the list comprehension since values are consumed immediately during unpacking. The behavior is identical for this use case.
488-492
: LGTM! Improved string formatting.The f-string conversions maintain the same output format with better readability.
tilelang/intrinsics/mfma_macro_generator.py (1)
14-14
: LGTM! Python 3 modernization.Removing the explicit
object
base class aligns with Python 3 conventions, as all classes implicitly inherit fromobject
.tilelang/carver/roller/shape_inference/common.py (1)
7-7
: LGTM! Python 3 modernization.Removing explicit
object
base classes aligns with Python 3 conventions, as all classes implicitly inherit fromobject
.Also applies to: 21-21
examples/amd/example_amd_flash_attn_bwd.py (2)
247-247
: LGTM! Memory-efficient unpacking.The generator expression is more memory-efficient than a list comprehension since values are consumed immediately during unpacking. The behavior is identical for this use case.
346-350
: LGTM! Improved string formatting.The f-string conversions maintain the same output format with better readability and performance.
tilelang/tileop/gemm/gemm_base.py (1)
11-12
: Ensure subclasses ofGemmBase
remain instantiable after removing@dataclass
.Removing
@dataclass
drops the generated__init__
, so any subclass without its own constructor will inheritobject.__init__
and break when passedgemm_node
. We found no direct instantiations ofGemmBase
or its subclasses (tilelang/tileop/gemm/gemm_wgmma.py
,tilelang/tileop/gemm/gemm_mma.py
,tilelang/primitives/gemm/gemm_mma.py
), but none define__init__
. Confirm these classes aren’t instantiated at runtime or add explicit__init__
methods.setup.py (1)
140-140
: LGTM!The removal of explicit
"r"
mode fromopen()
calls is a safe modernization since text mode is the default in Python 3. These changes align with enabling the pyupgrade linter.Also applies to: 171-171, 711-712, 723-724, 740-741
tilelang/carver/roller/node.py (2)
32-32
: LGTM!Removing explicit
object
base classes is correct for Python 3, where all classes implicitly inherit fromobject
.Also applies to: 93-93
304-304
: Verify that instance lifetime justifies method-level caching.The static analysis tool (B019) flags that using
@functools.lru_cache
on methods can prevent garbage collection of instances. The cache holds references toself
, which can lead to memory leaks if instances are meant to be short-lived.If
PrimFuncNode
instances have long lifetimes or are reused frequently, this is acceptable. Otherwise, consider using a separate cache or@lru_cache
withmaxsize
to limit memory growth.Also applies to: 421-421
tilelang/carver/roller/shape_inference/tir.py (1)
50-50
: LGTM!Removing explicit
object
base classes is correct for Python 3.Also applies to: 79-79
tilelang/jit/adapter/wrapper.py (4)
179-179
: LGTM!Removing explicit
object
base classes is correct for Python 3.Also applies to: 895-895, 1062-1062
325-326
: LGTM!The f-string conversion for CUDA kernel launch configuration is correct and more readable.
350-351
: LGTM!The f-string conversion for kernel launch code is correct and improves readability.
993-993
: LGTM!The f-string conversion is correct.
@XuehaiPan Would you mind help review this pr? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Some small nits:
-
Add
__future__
imports only when needed. -
We'd best add the
tool.ruff.target-version
topyproject.toml
whenUP
code is enabled.[tool.ruff] target-version = "py38"
Thanks! @XuehaiPan
This should be covered by the added
I steal a commit from your PR #955 to resolve that. |
7273746
to
391ddc2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 22
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (19)
tilelang/language/persistent.py (1)
18-18
: Update docstring to match the modernized type annotation.The docstring still references
List[tir.PrimExpr]
, but the function signature now useslist[tir.PrimExpr]
. Update the docstring for consistency.Apply this diff:
- domain : List[tir.PrimExpr] + domain : list[tir.PrimExpr]tilelang/primitives/gemm/base.py (1)
229-239
: Update the docstring to match the new type annotation syntax.The method parameter type was correctly updated to
int | None
, but the docstring at line 239 still referencesOptional[int]
. Please update the docstring to maintain consistency.Apply this diff to update the docstring:
Parameters ---------- - threads : Optional[int] + threads : int | None The total number of threads in a block. Must be provided if any block partition parameter is not already set.tilelang/language/frame.py (1)
201-210
: Update docstring to match the new type annotation syntax.The return type annotation on line 201 uses the modern
PrimExpr | None
syntax, but the docstring on line 208 still referencesOptional[PrimExpr]
. For consistency, update the docstring to use the new union syntax.Apply this diff to update the docstring:
"""Get the value bound to a variable in the current let frame stack. Args: var (Var): The variable to look up Returns: - Optional[PrimExpr]: The bound value if found, None otherwise + PrimExpr | None: The bound value if found, None otherwise """examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
220-232
: Restore the explicitNone
check forimages
.Switching the guard to
if images:
makes an empty list skip the length assertion, yet the code still doesimages[i]
later whenimages is not None
, raising anIndexError
. Revert to the explicitNone
check to keep the old, safe behavior.- if images: + if images is not None: assert len(prompts) == len(images)tilelang/language/tir/ir.py (2)
10-32
: Fix implicit Optional violations and update docstring.Three issues need attention:
- Line 11:
stop: PrimExpr = None
violates PEP 484 (implicit Optional). Should bePrimExpr | None = None
.- Line 13:
annotations: dict[str, Any] = None
violates PEP 484. Should bedict[str, Any] | None = None
.- Line 24: Docstring still references
Dict[str, Any]
instead ofdict[str, Any]
.Apply this diff:
-def serial(start: PrimExpr, - stop: PrimExpr = None, +def serial(start: PrimExpr, + stop: PrimExpr | None = None, *, - annotations: dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] | None = None) -> frame.ForFrame: """The serial For statement. Parameters ---------- start : PrimExpr The minimum value of iteration. stop : PrimExpr The maximum value of iteration. - annotations : Dict[str, Any] + annotations : dict[str, Any] | None The optional annotations of the For statement.
35-57
: Apply the same fixes to remaining functions.The
parallel
,vectorized
,unroll
, andthread_binding
functions have the same three issues asserial
:
stop: PrimExpr = None
→stop: PrimExpr | None = None
annotations: dict[str, Any] = None
→annotations: dict[str, Any] | None = None
- Docstrings reference
Dict[str, Any]
→ update todict[str, Any] | None
For
parallel
(lines 35-57):-def parallel(start: PrimExpr, - stop: PrimExpr = None, +def parallel(start: PrimExpr, + stop: PrimExpr | None = None, *, - annotations: dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] | None = None) -> frame.ForFrame: """The parallel For statement. Parameters ---------- start : PrimExpr The minimum value of iteration. stop : PrimExpr The maximum value of iteration. - annotations : Dict[str, Any] + annotations : dict[str, Any] | None The optional annotations of the For statement.For
vectorized
(lines 60-82):-def vectorized(start: PrimExpr, - stop: PrimExpr = None, +def vectorized(start: PrimExpr, + stop: PrimExpr | None = None, *, - annotations: dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] | None = None) -> frame.ForFrame: """The vectorized For statement. Parameters ---------- start : PrimExpr The minimum value of iteration. stop : PrimExpr The maximum value of iteration. - annotations : Dict[str, Any] + annotations : dict[str, Any] | None The optional annotations of the For statement.For
unroll
(lines 85-107):-def unroll(start: PrimExpr, - stop: PrimExpr = None, +def unroll(start: PrimExpr, + stop: PrimExpr | None = None, *, - annotations: dict[str, Any] = None) -> frame.ForFrame: + annotations: dict[str, Any] | None = None) -> frame.ForFrame: """The unrolled For statement. Parameters ---------- start : PrimExpr The minimum value of iteration. stop : PrimExpr The maximum value of iteration. - annotations : Dict[str, Any] + annotations : dict[str, Any] | None The optional annotations of the For statement.For
thread_binding
(lines 110-138):def thread_binding( start: PrimExpr, - stop: PrimExpr = None, - thread: str = None, + stop: PrimExpr | None = None, + thread: str | None = None, *, - annotations: dict[str, Any] = None, + annotations: dict[str, Any] | None = None, ) -> frame.ForFrame: """The thread-binding For statement. Parameters ---------- start : PrimExpr The minimum value of iteration. stop : PrimExpr The maximum value of iteration. thread : str The thread for loop variable to bind. - annotations : Dict[str, Any] + annotations : dict[str, Any] | None The optional annotations of the For statement.Note:
thread_binding
also hasthread: str = None
that needs fixing.Also applies to: 60-82, 85-107, 110-138
examples/fusedmoe/example_fusedmoe_torch.py (1)
10-16
: Don't use PEP 604/585 syntax under a Python 3.8 target.The new annotations (
int | None
,tuple[...]
, baredict
) rely on PEP 604/585 features that aren’t available on Python 3.8. Even withfrom __future__ import annotations
, any runtime evaluation of these hints (e.g.typing.get_type_hints
) will raise, breaking the stated goal of keeping 3.8 compatibility. Please revert to thetyping.Optional
,typing.Tuple
, andtyping.Dict
forms (or equivalent) in this module before enabling the lint.Apply this diff:
-from __future__ import annotations +from __future__ import annotations +from typing import Dict, Optional, Tuple @@ - def __init__(self, config: dict, d_expert: int | None = None): + def __init__(self, config: Dict, d_expert: Optional[int] = None): @@ - def __init__(self, config: dict): + def __init__(self, config: Dict): @@ - def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ - def __init__(self, config: dict): + def __init__(self, config: Dict): @@ -def ref_kernel(data: tuple[torch.Tensor, dict, dict]) -> torch.Tensor: +def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: @@ - seed: int) -> tuple[torch.Tensor, dict, dict]: + seed: int) -> Tuple[torch.Tensor, Dict, Dict]:Also applies to: 37-38, 47-48, 100-146
tilelang/language/experimental/gemm_sp.py (1)
45-56
: Type hint modernized. Update docstring and consider adding return type annotation.The PEP 604 union syntax is correctly applied to the parameter type hint.
However, the docstrings (lines 49, 52) still reference the old
Union[tir.Buffer, tir.Var]
syntax. Additionally, consider adding a return type annotation for completeness.Apply this diff to update the docstrings and add the return type annotation:
- def legalize_arguments(arg: tir.Buffer | tir.Var): + def legalize_arguments(arg: tir.Buffer | tir.Var) -> tir.Buffer | tir.Var: """Convert let-bound variables to their corresponding buffers. Args: - arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + arg (tir.Buffer | tir.Var): Input argument to legalize Returns: - Union[tir.Buffer, tir.Var]: The legalized argument + tir.Buffer | tir.Var: The legalized argument """tilelang/carver/matmul_analysis.py (1)
337-344
: Correct theget_ordered_axes
return annotationThe helper still builds and returns a list, but the annotation now advertises
set[Var]
. This will confuse type-checkers (and readers) because downstream code indexes it (axes[-1]
). Please change the return type back tolist[Var]
here and in the mirrored helper insideanalysis_tensorcore_tags
.- def get_ordered_axes(region: list[Range]) -> set[Var]: + def get_ordered_axes(region: list[Range]) -> list[Var]: axes: list[Var] = []tilelang/language/kernel.py (1)
195-200
: threads property includes the block frame; should return only threadIdx.{x,y,z}.This currently returns 4 items (includes the last block frame). Align with get_thread_bindings and the doc intent.
def threads(self) -> list[Var]: """ Returns the thread indices from the topmost frame. """ - return [frame.iter_var.var for frame in self.frames[-4:]] + # Exclude the trailing block frame; only return threadIdx.{x,y,z} + return [frame.iter_var.var for frame in self.frames[-4:-1]]tilelang/engine/phase.py (1)
22-30
: Guard None target before calling have_tma.have_tma(None) will raise (accesses target.kind). Add a fast‑path for None.
def allow_tma_and_warp_specialized(pass_ctx: PassContext | None = None, target: Target | None = None) -> bool: if pass_ctx is None: pass_ctx = tilelang.transform.get_pass_context() + if target is None: + return False if not have_tma(target): return False disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) return not disable_tma_lower and allow_warp_specialized(pass_ctx=pass_ctx, target=target)Reference: have_tma expects target.kind.name (see tilelang/contrib/nvcc.py). [Based on relevant code snippet]
tilelang/carver/arch/arch_base.py (1)
17-34
: Duplicate attribute definitions overwrite initial values.The attributes
transaction_size
andbandwidth
are defined twice in the__init__
method:
- First at lines 17-25 with comments describing their purpose
- Again at lines 32-34 with comments about units
The second definitions (lines 32, 34) overwrite the first ones (lines 17, 22), making the initial assignments redundant. This appears to be a pre-existing issue that should be addressed.
Remove the duplicate definitions and consolidate into a single set:
self.max_smem_usage: int = 0 # The maximum shared memory usage allowed - self.bandwidth: list[int] = [ - 0, - 0, - ] # Bandwidth specifications, possibly including peak and sustained rates self.platform: str = "unknown" # The platform or manufacturer of the device self.compute_capability: str = ( "unknown" # The compute capability, indicating the feature set and performance level ) self.l2_cache_size_bytes: int = 0 - # the number of transaction size in bytes - self.transaction_size: list[int] = [0, 0] # in bytes - # bandwidth in MB/s, will be used for recommend basic tile size - self.bandwidth: list[int] = [0, 0] + # Consolidate comments: The number of transaction size in bytes + self.transaction_size: list[int] = [0, 0] + # Consolidate comments: Bandwidth in MB/s, will be used for recommend basic tile size (possibly including peak and sustained rates) + self.bandwidth: list[int] = [0, 0]Alternatively, if the second definitions should replace the first, remove lines 17-25:
self.sm_partition: int = 0 # The number of streaming multiprocessor partitions - self.transaction_size: list[int] = [ - 0, - 0, - ] # The size of memory transactions, typically in bytes - self.max_smem_usage: int = 0 # The maximum shared memory usage allowed - self.bandwidth: list[int] = [ - 0, - 0, - ] # Bandwidth specifications, possibly including peak and sustained rates + self.max_smem_usage: int = 0 # The maximum shared memory usage allowed self.platform: str = "unknown" # The platform or manufacturer of the devicetilelang/engine/lower.py (1)
128-133
: Remove unusedtarget
parameter fromcanon_target_host
.The
target
argument isn’t referenced; drop it from the signature and update both call sites.--- a/tilelang/engine/lower.py @@ -128,7 +128,6 @@ def canon_target_host( - def canon_target_host(target: str | Target, target_host: str | Target | None): + def canon_target_host(target_host: str | Target | None): if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "c" return target_host @@ -216,7 +215,7 @@ def lower(...): - target_host = canon_target_host(target, target_host) + target_host = canon_target_host(target_host)--- a/tilelang/jit/adapter/utils.py @@ -85,7 +85,7 @@ def some_util(...): - target_host = tvm.target.Target.canon_target(canon_target_host(target, target_host)) + target_host = tvm.target.Target.canon_target(canon_target_host(target_host))tilelang/carver/arch/driver/cuda_driver.py (1)
124-143
: Don’t silently returnNone
fromget_device_attribute
The new
try/except
now returnsNone
on failure while the signature remains-> int
. Existing callers expect anint
(or a raised error) and will now receiveNone
, causing downstream breakage or type confusion. Please restore the previous behavior—let the RuntimeError propagate (or at least re-raise it) and keep returning anint
, or explicitly change the return type and update every caller to handle theNone
case.tilelang/jit/adapter/nvrtc/adapter.py (1)
163-167
: Guardbuffer_map
lookup before indexing
PrimFunc.params
often include scalar vars in addition to buffers. Those scalars are absent frombuffer_map
, so the new direct lookup raises aKeyError
during adapter initialisation. Please skip params that aren’t backed by buffers, like the ctypes adapter still does.Apply:
- for i, param in enumerate(params): - buffer = buffer_map[param] + for i, param in enumerate(params): + if param not in buffer_map: + continue + buffer = buffer_map[param]tilelang/jit/adapter/ctypes/adapter.py (2)
226-233
: Unpack bug after changing dynamic_symbolic_map to (id, i, j)dynamic_symbolic_map now stores triples (id, buffer_idx, dim). This two-value unpack will raise “too many values to unpack”.
- if isinstance(s, tir.Var): - ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s] - shape.append(ins[ref_tensor_idx].shape[ref_shape_idx]) + if isinstance(s, tir.Var): + _, ref_tensor_idx, ref_shape_idx = self.dynamic_symbolic_map[s] + shape.append(ins[ref_tensor_idx].shape[ref_shape_idx])
239-245
: Use args, not ins, for dynamic symbolics
When appending dynamic dim/stride values, index into args (which includes both inputs and outputs) rather than ins (inputs-only) to support symbols on output buffers. Replace ins[buffer_idx] with args[buffer_idx] in the loop at lines 239–245.tilelang/autotuner/tuner.py (1)
232-256
: generate_cache_key return type is wrongFunction returns a hex string, not AutotuneResult | None.
- def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None: + def generate_cache_key(self, parameters: dict[str, Any]) -> str:tilelang/jit/adapter/wrapper.py (1)
493-499
: Use of undefined variable 'function_name'Inside host_mod loop, function_name is not defined; this will raise at runtime. l2_persistent_map likely applies to all kernels; set the map directly.
- if "l2_persistent_map" in func.attrs: - self.l2_persistent_map[function_name] = func.attrs["l2_persistent_map"] + if "l2_persistent_map" in func.attrs: + self.l2_persistent_map = func.attrs["l2_persistent_map"]
♻️ Duplicate comments (2)
setup.py (1)
312-312
: Remove unused noqa (SIM115 not enabled).Ruff flags this as an unused suppression.
- return open(get_path("README.md"), encoding="utf-8").read() # noqa: SIM115 + return open(get_path("README.md"), encoding="utf-8").read()tilelang/carver/roller/node.py (1)
305-307
: Restorelru_cache()
invocationDropping the parentheses passes the method object in as
maxsize
, so definition now raisesTypeError: 'function' object cannot be interpreted as an integer
. Please revert to calling the decorator (applies to both cached methods).Fix:
- @functools.lru_cache + @functools.lru_cache() @@ - @functools.lru_cache + @functools.lru_cache()Also applies to: 420-422
🧹 Nitpick comments (23)
examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py (1)
76-76
: Remove redundant self-assignment.Line 76 assigns
import_source
to itself, which has no effect. This line can be safely removed.Apply this diff to remove the redundant line:
assert import_source is not None, "lop3_intrin_info is not found" assert func_name is not None, "lop3_intrin_info is not found" - import_source = import_source
tilelang/utils/language.py (1)
88-98
: Update docstring to match the new type annotation.The function signature correctly uses the modern
list[int]
syntax, but the docstring at line 93 still referencesList[int]
. For consistency, update the docstring to match the signature.Apply this diff to update the docstring:
Args: - array (List[int]): The array of integers to reduce. + array (list[int]): The array of integers to reduce. Returns:testing/python/kernel/test_tilelang_kernel_fp8_gemv_simt.py (1)
25-26
: Type annotation modernization looks good.The migration from
Optional[int]
toint | None
syntax is correct and aligns with the PR's goal of enabling pyupgrade linter rules. The runtime behavior is unchanged since the defaults and assertions remain the same.Optional refactor: Consider tightening the type hints.
Since both parameters have non-None defaults (4 and 32) and are immediately validated as non-None (lines 28-31), the type hints could be simplified to just
int
rather thanint | None
:- n_partition: int | None = 4, - reduce_thread: int | None = 32, + n_partition: int = 4, + reduce_thread: int = 32,This would make the signature more accurate—the function doesn't meaningfully accept None since it's rejected immediately. However, this is a pre-existing pattern from the original
Optional[int]
annotations and is outside the scope of this linting PR.tilelang/language/warpgroup.py (1)
48-50
: Consider simplifying the list construction.The current loop-based approach could be streamlined using a list constructor, though this is beyond the scope of the linting changes.
Apply this diff to simplify:
- warp_group_ids: list[int] = [] - for warp_group_id in warp_group_idx: - warp_group_ids.append(warp_group_id) + warp_group_ids: list[int] = list(warp_group_idx)examples/bitnet-1.58b/vllm_workspace/utils.py (1)
1-2
: Clean up unused imports and modernize type aliases for consistency.After updating the function signatures to use built-in
list
, theList
import fromtyping
on line 2 is no longer used in the function signatures. Additionally, for consistency with the modernized function signatures, the type aliasesTokensText
(line 4) andTokensTextLogprobs
(line 27) should also be updated to use built-in generics instead ofTuple
andList
from typing.Apply this diff to modernize the type aliases and clean up imports:
-from typing import Dict, List, Tuple +from typing import Dict -TokensText = Tuple[List[int], str] +TokensText = tuple[list[int], str] -TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] +TokensTextLogprobs = tuple[list[int], str, list[Dict[int, float]]]Note: If
Dict
is also unused elsewhere in the codebase, consider removing it too. The modernization todict[int, float]
would require postponed evaluation support, which is already enabled via thefrom __future__ import annotations
statement.tilelang/contrib/cc.py (1)
211-211
: LGTM! Type annotation correctly modernized.The change from
typing.Dict[str, str]
todict[str, str]
is correct and aligns with modern Python typing practices (PEP 585). The future annotations import at line 18 ensures Python 3.8 compatibility.Optional: Consider updating the docstring for consistency.
The docstring at line 224 still uses the old-style
Dict[str, str]
notation:Returns ------- symbol_section_map: Dict[str, str] A map from defined global symbol to their sectionsFor consistency with the actual annotation, you could update it to:
Returns ------- - symbol_section_map: Dict[str, str] + symbol_section_map: dict[str, str] A map from defined global symbol to their sectionsexamples/fusedmoe/example_fusedmoe_tilelang.py (4)
274-278
: LGTM! Type hints modernized.The conversion to built-in generic types (
dict
,int | None
) is correct and consistent with PEP 585/604 standards.Optionally, consider making the
dict
type more specific for better type safety:def __init__(self, - config: dict, + config: dict[str, Any], gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: int | None = None):This would require importing
Any
fromtyping
if not already imported.
298-298
: LGTM! Type hints modernized.The conversion to built-in
dict
type is correct and consistent with PEP 585 standards.Optionally, consider making the
dict
types more specific:- def __init__(self, config: dict, weights: dict): + def __init__(self, config: dict[str, Any], weights: dict[str, torch.Tensor]):
317-320
: LGTM! Type hints modernized.The conversion to built-in
dict
types is correct and consistent with PEP 585 standards.Optionally, consider making the
dict
types more specific:def __init__(self, - config: dict, + config: dict[str, Any], shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, - weights: dict, + weights: dict[str, torch.Tensor], padding_M: int = 128):
478-478
: LGTM! Type hint modernized.The conversion to built-in
tuple
type is correct and consistent with PEP 585 standards.Optionally, consider making the
dict
types more specific to match the documented structure:-def custom_kernel(data: tuple[torch.Tensor, dict, dict]) -> torch.Tensor: +def custom_kernel(data: tuple[torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]) -> torch.Tensor:examples/cast/example_per_token_cast_to_fp8.py (1)
103-117
: All print statements use f-strings; refactor Tuple imports
- No remaining
.format()
or%
formatting inprint()
calls.- Replace
from typing import Tuple
with built-intuple[...]
annotations in:
• maint/precision/compare_ops.py
• tilelang/language/ast/ir.py
• examples/deepseek_v32/inference/kernel.py
• examples/deepseek_v32/inference/model.py
• examples/bitnet-1.58b/vllm_workspace/utils.py
• examples/deepseek_v32/utils.pytilelang/contrib/nvrtc.py (1)
16-17
: Update docstrings to match modernized type annotations.The parameter type annotations have been correctly modernized to use PEP 604 union syntax (
int | None
,str | list[str] | None
). However, the docstrings at lines 29 and 32 still reference the old typing notation (Optional[int]
,Optional[Union[str, List[str]]]
).Consider updating the docstring to match the new annotation style for consistency:
- arch : Optional[int] + arch : int | None The cuda architecture code. - options : Optional[Union[str, List[str]]] + options : str | list[str] | None The additional options.tilelang/language/builtin.py (1)
173-173
: Consider expanding type hints to match implementation.The type hints for
mbarrier
/barrier_id
specifyint | PrimExpr | tir.Call
, but the implementations (lines 212-219, 230-237) also accepttir.BufferLoad
andtir.Buffer
. Consider adding these types to the annotations for more accurate API documentation:-def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call, parity: int | Var): +def mbarrier_wait_parity(mbarrier: int | PrimExpr | tir.Call | tir.Buffer | tir.BufferLoad, parity: int | Var):(Apply similar changes to
mbarrier_arrive
,barrier_wait
, andbarrier_arrive
)Also applies to: 223-223, 266-266, 281-281
setup.py (1)
124-129
: Harden requirements parsing and set encoding.Avoid env‑dependent defaults and stray entries in install_requires. Filter comments/empties and set UTF‑8.
-def get_requirements(file_path: str = "requirements.txt") -> list[str]: - """Get Python package dependencies from requirements.txt.""" - with open(get_path(file_path)) as f: - requirements = f.read().strip().split("\n") - return requirements +def get_requirements(file_path: str = "requirements.txt") -> list[str]: + """Get Python package dependencies from requirements.txt.""" + with open(get_path(file_path), encoding="utf-8") as f: + requirements: list[str] = [] + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + requirements.append(line) + return requirementstilelang/language/kernel.py (2)
210-214
: Fix varargs annotation for blocks.For varargs, the annotation applies to each element. Use tir.PrimExpr, not list[tir.PrimExpr].
-def Kernel( - *blocks: list[tir.PrimExpr], +def Kernel( + *blocks: tir.PrimExpr,
158-164
: Docstring doesn’t match return type.Method returns a list of three bindings, not a single dim’s binding.
def get_thread_bindings(self) -> list[Var]: - """ - Returns the thread binding for the given dimension. - dim=0 corresponds to threadIdx.x, dim=1 to threadIdx.y, and dim=2 to threadIdx.z. - """ + """ + Returns all three thread bindings: threadIdx.x, threadIdx.y, threadIdx.z. + """tilelang/language/copy.py (1)
11-16
: Include tir.Var in type hints to match runtime handling.Implementation accepts tir.Var (resolved via T.has_let_value/T.get_let_value). Reflect that in signatures.
-def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion, - dst: tir.Buffer | tir.BufferLoad, +def copy(src: tir.Buffer | tir.BufferLoad | tir.BufferRegion | tir.Var, + dst: tir.Buffer | tir.BufferLoad | tir.Var, coalesced_width: int | None = None, disable_tma: bool = False, eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None):tilelang/carver/template/conv.py (1)
47-59
: Restore the TileDevice parameter annotationAll other templates still surface the
arch: TileDevice
annotation, andBaseTemplate.get_hardware_aware_configs
declares the same signature. Dropping it here hides the contract from static analysis and breaks consistency. Please keep the type annotation (re-importingTileDevice
if necessary) so the override matches the base method.-from ..roller import Hint +from ..arch import TileDevice +from ..roller import Hint @@ - def get_hardware_aware_configs(self, arch=None, topk=10) -> list[Hint]: + def get_hardware_aware_configs(self, arch: TileDevice = None, topk: int = 10) -> list[Hint]:tilelang/carver/arch/cuda.py (1)
137-138
: Fix theavailable_tensor_instructions
annotationThe attribute is annotated as
list[TensorInstruction]
but initialized toNone
, which violates the declared type and will trip stricter type checkers now that we’ve enabled the new linting. Please declare it aslist[TensorInstruction] | None
(or initialize it with an empty list) to keep the annotation truthful.tilelang/jit/adapter/ctypes/adapter.py (1)
102-106
: Check library init result and surface errorFor parity with the cython adapter and better diagnostics, check init() return and expose get_last_error.
- self.lib = self.lib_generator.load_lib() - self.lib.init() + self.lib = self.lib_generator.load_lib() + self.lib.get_last_error.restype = ctypes.c_char_p + _res = self.lib.init() + if _res != 0: + _err = self.lib.get_last_error().decode("utf-8") + raise RuntimeError(f"Initialization failed: {_err}")tilelang/jit/adapter/cython/adapter.py (3)
378-393
: buffer_dtype_map key type annotation mismatchKeys are buffer names (str), not tir.Var. Fix annotations.
- def _process_buffer_dtype(self) -> dict[tir.Var, tuple[int, torch.dtype]]: + def _process_buffer_dtype(self) -> dict[str, tuple[int, torch.dtype]]: @@ - buffer_dtype_map = {} + buffer_dtype_map: dict[str, tuple[int, torch.dtype]] = {}
408-412
: _process_static_buffer_infos return type annotations incorrectMaps keyed by buffer.name (str) and static_contiguous_list holds (index, name).
- def _process_static_buffer_infos(self) -> \ - tuple[dict[tir.Var, tuple[int, list[tuple[int, int]]]], - dict[tir.Var, tuple[int, list[tuple[int, int]]]], - list[tuple[tir.Var]]]: + def _process_static_buffer_infos(self) -> \ + tuple[dict[str, tuple[int, list[tuple[int, int]]]], + dict[str, tuple[int, list[tuple[int, int]]]], + list[tuple[int, str]]]:
442-467
: buffer_device_map key type annotation mismatchSame as others: key is buffer name (str).
- def _process_buffer_device(self) -> dict[tir.Var, tuple[int, torch.device]]: + def _process_buffer_device(self) -> dict[str, tuple[int, torch.device]]: @@ - buffer_device_map = {} + buffer_device_map: dict[str, tuple[int, torch.device]] = {}
def run_general_reduction_recommend_hints(structure: str = "SSR", | ||
shape: List[int] = None, | ||
shape: list[int] = None, | ||
dtype: str = "float16", | ||
topk: int = 20): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix implicit Optional
in type annotations.
The shape
parameter has a default value of None
, but the type annotation list[int]
doesn't include None
as a valid type. PEP 484 prohibits implicit Optional
. With from __future__ import annotations
enabled, use the PEP 604 union syntax list[int] | None
.
Apply this diff to fix both occurrences:
def run_general_reduction_recommend_hints(structure: str = "SSR",
- shape: list[int] = None,
+ shape: list[int] | None = None,
dtype: str = "float16",
topk: int = 20):
-def run_elementwise_recommend_hints(shape: list[int] = None,
+def run_elementwise_recommend_hints(shape: list[int] | None = None,
dtype: str = "float16",
topk: int = 20):
Also applies to: 31-33
🧰 Tools
🪛 Ruff (0.13.3)
8-8: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In testing/python/carver/test_tilelang_carver_recommend_hints.py around lines
7-10 and 31-33, the function parameter `shape` is annotated as `list[int]` but
has a default of `None`, creating an implicit Optional; update both annotations
to use PEP 604 union syntax `list[int] | None` so the type explicitly allows
None (i.e., change `shape: list[int] = None` to `shape: list[int] | None = None`
in both locations).
tilelang/autotuner/param.py
Outdated
out_idx: list[int] | int | None = None | ||
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" | ||
target: Literal['auto', 'cuda', 'hip'] = 'auto' | ||
target_host: Union[str, Target] = None | ||
target_host: str | Target = None | ||
verbose: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include None
in the annotation for target_host
.
With pyupgrade’s RUF013
enabled, a default of None
without None
in the annotation fails lint. Please widen the type to keep the new lint clean.
- target_host: str | Target = None
+ target_host: str | Target | None = None
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
out_idx: list[int] | int | None = None | |
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" | |
target: Literal['auto', 'cuda', 'hip'] = 'auto' | |
target_host: Union[str, Target] = None | |
target_host: str | Target = None | |
verbose: bool = False | |
out_idx: list[int] | int | None = None | |
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython" | |
target: Literal['auto', 'cuda', 'hip'] = 'auto' | |
target_host: str | Target | None = None | |
verbose: bool = False |
🤖 Prompt for AI Agents
In tilelang/autotuner/param.py around lines 51 to 55, the parameter annotation
for target_host is missing None even though its default is None; update the type
to include None (e.g. change the annotation to include | None or use Optional[])
so the declared type matches the default and satisfies the RUF013 lint rule.
tilelang/autotuner/tuner.py
Outdated
_kernel_parameters: tuple[str, ...] | None = None | ||
_function_parameters: dict[str, Any] | None = None | ||
_lock = threading.Lock() # For thread safety |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_kernel_parameters/_function_parameters types regressed
_key is used as (args_tuple, sorted_kwargs_tuple). The current tuple[str, ...] is incorrect and will confuse tools.
- _kernel_parameters: tuple[str, ...] | None = None
- _function_parameters: dict[str, Any] | None = None
+ _kernel_parameters: tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]] | None = None
+ _function_parameters: dict[str, Any] | None = None
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
_kernel_parameters: tuple[str, ...] | None = None | |
_function_parameters: dict[str, Any] | None = None | |
_lock = threading.Lock() # For thread safety | |
_kernel_parameters: tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]] | None = None | |
_function_parameters: dict[str, Any] | None = None | |
_lock = threading.Lock() # For thread safety |
🤖 Prompt for AI Agents
In tilelang/autotuner/tuner.py around lines 107 to 109, the type annotations for
_kernel_parameters and _function_parameters are incorrect—the code uses keys of
the form (args_tuple, sorted_kwargs_tuple), not tuple[str,...] or str. Change
both annotations to a dict whose keys are a tuple of (args tuple, sorted kwargs
tuple) — e.g. key type tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]] — and
values as the appropriate parameter type (Any), keeping the optional None union
and leaving _lock as-is for thread safety.
tilelang/autotuner/tuner.py
Outdated
def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]): | ||
# for cache key generation | ||
self._kernel_parameters = k_parameters | ||
self._function_parameters = f_parameters | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
set_kernel_parameters signature should match actual payload
The first argument is that (args_tuple, kwargs_tuple) key, not a tuple[str, ...].
- def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
+ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]], f_parameters: dict[str, Any]):
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]): | |
# for cache key generation | |
self._kernel_parameters = k_parameters | |
self._function_parameters = f_parameters | |
def set_kernel_parameters( | |
self, | |
k_parameters: tuple[tuple[Any, ...], tuple[tuple[str, Any], ...]], | |
f_parameters: dict[str, Any] | |
): | |
# for cache key generation | |
self._kernel_parameters = k_parameters | |
self._function_parameters = f_parameters |
🤖 Prompt for AI Agents
In tilelang/autotuner/tuner.py around lines 227 to 231, the
set_kernel_parameters signature and typing are incorrect: the first argument is
the (args_tuple, kwargs_tuple) key, not a tuple[str, ...]. Change the parameter
type to reflect the actual payload (e.g., key: tuple[tuple[Any, ...], dict[str,
Any]] or a more specific alias) and update the parameter name if helpful (e.g.,
key or params_key); then assign the provided key and the f_parameters to the
internal fields (self._kernel_parameters = key; self._function_parameters =
f_parameters) and update any callers if necessary to match the new
signature/type.
tilelang/cache/__init__.py
Outdated
def cached( | ||
func: PrimFunc = None, | ||
out_idx: List[int] = None, | ||
out_idx: list[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add explicit | None
annotation.
Per PEP 484, parameters with None
defaults must explicitly include None
in the type annotation.
Apply this diff:
def cached(
func: PrimFunc = None,
- out_idx: list[int] = None,
+ out_idx: list[int] | None = None,
*args,
🧰 Tools
🪛 Ruff (0.13.3)
17-17: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In tilelang/cache/__init__.py around line 17, the parameter annotation `out_idx:
list[int] = None` needs to explicitly include None per PEP 484; change the type
to `list[int] | None` (or `Optional[list[int]]` for Python <3.10) so the
signature becomes `out_idx: list[int] | None = None`.
def __init__( | ||
self, | ||
func: PrimFunc = None, | ||
out_idx: Union[List[int], int] = None, | ||
out_idx: list[int] | int = None, | ||
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", | ||
target: Union[str, Target] = "auto", | ||
target_host: Union[str, Target] = None, | ||
target: str | Target = "auto", | ||
target_host: str | Target = None, | ||
verbose: bool = False, | ||
pass_configs: Optional[Dict[str, Any]] = None, | ||
pass_configs: dict[str, Any] | None = None, | ||
from_database: bool = False, | ||
compile_flags: Optional[List[str]] = None, | ||
compile_flags: list[str] | None = None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix implicit Optional annotations in JITKernel.__init__
.
out_idx
and target_host
default to None
, but their annotations omit it. RUF013
(pyupgrade) will flag this, blocking the new lint configuration. Please fold None
into the unions.
- out_idx: list[int] | int = None,
- target: str | Target = "auto",
- target_host: str | Target = None,
+ out_idx: list[int] | int | None = None,
+ target: str | Target = "auto",
+ target_host: str | Target | None = None,
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __init__( | |
self, | |
func: PrimFunc = None, | |
out_idx: Union[List[int], int] = None, | |
out_idx: list[int] | int = None, | |
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", | |
target: Union[str, Target] = "auto", | |
target_host: Union[str, Target] = None, | |
target: str | Target = "auto", | |
target_host: str | Target = None, | |
verbose: bool = False, | |
pass_configs: Optional[Dict[str, Any]] = None, | |
pass_configs: dict[str, Any] | None = None, | |
from_database: bool = False, | |
compile_flags: Optional[List[str]] = None, | |
compile_flags: list[str] | None = None, | |
): | |
def __init__( | |
self, | |
func: PrimFunc = None, | |
out_idx: list[int] | int | None = None, | |
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython", | |
target: str | Target = "auto", | |
target_host: str | Target | None = None, | |
verbose: bool = False, | |
pass_configs: dict[str, Any] | None = None, | |
from_database: bool = False, | |
compile_flags: list[str] | None = None, | |
): |
🧰 Tools
🪛 Ruff (0.13.3)
47-47: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
🤖 Prompt for AI Agents
In tilelang/jit/kernel.py around lines 44 to 55, the parameters out_idx and
target_host default to None but their type annotations omit None; update the
signatures so None is included in the unions (e.g., out_idx: list[int] | int |
None and target_host: str | Target | None) to make the types explicit and
satisfy the pyupgrade RUF013 lint rule; ensure any other parameters with default
None are similarly annotated if applicable.
A_sparse: tir.Buffer | tir.Var, | ||
E: tir.Buffer | tir.Var, | ||
B: tir.Buffer | tir.Var, | ||
C: tir.Buffer | tir.Var, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type hints correctly modernized. Update docstrings for consistency.
The PEP 604 union syntax (|
) is correctly applied to the function parameters and works with the __future__
import for Python 3.8+ compatibility.
However, the docstrings (lines 27-30) still reference the old Union[tir.Buffer, tir.Var]
syntax. Update them to match the actual type hints.
Apply this diff to update the docstrings:
Args:
- A_sparse (Union[tir.Buffer, tir.Var]): First input matrix dense values
- E (Union[tir.Buffer, tir.Var]): First input matrix sparse metadata
- B (Union[tir.Buffer, tir.Var]): Second input matrix
- C (Union[tir.Buffer, tir.Var]): Output matrix for results
+ A_sparse (tir.Buffer | tir.Var): First input matrix dense values
+ E (tir.Buffer | tir.Var): First input matrix sparse metadata
+ B (tir.Buffer | tir.Var): Second input matrix
+ C (tir.Buffer | tir.Var): Output matrix for results
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
A_sparse: tir.Buffer | tir.Var, | |
E: tir.Buffer | tir.Var, | |
B: tir.Buffer | tir.Var, | |
C: tir.Buffer | tir.Var, | |
""" | |
Args: | |
A_sparse (tir.Buffer | tir.Var): First input matrix dense values | |
E (tir.Buffer | tir.Var): First input matrix sparse metadata | |
B (tir.Buffer | tir.Var): Second input matrix | |
C (tir.Buffer | tir.Var): Output matrix for results | |
transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. | |
""" |
🤖 Prompt for AI Agents
In tilelang/language/experimental/gemm_sp.py around lines 27 to 30, the
docstring parameter types still use the old Union[tir.Buffer, tir.Var] syntax;
update the docstring entries for A_sparse, E, B, and C to reflect the modern PEP
604 style used in the signature (tir.Buffer | tir.Var), keeping wording and
punctuation identical otherwise so the docs match the actual type hints.
tilelang/language/parallel.py
Outdated
|
||
|
||
def Parallel(*extents: tir.PrimExpr, coalesced_width: Optional[int] = None): | ||
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Update the docstring to match the new type annotation.
The parameter type annotation has been correctly modernized from Optional[int]
to int | None
, but the docstring at line 18 still references Optional[int]
. Update the docstring for consistency.
Apply this diff to update the docstring:
coalesced_width : Optional[int]
+ coalesced_width : int | None
The coalesced width of the parallel loop.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def Parallel(*extents: tir.PrimExpr, coalesced_width: int | None = None): | |
coalesced_width : Optional[int] | |
coalesced_width : int | None | |
The coalesced width of the parallel loop. |
🤖 Prompt for AI Agents
In tilelang/language/parallel.py around lines 9 to 18, the function signature
now uses the modern union type `int | None` for the parameter `coalesced_width`,
but the docstring still refers to `Optional[int]`; update the docstring text to
reference `int | None` (or simply "int or None") so it matches the annotation,
keeping the same description and formatting.
tilelang/language/tir/entry.py
Outdated
from __future__ import annotations | ||
import inspect | ||
from typing import Callable, Optional, Union | ||
from typing import Callable | ||
|
||
import tvm.script.parser.tir.entry as _tir_entry | ||
from tvm.tir.function import PrimFunc | ||
from tvm.script.parser._core import parse, scan_macro, utils | ||
|
||
|
||
def prim_func(func: Optional[Callable] = None, | ||
def prim_func(func: Callable | None = None, | ||
private: bool = False, | ||
check_well_formed: bool = False) -> Union[PrimFunc, Callable]: | ||
check_well_formed: bool = False) -> PrimFunc | Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Update docstring to match new return type annotation.
The type annotations are correctly modernized using union syntax (Callable | None
, PrimFunc | Callable
). However, the docstring on line 31 still references the old Union[PrimFunc, Callable]
syntax.
Apply this diff to update the docstring:
Returns
-------
- res : Union[PrimFunc, Callable]
+ res : PrimFunc | Callable
The parsed tir prim func.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from __future__ import annotations | |
import inspect | |
from typing import Callable, Optional, Union | |
from typing import Callable | |
import tvm.script.parser.tir.entry as _tir_entry | |
from tvm.tir.function import PrimFunc | |
from tvm.script.parser._core import parse, scan_macro, utils | |
def prim_func(func: Optional[Callable] = None, | |
def prim_func(func: Callable | None = None, | |
private: bool = False, | |
check_well_formed: bool = False) -> Union[PrimFunc, Callable]: | |
check_well_formed: bool = False) -> PrimFunc | Callable: | |
Returns | |
------- | |
res : PrimFunc | Callable | |
The parsed tir prim func. |
🧰 Tools
🪛 Ruff (0.13.3)
11-11: Unused function argument: private
(ARG001)
🤖 Prompt for AI Agents
In tilelang/language/tir/entry.py around lines 1 to 40, the module docstring
still uses the old typing Union[PrimFunc, Callable]; update that docstring to
reflect the modern return type annotation by replacing any occurrence of
"Union[PrimFunc, Callable]" (or similar) with the new union syntax "PrimFunc |
Callable" so the docstring matches the function signature.
013834f
to
94610e8
Compare
version_provider.py
Outdated
|
||
|
||
def _read_cmake_bool(i: str | None, default=False): | ||
if i is None: | ||
return default | ||
return i.lower() not in ('0', 'false', 'off', 'no', 'n', '') | ||
return i.lower() not in ("0", "false", "off", "no", "n", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suppose there are too many quote changes. Perhaps we should prioritize single quotes over double quotes. You can add the following to pyproject.toml
:
[tool.ruff.lint]
select = [
...,
"Q", # flake8-quotes
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
multiline-quotes = "double"
inline-quotes = "single"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (9)
benchmark/matmul/benchmark_matmul_intrinsic.py (1)
305-305
: Remove hardcoded override of user input.Line 305 unconditionally sets
with_roller = True
, overriding the CLI argument parsed on line 304. This makes the argparse setup (lines 290-295) and user input meaningless.Given the PR's scope (linting/formatting), this appears to be a debugging artifact that was unintentionally committed.
Apply this diff to remove the hardcoded override:
accum_dtype = "float32" if in_dtype == "int8" else "float16" with_roller = args.with_roller - with_roller = True # Compute total floating-point operations total_flops = 2 * M * N * K
examples/gdn/example_chunk_delta_h.py (1)
10-17
: Critical: Incomplete fallback for missing fla dependency.When
fla
import fails,chunk_gated_delta_rule_fwd_h
is undefined, causing aNameError
at line 299. The error message "fla not found, using tilelang implementation" is misleading—no actual fallback occurs.Apply this diff to either skip the reference comparison gracefully or provide a proper fallback:
Option 1: Skip reference comparison when fla is unavailable
try: import fla print(fla.__file__) from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h + FLA_AVAILABLE = True except ImportError: print("fla not found, using tilelang implementation") fla = None + FLA_AVAILABLE = FalseThen at line 298, add:
- # fla ref - h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, - store_final_state, chunk_size, - save_new_value) + # fla ref + if FLA_AVAILABLE: + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, + store_final_state, chunk_size, + save_new_value) + else: + print("Skipping reference comparison (fla not available)") + h_ref, V_new_ref, final_state_ref = None, None, NoneAnd update the correctness checks (lines 343-386) to skip when
h_ref is None
.examples/blocksparse_gemm/example_blocksparse_gemm.py (2)
57-69
: Fix: Python if on CUDA bool tensor in ref pathUsing a CUDA bool tensor in an if causes errors or sync. Convert mask to CPU (once) and compare Python bools.
Apply this diff:
def ref_program(A, B, BlockMask, block_M, block_N, block_K): - ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) + mask_cpu = BlockMask.bool().cpu() @@ - for k in range(K // block_K): - if BlockMask[i, j, k]: + for k in range(K // block_K): + if mask_cpu[i, j, k].item(): accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( torch.float32) @ B[k * block_K:(k + 1) * block_K, j * block_N:(j + 1) * block_N].to(torch.float32)
72-88
: Fix: BlockMask generation uses invalid device and cross-device assignment
- device=torch.cuda.current_device() passes an int; needs torch.device.
- torch.rand(...) creates CPU tensor; assigning to CUDA tensor fails.
- Prefer robust BlockMask detection via dtype+rank and ensure shape is a tuple.
Apply this diff:
-def supply_program(params: list[KernelParam]): - input_tensors = [] - - for p in params: - # Check if the kernel parameter is BlockMask tensor. - # Here, BlockMask is uniquely identified by having 3 dimensions. - if len(p.shape) != 3: - # For non-BlockMask tensors, use the default tensor generation logic. - input_tensors.append(default_tensor_supply(p)) - else: - # For BlockMask tensor, randomly set elements to True based on desired - # sparsity level. - block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device()) - block_mask[:, :, :] = torch.rand(p.shape) > sparsity - input_tensors.append(block_mask) - - return input_tensors +def supply_program(params: list[KernelParam]): + inputs: list[torch.Tensor] = [] + device = torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() \ + else torch.device("cpu") + for p in params: + # Identify BlockMask robustly: boolean 3D tensor + if hasattr(p, "is_boolean") and p.is_boolean() and len(p.shape) == 3: + shape = tuple(p.shape) # torch expects a tuple of ints + mask = torch.rand(shape, device=device) > sparsity + inputs.append(mask) + else: + inputs.append(default_tensor_supply(p)) + return inputsNote: If the autotuner invokes this function, this fix prevents CUDA/CPU mismatches during tuning. See KernelParam.is_boolean() in tilelang/engine/param.py. [Based on relevant code snippets]
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (4)
6-6
: Remove duplicate importRedundant second
import torch
. Drop it.-import torch
206-206
: Fix undefined H and wrong grid split in bwd_dkv kernel
H
is not defined; useheads_kv = heads // groups
. Also drop stray print.- print("NV", NV, "NS", NS, "B", B, "H", H) + # print removed: avoid debug noise - with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + with T.Kernel(NV, NS, B * heads_kv, threads=num_threads) as (i_v, i_s, i_bh): - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // heads_kv, i_bh % heads_kvAlso applies to: 220-220, 239-239
387-387
: Fix undefined H and wrong grid split in bwd_dqkv kernelUse
heads_kv
to size grid and spliti_bh
.- with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh): + with T.Kernel(NV, NS, B * heads_kv, threads=num_threads) as (i_v, i_s, i_bh): - i_b, i_h = i_bh // H, i_bh % H + i_b, i_h = i_bh // heads_kv, i_bh % heads_kvAlso applies to: 407-407
537-545
: Fix undefinedblock_counts
in block_mask and robustify call site
block_counts
is referenced but not in scope; add it to the wrapper signature and computeUSE_BLOCK_COUNTS
from it.- Call site unconditionally calls
.to()
; guard for None/int.-def tilelang_kernel_block_mask( - batch, - heads, - seq_len, - selected_blocks, - block_size, - dtype="int32", -): +def tilelang_kernel_block_mask( + batch, + heads, + seq_len, + selected_blocks, + block_size, + dtype="int32", + block_counts=None, +): @@ - USE_BLOCK_COUNTS = block_counts is not None + USE_BLOCK_COUNTS = block_counts is not NoneUpdate the call site:
- block_mask = tilelang_kernel_block_mask(B, H, T, S, - BS)(block_indices.to(torch.int32), - block_counts.to(torch.int32)).to(torch.bool) + bc_arg = ( + block_counts.to(torch.int32) + if isinstance(block_counts, torch.Tensor) + else torch.empty(0, dtype=torch.int32, device=block_indices.device) + ) + block_mask = tilelang_kernel_block_mask(B, H, T, S, BS, block_counts=block_counts)( + block_indices.to(torch.int32), bc_arg + ).to(torch.bool)Also applies to: 554-554, 610-613
examples/deepseek_nsa/example_triton_nsa_bwd.py (1)
354-355
: Use bitwise boolean ops in Triton expressionsPython
and
on Triton tensors is invalid; use elementwise&
with parentheses.- b_p_swa = tl.where((i >= o_s and (i - WS) < o_s)[:, None], b_p_swa, 0) + b_p_swa = tl.where(((i >= o_s) & ((i - WS) < o_s))[:, None], b_p_swa, 0)
🧹 Nitpick comments (23)
examples/bitnet-1.58b/load_from_quantized.py (1)
52-52
: Outer parentheses are unnecessary.The parentheses wrapping the entire expression add visual noise without improving readability or changing behavior. Consider removing them for cleaner code.
Apply this diff:
- qmodel = (BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()) + qmodel = BitnetForCausalLM.from_quantized(saved_model_path).cuda().half()Note: Also removed the trailing comma in
from_quantized(saved_model_path,)
since it's unnecessary for a single-argument call.examples/blocksparse_attention/block_sparse_attn_triton.py (1)
201-213
: Consider prefixing unusedctx
parameter with underscore.The multi-line formatting improves readability. However, the
ctx
parameter is unused throughout the function body. Since this helper is called from_sparse_attention.forward
(line 262) andctx
is never utilized (backward pass is not implemented), consider prefixing it with an underscore (_ctx
) to indicate it's intentionally unused and silence the linter warning.Apply this diff if you want to silence the linter warning:
def _forward( - ctx, + _ctx, q, k, v,examples/flash_decoding/example_gqa_decode.py (1)
475-475
: Consider defining a custom exception class (optional).Static analysis flags TRY003 on this line. While the current error message is concise and clear, defining a custom exception class (e.g.,
class SimilarityError(AssertionError)
) would align with the TRY003 guideline. However, this is a minor style concern and may be overkill for an example/benchmark file.examples/deepseek_v32/fp8_lighting_indexer.py (1)
1-1
: Consider removing or making the noqa directive more specific.The blanket
# ruff: noqa
disables all ruff checks for this file, which seems to conflict with the PR's goal of enabling pyupgrade rules. If specific rules need to be suppressed, consider using targeted ignores like# ruff: noqa: UP001, UP032
instead.examples/convolution/example_convolution.py (1)
50-51
: Consider removing unused parameter overrides.The
dtype
andaccum_dtype
parameters are immediately overridden with hardcoded values, making the function parameters ineffective. Since you're updating the function signature anyway, consider either:
- Removing these parameters entirely (if always hardcoded), or
- Removing lines 50-51 to honor the passed values
Note: This pattern also appears in
example_convolution_autotune.py
(lines 117-118), so you may want to address it consistently across related files.examples/deepseek_mla/example_mla_decode_paged.py (1)
313-328
: Note: Unused parameters in reference implementation.The
block_table
andblock_size
parameters are unused in the reference implementationrun_torch_mla
. The function uses direct indexing (i * max_seqlen_pad
) rather than block table lookups.These parameters are likely kept for API consistency with
run_tilelang_mla
, which does use them. This is acceptable for maintaining a uniform interface, though you could consider documenting this or using leading underscore naming (e.g.,_block_table
) to indicate intentionally unused parameters.If you want to explicitly mark them as intentionally unused:
def run_torch_mla( q, - block_table, + _block_table, # unused, kept for API consistency blocked_k, max_seqlen_pad, - block_size, + _block_size, # unused, kept for API consistency b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype, ):Based on learnings
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
66-93
: Comprehensive documentation added.The new docstring provides thorough documentation of the function's parameters, behavior, and requirements. While technically correct and informative, the 28-line docstring is quite verbose for an example file. Consider whether a more concise summary would suffice, reserving this level of detail for core library functions.
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
206-215
: LGTM! Consider fixing the comment typo.The type annotation modernization from
Optional[int]
toint | None
is correct and aligns with PEP 604. The public API signature change is backward-compatible at runtime.Minor: Line 215 has a typo in the comment:
function'sinterface
should befunction's interface
(missing space).examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
169-169
: LGTM! Consider fixing the comment typo.The type annotation update and
T.alloc_local
usage are correct. The public API signature change maintains runtime compatibility.Minor: Line 202 has the same typo as the first file:
function's interface
is missing a space between "function's" and "interface".Also applies to: 193-202
examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
366-367
: Verify the defensive getattr change and ensure consistency.This change from direct attribute access to
getattr()
with a default is a logic improvement but appears unrelated to the pyupgrade linting objectives. While the defensive coding is good practice, note that Line 319 uses a similar pattern without this defensive check:if self.model.get_output_embeddings().bias is not None:For consistency, consider updating Line 319 similarly, or clarify whether this change addresses a specific issue with certain model types that lack a
bias
attribute.Consider applying the same defensive pattern to Line 319:
-if self.model.get_output_embeddings().bias is not None: +if getattr(self.model.get_output_embeddings(), "bias", None) is not None:examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
1-351
: Consider separating functional changes from linting fixes.This PR mixes functional changes (lines 30-35:
pass_configs
addition) with linting/formatting improvements. While the linting changes are appropriate for a "Enable pyupgrade linter" PR, functional changes that alter behavior should ideally be in separate commits or PRs for easier review and potential rollback.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
415-415
: Remove unnecessary int() cast.The value is already an integer from
math.ceil()
, which returns an int in Python 3. Theint()
cast is redundant.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
418-418
: Remove unnecessary int() cast.
math.ceil()
already returns an int in Python 3.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/fusedmoe/example_fusedmoe_tilelang.py (1)
7-8
: Consider replacing star imports.Static analysis flags these star imports as they make it difficult to track which names are imported and can lead to namespace pollution.
If these modules have a limited set of commonly used exports, consider explicit imports:
-from tilelang.autotuner import * -from example_fusedmoe_torch import * +from tilelang.autotuner import <specific items> +from example_fusedmoe_torch import <specific items>Alternatively, use qualified imports:
-from tilelang.autotuner import * -from example_fusedmoe_torch import * +import tilelang.autotuner +import example_fusedmoe_torchexamples/blocksparse_gemm/example_blocksparse_gemm.py (1)
93-105
: Rename “enable_rasteration” → “enable_rasterization” for consistencySpelling is inconsistent with DEFAULT_ENABLE_RASTERIZATION and typical terminology. Rename for clarity and avoid future confusion.
Apply these diffs within this file’s changed regions:
def blocksparse_matmul( @@ - enable_rasteration, + enable_rasterization, @@ - T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.use_swizzle(panel_size=10, enable=enable_rasterization)- kernel = blocksparse_matmul( + kernel = blocksparse_matmul( @@ - enable_rasteration=DEFAULT_ENABLE_RASTERIZATION, + enable_rasterization=DEFAULT_ENABLE_RASTERIZATION,Also update get_configs for consistency (outside the changed hunk; example snippet):
enable_rasterization = [True, False] # ... { # ... "enable_rasterization": c[5], }Also applies to: 168-169
examples/flash_attention/example_mha_bwd.py (1)
290-299
: Place scalar on the same device or use a Python scalar
Replacescores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))with either
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))or
scores = scores / (dim ** 0.5)to avoid implicit CPU→CUDA copies.
examples/bitnet-1.58b/benchmark_inference_latency.py (1)
15-16
: Drop unnecessary NumPy import and mean;times
is already a scalar.Simplify and avoid an extra dependency.
def profile(model, input_data): import time - - import numpy as np @@ - return np.mean(times) + return timesAlso applies to: 34-34
examples/bitnet-1.58b/tokenization_bitnet.py (1)
326-337
: Align return type with implementation (can return None).save_vocabulary returns early without a value on invalid dir (Line 339), conflicting with
-> tuple[str]
.-def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: +def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str] | None:examples/bitnet-1.58b/modeling_bitnet.py (3)
371-382
: Silence unused-argument lint while preserving HF API.use_cache/kwargs are intentionally unused for signature compatibility. Add deletions to satisfy Ruff ARG002.
def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + # Unused; kept for HF interface compatibility. + del use_cache, kwargs bsz, q_len, _ = hidden_states.size()[Based on static analysis hints]
524-535
: Apply same ARG002 fix in fused attention.def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + # Unused; kept for HF interface compatibility. + del use_cache, kwargs bsz, q_len, _ = hidden_states.size()[Based on static analysis hints]
612-623
: Apply same ARG002 fix in FlashAttention2 forward.def forward( self, hidden_states: torch.Tensor, attention_mask: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, past_key_value: Cache | None = None, output_attentions: bool = False, use_cache: bool = False, cache_position: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + # Unused; kept for HF interface compatibility. + del use_cache, kwargs output_attentions = False[Based on static analysis hints]
examples/bitnet-1.58b/eval_correctness.py (1)
49-51
: Avoid importing NumPy; return the scalar directly.
times
is already averaged;np.mean(times)
is redundant.def profile(model, input_data): - import numpy as np @@ - return np.mean(times) + return timesAlso applies to: 69-69
examples/deepseek_nsa/example_triton_nsa_fwd.py (1)
21-24
: Wire up USE_BLOCK_COUNTS or drop the heuristic
USE_BLOCK_COUNTS
is defined but not used (NS hard-coded to S). Prefer using it to respect per-token counts.- # if USE_BLOCK_COUNTS: - # NS = tl.load(block_counts + (bos + i_t) * H + i_h) - # else: - NS = S + NS = tl.load(block_counts + (bos + i_t) * H + i_h) if USE_BLOCK_COUNTS else SAlso applies to: 66-71
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (107)
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
(3 hunks)benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
(4 hunks)benchmark/matmul/benchmark_matmul.py
(1 hunks)benchmark/matmul/benchmark_matmul_intrinsic.py
(1 hunks)benchmark/matmul/benchmark_matmul_sp.py
(3 hunks)benchmark/matmul_fp8/benchmark_matmul.py
(0 hunks)docs/conf.py
(2 hunks)examples/amd/example_amd_flash_attn_bwd.py
(3 hunks)examples/amd/example_amd_flash_attn_fwd.py
(7 hunks)examples/analyze/example_conv_analyze.py
(2 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(10 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(6 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(9 hunks)examples/bitnet-1.58b/benchmark_generate.py
(3 hunks)examples/bitnet-1.58b/benchmark_inference_latency.py
(3 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(2 hunks)examples/bitnet-1.58b/eval_correctness.py
(3 hunks)examples/bitnet-1.58b/eval_gpu_memory.py
(2 hunks)examples/bitnet-1.58b/eval_ppl.py
(3 hunks)examples/bitnet-1.58b/eval_utils.py
(1 hunks)examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
(2 hunks)examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
(0 hunks)examples/bitnet-1.58b/load_from_quantized.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(3 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(27 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(9 hunks)examples/bitnet-1.58b/utils_quant.py
(2 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(22 hunks)examples/bitnet-1.58b/vllm_workspace/utils.py
(3 hunks)examples/blocksparse_attention/block_sparse_attn_triton.py
(9 hunks)examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
(5 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
(17 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(15 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(16 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
(12 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
(12 hunks)examples/blocksparse_attention/heuristic.py
(1 hunks)examples/blocksparse_gemm/example_blocksparse_gemm.py
(5 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(7 hunks)examples/cast/example_per_token_cast_to_fp8.py
(7 hunks)examples/cast/example_triton_cast_to_fp8.py
(3 hunks)examples/convolution/example_convolution.py
(3 hunks)examples/convolution/example_convolution_autotune.py
(5 hunks)examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
(4 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
(7 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
(9 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
(9 hunks)examples/deepseek_mla/benchmark_mla.py
(13 hunks)examples/deepseek_mla/example_mla_decode.py
(7 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(12 hunks)examples/deepseek_mla/example_mla_decode_persistent.py
(4 hunks)examples/deepseek_mla/example_mla_decode_ws.py
(14 hunks)examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
(3 hunks)examples/deepseek_mla/torch_refs.py
(1 hunks)examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
(28 hunks)examples/deepseek_nsa/example_tilelang_nsa_bwd.py
(4 hunks)examples/deepseek_nsa/example_tilelang_nsa_decode.py
(3 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd.py
(2 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
(7 hunks)examples/deepseek_nsa/example_triton_nsa_bwd.py
(22 hunks)examples/deepseek_nsa/example_triton_nsa_fwd.py
(8 hunks)examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
(12 hunks)examples/deepseek_nsa/reference.py
(9 hunks)examples/deepseek_v32/fp8_lighting_indexer.py
(3 hunks)examples/deepseek_v32/sparse_mla_bwd.py
(12 hunks)examples/deepseek_v32/sparse_mla_fwd.py
(6 hunks)examples/deepseek_v32/sparse_mla_fwd_pipelined.py
(13 hunks)examples/deepseek_v32/topk_selector.py
(3 hunks)examples/deepseek_v32/utils.py
(4 hunks)examples/dequantize_gemm/dequantize_utils.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(7 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
(8 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(9 hunks)examples/dequantize_gemm/example_dequant_gemm_w4a8.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
(5 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(14 hunks)examples/dynamic_shape/example_dynamic.py
(2 hunks)examples/elementwise/example_elementwise_add.py
(2 hunks)examples/elementwise/example_elementwise_add_tma_1d.py
(1 hunks)examples/flash_attention/bert_padding.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(12 hunks)examples/flash_attention/example_gqa_bwd_tma_reduce.py
(12 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(12 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(5 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_bwd.py
(5 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_varlen.py
(4 hunks)examples/flash_decoding/example_gqa_decode.py
(11 hunks)examples/flash_decoding/example_mha_inference.py
(9 hunks)examples/fusedmoe/example_fusedmoe_tilelang.py
(16 hunks)examples/fusedmoe/example_fusedmoe_torch.py
(7 hunks)examples/fusedmoe/test_example_fusedmoe.py
(1 hunks)examples/gdn/example_chunk_delta_bwd.py
(10 hunks)examples/gdn/example_chunk_delta_h.py
(9 hunks)examples/gdn/example_chunk_o.py
(5 hunks)
⛔ Files not processed due to max files limit (36)
- examples/gdn/example_chunk_o_bwd.py
- examples/gdn/example_chunk_scaled_dot_kkt.py
- examples/gdn/example_cumsum.py
- examples/gdn/example_wy_fast.py
- examples/gdn/example_wy_fast_bwd_split.py
- examples/gdn/test_example_gdn_compilation.py
- examples/gdn/utils.py
- examples/gemm/example_gemm_autotune.py
- examples/gemm/example_gemm_intrinsics.py
- examples/gemm/example_gemm_persistent.py
- examples/gemm_fp8/example_tilelang_gemm_amd.py
- examples/gemm_fp8/example_tilelang_gemm_fp8.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
- examples/gemm_sm100/gemm_mma.py
- examples/gemm_sm100/gemm_tcgen5mma.py
- examples/gemm_sp/example_gemm_sp.py
- examples/gemm_splitk/example_tilelang_gemm_splitk.py
- examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
- examples/gemm_streamk/example_tilelang_gemm_streamk.py
- examples/grouped_gemm/example_grouped_gemm_bwd.py
- examples/grouped_gemm/example_grouped_gemm_fwd.py
- examples/hadamard_transform/example_hadamard.py
- examples/linear_attention/example_linear_attn_bwd.py
- examples/linear_attention/example_linear_attn_fwd.py
- examples/linear_attention/example_mamba_chunk_scan.py
- examples/linear_attention/example_mamba_chunk_state.py
- examples/linear_attention/example_retention_fwd.py
- examples/minference/example_vertical_slash_sparse_attn.py
- examples/norm/rms_norm.py
- examples/online_softmax/online_softmax.py
- examples/plot_layout/fragment_mma_load_a.py
- examples/seer_attention/block_sparse_attn_tilelang.py
- examples/seer_attention/block_sparse_attn_triton.py
- examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
- examples/topk/example_topk.py
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
- benchmark/matmul_fp8/benchmark_matmul.py
✅ Files skipped from review due to trivial changes (21)
- examples/flash_attention/bert_padding.py
- benchmark/matmul/benchmark_matmul.py
- examples/amd/example_amd_flash_attn_fwd.py
- examples/deepseek_nsa/example_tilelang_nsa_decode.py
- examples/deepseek_mla/example_mla_decode_ws.py
- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
- examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
- examples/deepseek_nsa/reference.py
- examples/blocksparse_attention/heuristic.py
- examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
- examples/bitnet-1.58b/eval_utils.py
- examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
- examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
- examples/deepseek_v32/topk_selector.py
- examples/deepseek_nsa/example_tilelang_nsa_fwd.py
- examples/gdn/example_chunk_o.py
- examples/elementwise/example_elementwise_add.py
- benchmark/matmul/benchmark_matmul_sp.py
- examples/deepseek_v32/sparse_mla_fwd_pipelined.py
- examples/deepseek_mla/torch_refs.py
- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
🚧 Files skipped from review as they are similar to previous changes (22)
- docs/conf.py
- examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
- examples/flash_attention/example_mha_bwd_bhsd.py
- examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
- examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
- examples/bitnet-1.58b/configuration_bitnet.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
- examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
- examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
- examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
- examples/bitnet-1.58b/utils_quant.py
- examples/bitnet-1.58b/vllm_workspace/utils.py
- examples/flash_attention/example_gqa_bwd.py
- examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
- examples/cast/example_triton_cast_to_fp8.py
- examples/flash_attention/example_gqa_fwd_bshd.py
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py
- examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
- examples/cast/example_per_token_cast_to_fp8.py
- examples/amd/example_amd_flash_attn_bwd.py
- examples/attention_sink/example_mha_sink_bwd_bhsd.py
🧰 Additional context used
🧬 Code graph analysis (48)
examples/cast/example_group_per_split_token_cast_to_fp8.py (3)
tilelang/language/tir/op.py (1)
ceildiv
(3116-3133)tilelang/language/kernel.py (2)
Kernel
(229-303)threads
(215-219)tilelang/language/copy.py (1)
copy
(15-93)
examples/elementwise/example_elementwise_add_tma_1d.py (1)
tilelang/language/parallel.py (1)
Parallel
(9-29)
examples/flash_decoding/example_gqa_decode.py (2)
examples/gemm/example_gemm_autotune.py (1)
get_heuristic_config
(165-199)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
flashattn
(35-207)ref_program
(212-257)main
(141-205)main
(394-473)gen_inputs
(379-391)triton_program
(349-376)
examples/flash_attention/example_mha_fwd_bhsd.py (3)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/profiler/__init__.py (2)
assert_allclose
(77-146)do_bench
(219-282)
examples/deepseek_mla/benchmark_mla.py (1)
examples/deepseek_mla/example_mla_decode_paged.py (1)
run_torch_mla
(313-354)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py (1)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py (6)
run_torch_mla
(35-73)run_flash_mla_triton
(327-373)flash_mla_triton
(352-369)mla_decode_triton
(292-323)compare_a
(458-505)compare_ab
(382-455)
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py (1)
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (1)
Softmax
(87-114)
examples/bitnet-1.58b/load_from_quantized.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
from_quantized
(1500-1578)
examples/flash_attention/example_mha_fwd_bshd.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/profiler/__init__.py (2)
assert_allclose
(77-146)do_bench
(219-282)
examples/deepseek_mla/example_mla_decode.py (5)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/tileop/gemm/gemm_base.py (2)
clear_accum
(107-108)policy
(119-120)tilelang/language/kernel.py (1)
threads
(215-219)tilelang/language/copy.py (1)
copy
(15-93)
examples/gdn/example_chunk_delta_bwd.py (1)
tilelang/language/copy.py (1)
copy
(15-93)
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (2)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (1)
parallel_nsa
(232-308)examples/gemm_streamk/example_tilelang_gemm_streamk.py (1)
cdiv
(8-9)
examples/fusedmoe/example_fusedmoe_tilelang.py (2)
tilelang/language/copy.py (1)
copy
(15-93)examples/fusedmoe/example_fusedmoe_torch.py (3)
forward
(21-24)forward
(37-42)forward
(56-67)
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py (1)
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py (1)
get_sparse_attn_mask_from_topk
(14-26)
examples/bitnet-1.58b/eval_ppl.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
BitnetForCausalLM
(1231-1578)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/tir/op.py (1)
if_then_else
(2906-2936)tilelang/language/__init__.py (1)
symbolic
(87-98)
examples/deepseek_mla/example_mla_decode_persistent.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/tileop/gemm/gemm_base.py (1)
policy
(119-120)
examples/bitnet-1.58b/eval_gpu_memory.py (5)
examples/bitnet-1.58b/benchmark_generate.py (1)
profile
(54-74)examples/bitnet-1.58b/benchmark_inference_latency.py (1)
profile
(12-34)examples/bitnet-1.58b/eval_correctness.py (1)
profile
(49-69)examples/bitnet-1.58b/benchmark_model_10k_loops.py (1)
profile
(19-41)examples/bitnet-1.58b/modeling_bitnet.py (1)
_post_process_weights
(1487-1491)
examples/bitnet-1.58b/benchmark_generate.py (2)
examples/bitnet-1.58b/eval_ppl.py (1)
main
(31-61)examples/bitnet-1.58b/benchmark_inference_latency.py (1)
main
(37-53)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/language/allocate.py (1)
alloc_local
(39-50)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
ref_program
(414-459)
examples/deepseek_nsa/example_triton_nsa_fwd.py (1)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(26-107)parallel_nsa
(232-308)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
ref_program
(212-257)gen_inputs
(379-391)tilelang/language/allocate.py (1)
alloc_local
(39-50)
examples/deepseek_v32/sparse_mla_fwd.py (2)
tilelang/math/__init__.py (1)
next_power_of_2
(1-2)examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
sparse_mla_fwd_interface
(343-395)
examples/analyze/example_conv_analyze.py (2)
examples/gemm/example_gemm_autotune.py (1)
kernel
(110-150)examples/analyze/example_gemm_analyze.py (1)
kernel
(10-46)
examples/gdn/example_chunk_delta_h.py (2)
examples/gdn/example_chunk_o.py (2)
prepare_input
(26-44)kernel
(92-197)examples/gdn/example_chunk_delta_bwd.py (3)
prepare_input
(33-64)kernel
(232-398)do_bench
(614-636)
examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
get_output_embeddings
(1249-1250)
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
tilelang/language/tir/op.py (2)
call_extern
(173-195)address_of
(464-480)
examples/bitnet-1.58b/benchmark_inference_latency.py (4)
examples/bitnet-1.58b/benchmark_generate.py (1)
profile
(54-74)examples/bitnet-1.58b/eval_correctness.py (1)
profile
(49-69)examples/bitnet-1.58b/eval_gpu_memory.py (1)
profile
(12-34)examples/bitnet-1.58b/benchmark_model_10k_loops.py (1)
profile
(19-41)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (2)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/flash_decoding/example_mha_inference.py (2)
tilelang/language/copy.py (1)
copy
(15-93)tilelang/profiler/__init__.py (1)
do_bench
(219-282)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
matmul
(86-381)
examples/convolution/example_convolution_autotune.py (2)
tilelang/autotuner/tuner.py (1)
autotune
(727-820)examples/gemm/example_gemm_autotune.py (1)
get_configs
(22-105)
examples/convolution/example_convolution.py (2)
examples/convolution/example_convolution_autotune.py (1)
convolution
(97-168)tilelang/language/kernel.py (1)
threads
(215-219)
examples/fusedmoe/example_fusedmoe_torch.py (2)
examples/fusedmoe/example_fusedmoe_tilelang.py (3)
forward
(317-320)forward
(333-338)forward
(432-535)tilelang/language/customize.py (1)
view
(51-61)
examples/bitnet-1.58b/modeling_bitnet.py (1)
examples/bitnet-1.58b/configuration_bitnet.py (1)
BitnetConfig
(29-194)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/__init__.py (1)
annotate_layout
(110-148)tilelang/jit/__init__.py (1)
jit
(242-318)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
matmul
(49-354)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
matmul
(86-381)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group
(52-109)
examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py (2)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/flash_attention/example_mha_fwd_varlen.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/tir/op.py (1)
if_then_else
(2906-2936)
examples/deepseek_nsa/example_triton_nsa_bwd.py (3)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(26-107)parallel_nsa
(232-308)examples/deepseek_nsa/example_triton_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa
(238-314)examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (2)
parallel_nsa_fwd_kernel
(30-157)parallel_nsa
(259-335)
examples/flash_attention/example_mha_bwd.py (3)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/profiler/__init__.py (1)
do_bench
(219-282)
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (2)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(26-107)parallel_nsa
(232-308)examples/deepseek_nsa/example_triton_nsa_bwd.py (3)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa_fwd_kernel
(559-686)parallel_nsa
(914-990)
examples/blocksparse_gemm/example_blocksparse_gemm.py (2)
tilelang/jit/kernel.py (1)
params
(475-476)tilelang/engine/param.py (1)
KernelParam
(12-104)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/__init__.py (1)
symbolic
(87-98)
examples/deepseek_mla/example_mla_decode_paged.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)examples/deepseek_mla/benchmark_mla.py (1)
run_torch_mla
(36-74)
🪛 Ruff (0.14.0)
examples/flash_decoding/example_gqa_decode.py
475-475: Avoid specifying long messages outside the exception class
(TRY003)
examples/fusedmoe/example_fusedmoe_tilelang.py
7-7: from tilelang.autotuner import *
used; unable to detect undefined names
(F403)
8-8: from example_fusedmoe_torch import *
used; unable to detect undefined names
(F403)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
332-332: Unused function argument: max_cache_seqlen
(ARG001)
332-332: Unused function argument: num_blocks
(ARG001)
334-334: Unpacked variable heads
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
373-373: Unused function argument: block_indices
(ARG001)
373-373: Unused function argument: max_cache_seqlen
(ARG001)
373-373: Unused function argument: num_blocks
(ARG001)
374-374: Unused function argument: block_size
(ARG001)
418-418: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/blocksparse_attention/block_sparse_attn_triton.py
202-202: Unused function argument: ctx
(ARG001)
examples/convolution/example_convolution_autotune.py
112-112: Unused function argument: enable_rasteration
(ARG001)
182-182: Unused function argument: with_roller
(ARG001)
examples/bitnet-1.58b/modeling_bitnet.py
378-378: Unused method argument: use_cache
(ARG002)
380-380: Unused method argument: kwargs
(ARG002)
531-531: Unused method argument: use_cache
(ARG002)
533-533: Unused method argument: kwargs
(ARG002)
619-619: Unused method argument: use_cache
(ARG002)
621-621: Unused method argument: kwargs
(ARG002)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
440-440: Comment contains ambiguous (
(FULLWIDTH LEFT PARENTHESIS). Did you mean (
(LEFT PARENTHESIS)?
(RUF003)
440-440: Comment contains ambiguous )
(FULLWIDTH RIGHT PARENTHESIS). Did you mean )
(RIGHT PARENTHESIS)?
(RUF003)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
368-368: Unused function argument: block_indices
(ARG001)
368-368: Unused function argument: max_cache_seqlen
(ARG001)
368-368: Unused function argument: num_blocks
(ARG001)
369-369: Unused function argument: block_size
(ARG001)
415-415: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/deepseek_mla/example_mla_decode_paged.py
315-315: Unused function argument: block_table
(ARG001)
318-318: Unused function argument: block_size
(ARG001)
parser = argparse.ArgumentParser() | ||
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) | ||
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Pre-existing issue: CLI argument defined but never used.
The --hf_path
argument is defined at line 9 but never parsed or used in the main()
function. The hardcoded value "1bitLLM/bitnet_b1_58-3B"
at line 39 duplicates the default, making the CLI argument non-functional.
Run the following script to verify if this pattern exists in other similar files:
If this is a widespread pattern in the bitnet examples, consider either:
- Removing the unused argument parser, or
- Updating
main()
to parse and useargs.hf_path
:args = parser.parse_args() model = BitnetForCausalLM.from_pretrained( args.hf_path, device_map="auto", ... )
Also applies to: 38-44
🏁 Script executed:
#!/bin/bash
# Check if other eval/benchmark scripts properly use their CLI arguments
echo "=== Files with argparse but potentially unused args ==="
for file in examples/bitnet-1.58b/{eval,benchmark}*.py; do
if grep -q "argparse.ArgumentParser" "$file"; then
echo "--- $file ---"
echo "Arguments defined:"
grep -n "add_argument" "$file" || echo " (none found)"
echo "Args parsed:"
grep -n "parse_args\|args\." "$file" || echo " (none found)"
echo ""
fi
done
Length of output: 1859
Use or remove unused --hf_path
argument in Bitnet examples
In these scripts the --hf_path
flag is defined but never parsed or applied:
• examples/bitnet-1.58b/eval_gpu_memory.py
• examples/bitnet-1.58b/benchmark_inference_latency.py
• examples/bitnet-1.58b/benchmark_model_10k_loops.py
Either drop the unused --hf_path
argument or insert
args = parser.parse_args()
model = BitnetForCausalLM.from_pretrained(args.hf_path, …)
(and similarly replace hard-coded "1bitLLM/bitnet_b1_58-3B"
with args.hf_path
).
🤖 Prompt for AI Agents
In examples/bitnet-1.58b/eval_gpu_memory.py around lines 8-9, the parser defines
--hf_path but never parses or uses it; replace the hard-coded model path by
calling args = parser.parse_args() and pass args.hf_path into
BitnetForCausalLM.from_pretrained (and any other model load sites) instead of
the literal "1bitLLM/bitnet_b1_58-3B", or remove the --hf_path argument if you
prefer not to make the path configurable.
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, | ||
block_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove unused parameters in ref_program_torch.
The parameters max_cache_seqlen
and num_blocks
are never used in the function body.
Apply this diff:
-def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
- block_size):
+def ref_program_torch(query, key, value, block_indices, cache_seqlens, block_size):
Update the call site at line 464:
- ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks,
- block_size)
+ ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, block_size)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ref_program_torch(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, | |
block_size): | |
def ref_program_torch(query, key, value, block_indices, cache_seqlens, block_size): | |
… |
🧰 Tools
🪛 Ruff (0.14.0)
332-332: Unused function argument: max_cache_seqlen
(ARG001)
332-332: Unused function argument: num_blocks
(ARG001)
🤖 Prompt for AI Agents
In
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
around lines 332-333, the function ref_program_torch includes unused parameters
max_cache_seqlen and num_blocks; remove these two parameters from the function
signature and update its single call at line 464 to stop passing those two
arguments (pass only query, key, value, block_indices, cache_seqlens, block_size
in the same order). Ensure there are no other references to max_cache_seqlen or
num_blocks inside the function or remaining call sites and run tests/linters to
confirm the signature change is consistent project-wide.
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, | ||
block_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove unused parameters in ref_program_fa.
The parameters block_indices
, max_cache_seqlen
, num_blocks
, and block_size
are never used.
Apply this diff:
-def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
- block_size):
+def ref_program_fa(query, key, value, cache_seqlens):
Update call sites at lines 475 and 480:
- ref = ref_program_fa(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen,
- max_num_blocks, block_size)
+ ref = ref_program_fa(Q, K, V, cache_seqlens)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, | |
block_size): | |
def ref_program_fa(query, key, value, cache_seqlens): |
🧰 Tools
🪛 Ruff (0.14.0)
373-373: Unused function argument: block_indices
(ARG001)
373-373: Unused function argument: max_cache_seqlen
(ARG001)
373-373: Unused function argument: num_blocks
(ARG001)
374-374: Unused function argument: block_size
(ARG001)
🤖 Prompt for AI Agents
In
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
around lines 373-374, the function ref_program_fa declares unused parameters
block_indices, max_cache_seqlen, num_blocks, and block_size; remove these from
the function signature and update its call sites at lines 475 and 480 to stop
passing those arguments (keep only the remaining required arguments: query, key,
value, cache_seqlens). Ensure the updated signature and all calls preserve the
original argument order for the remaining parameters and run tests/lint to
confirm no other references remain.
def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks, | ||
block_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove unused parameters in ref_program_fa.
The parameters block_indices
, max_cache_seqlen
, num_blocks
, and block_size
are declared but never used in the function body. Since this is a reference implementation for latency comparison, consider removing them to match the actual usage.
Apply this diff:
-def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_seqlen, num_blocks,
- block_size):
+def ref_program_fa(query, key, value, cache_seqlens):
Update the call sites at lines 463 and 468 accordingly:
- ref = ref_program_fa(Q, K, V, block_mask, cache_seqlens, max_cache_seqlen, num_blocks,
- block_size)
+ ref = ref_program_fa(Q, K, V, cache_seqlens)
🧰 Tools
🪛 Ruff (0.14.0)
368-368: Unused function argument: block_indices
(ARG001)
368-368: Unused function argument: max_cache_seqlen
(ARG001)
368-368: Unused function argument: num_blocks
(ARG001)
369-369: Unused function argument: block_size
(ARG001)
🤖 Prompt for AI Agents
In
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
around lines 368-369, ref_program_fa declares unused parameters block_indices,
max_cache_seqlen, num_blocks, and block_size; remove these parameters from the
function signature and then update its call sites at lines 463 and 468 to pass
only the remaining arguments (query, key, value, cache_seqlens) so the signature
and all calls match actual usage.
def convolution( | ||
N, | ||
C, | ||
H, | ||
W, | ||
F, | ||
K, | ||
S, | ||
D, | ||
P, | ||
block_M, | ||
block_N, | ||
block_K, | ||
num_stages, | ||
thread_num, | ||
enable_rasteration, | ||
dtype="float16", | ||
accum_dtype="float", | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Address the unused enable_rasteration
parameter.
Static analysis correctly identifies that enable_rasteration
is unused in the function body. The autotuner explores this dimension, but it has no effect on the kernel behavior. This wastes tuning time exploring a non-functional parameter.
Either implement the rasterization feature in the kernel (lines 129-166) or remove it from the autotune configuration. Run this script to verify the parameter is indeed unused:
🏁 Script executed:
#!/bin/bash
# Description: Verify enable_rasteration parameter usage in convolution function
# Search for enable_rasteration usage in the function body
ast-grep --pattern $'def convolution($$$):
$$$
enable_rasteration
$$$'
Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for enable_rasteration usage with context
rg -n 'enable_rasteration' -C3 examples/convolution/example_convolution_autotune.py
Length of output: 952
Remove or implement the unused enable_rasteration
parameter
enable_rasteration
is declared in convolution() and included in autotune configs but never referenced in the function body, causing wasted tuning iterations. Either implement its rasterization logic in the kernel or remove it (and related config entries) from the autotuner interface.
🧰 Tools
🪛 Ruff (0.14.0)
112-112: Unused function argument: enable_rasteration
(ARG001)
🤖 Prompt for AI Agents
In examples/convolution/example_convolution_autotune.py around lines 97 to 115,
the function signature includes an unused parameter enable_rasteration which is
also present in autotune configs but never referenced; either remove the
parameter and all related config entries to avoid wasted tuning iterations
(delete enable_rasteration from the function signature, remove any autotuner
search space entries or config keys that set it, and update call sites/tests) or
implement the rasterization flag inside the convolution body and kernel path
(add conditional logic to toggle rasterization code paths, ensure the kernel
respects the flag, and validate autotuner uses the implemented behavior).
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) | ||
kernel( | ||
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), | ||
q.view(C_SEQ_LEN, HQ, D), | ||
k.view(C_SEQ_LEN, H, D), | ||
v.view(C_SEQ_LEN, H, D), | ||
o_slc.view(C_SEQ_LEN, HQ, V), | ||
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), | ||
block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), | ||
token_indices.to(torch.int32)) | ||
block_counts.to(torch.int32).view(C_SEQ_LEN, H), | ||
offsets.to(torch.int32), | ||
token_indices.to(torch.int32), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix undefined D and wrong view shapes in varlen fwd call
D
is not defined; use K
for q/k last dim and V
for v. Also ensure offsets
/token_indices
exist.
- kernel(
- q.view(C_SEQ_LEN, HQ, D),
- k.view(C_SEQ_LEN, H, D),
- v.view(C_SEQ_LEN, H, D),
+ # Varlen path requires offsets/token_indices; ensure present
+ assert offsets is not None, "offsets must be provided for varlen forward"
+ kernel(
+ q.view(C_SEQ_LEN, HQ, K),
+ k.view(C_SEQ_LEN, H, K),
+ v.view(C_SEQ_LEN, H, V),
o_slc.view(C_SEQ_LEN, HQ, V),
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
- block_counts.to(torch.int32).view(C_SEQ_LEN, H),
- offsets.to(torch.int32),
- token_indices.to(torch.int32),
+ block_counts.to(torch.int32).view(C_SEQ_LEN, H),
+ offsets.to(torch.int32),
+ token_indices.to(torch.int32),
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) | |
kernel( | |
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D), | |
q.view(C_SEQ_LEN, HQ, D), | |
k.view(C_SEQ_LEN, H, D), | |
v.view(C_SEQ_LEN, H, D), | |
o_slc.view(C_SEQ_LEN, HQ, V), | |
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), | |
block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32), | |
token_indices.to(torch.int32)) | |
block_counts.to(torch.int32).view(C_SEQ_LEN, H), | |
offsets.to(torch.int32), | |
token_indices.to(torch.int32), | |
) | |
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device) | |
# Varlen path requires offsets/token_indices; ensure present | |
assert offsets is not None, "offsets must be provided for varlen forward" | |
kernel( | |
q.view(C_SEQ_LEN, HQ, K), | |
k.view(C_SEQ_LEN, H, K), | |
v.view(C_SEQ_LEN, H, V), | |
o_slc.view(C_SEQ_LEN, HQ, V), | |
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), | |
block_counts.to(torch.int32).view(C_SEQ_LEN, H), | |
offsets.to(torch.int32), | |
token_indices.to(torch.int32), | |
) |
🤖 Prompt for AI Agents
In examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py around lines
192-202, the kernel call uses an undefined D and incorrect view shapes; replace
D with K for q/k last dim and V for v last dim, change q.view(...) to
q.view(C_SEQ_LEN, HQ, K), k.view(...) to k.view(C_SEQ_LEN, H, K), v.view(...) to
v.view(C_SEQ_LEN, H, V), keep o_slc.view(C_SEQ_LEN, HQ, V) and ensure
block_indices/block_counts are cast as int32 with the shown shapes, and also
make sure offsets and token_indices are defined earlier in the function (or
computed) before calling kernel and passed as offsets.to(torch.int32) and
token_indices.to(torch.int32).
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
Outdated
Show resolved
Hide resolved
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) | ||
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. | ||
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) | ||
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix fullwidth parentheses in comment.
The comment contains ambiguous fullwidth parentheses ()
instead of ASCII ()
. This can cause confusion and is flagged by the linter.
Apply this diff to fix the comment:
- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
+ expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) | |
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) |
🧰 Tools
🪛 Ruff (0.14.0)
440-440: Comment contains ambiguous (
(FULLWIDTH LEFT PARENTHESIS). Did you mean (
(LEFT PARENTHESIS)?
(RUF003)
440-440: Comment contains ambiguous )
(FULLWIDTH RIGHT PARENTHESIS). Did you mean )
(RIGHT PARENTHESIS)?
(RUF003)
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
around line 440, the inline comment uses fullwidth parentheses '()'; replace
them with standard ASCII parentheses '()' so the comment reads using normal
ASCII punctuation to satisfy the linter and avoid ambiguity.
parser.add_argument("--batch", type=int, default=8, help="Batch size") | ||
parser.add_argument("--h", type=int, default=32, help="Number of heads") | ||
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") | ||
parser.add_argument("--d_head", type=int, default=64, help="Head dimension") | ||
parser.add_argument("--causal", type=bool, default=False, help="Causal flag") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argparse bool flag bug
type=bool is misleading (e.g., --causal False parses as True). Use store_true.
Apply:
- parser.add_argument("--causal", type=bool, default=False, help="Causal flag")
+ parser.add_argument("--causal", action="store_true", help="Enable causal attention")
Optionally add a negation:
+ parser.add_argument("--no-causal", dest="causal", action="store_false", help="Disable causal attention")
+ parser.set_defaults(causal=False)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
parser.add_argument("--batch", type=int, default=8, help="Batch size") | |
parser.add_argument("--h", type=int, default=32, help="Number of heads") | |
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") | |
parser.add_argument("--d_head", type=int, default=64, help="Head dimension") | |
parser.add_argument("--causal", type=bool, default=False, help="Causal flag") | |
parser.add_argument("--batch", type=int, default=8, help="Batch size") | |
parser.add_argument("--h", type=int, default=32, help="Number of heads") | |
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") | |
parser.add_argument("--d_head", type=int, default=64, help="Head dimension") | |
parser.add_argument("--causal", action="store_true", help="Enable causal attention") | |
parser.add_argument("--no-causal", dest="causal", action="store_false", help="Disable causal attention") | |
parser.set_defaults(causal=False) |
🤖 Prompt for AI Agents
In examples/flash_attention/example_mha_bwd.py around lines 354 to 358 the
argparse definition uses type=bool for the --causal flag which is unreliable
(e.g., "--causal False" still evaluates True); change this to use
action="store_true" (and set default=False) so providing the flag sets
causal=True, and optionally add a complementary "--no-causal" with
action="store_false" and dest="causal" to support explicit negation.
routed_expert_weights_shape = group_sum | ||
group_sizes_shape = n_routed_experts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify the shape definition changes.
The shape definitions changed from tuple expressions to scalar values. Ensure this change is intentional and doesn't affect the tensor shape semantics in the tilelang type system.
🏁 Script executed:
#!/bin/bash
# Check how these shapes are used in the kernel definition and if scalar vs tuple matters
rg -A 5 "routed_expert_weights_shape|group_sizes_shape" --type=py
Length of output: 1639
🏁 Script executed:
#!/bin/bash
# Locate definitions of group_sum and n_routed_experts
rg -nP "group_sum\s*=" -C3 examples/fusedmoe/example_fusedmoe_tilelang.py
rg -nP "n_routed_experts\s*=" -C3 examples/fusedmoe/example_fusedmoe_tilelang.py
Length of output: 995
🏁 Script executed:
#!/bin/bash
# Locate other shape assignments in the file for comparison
rg -nP "input_shape\s*=" -C3 examples/fusedmoe/example_fusedmoe_tilelang.py
rg -nP "routed_expert_gate_shape\s*=" -C3 examples/fusedmoe/example_fusedmoe_tilelang.py
rg -nP "intermediate_shape\s*=" -C3 examples/fusedmoe/example_fusedmoe_tilelang.py
Length of output: 1605
Critical: restore tuple shapes for 1D tensors
examples/fusedmoe/example_fusedmoe_tilelang.py:137-138
- Change
routed_expert_weights_shape = group_sum
→routed_expert_weights_shape = (group_sum,)
- Change
group_sizes_shape = n_routed_experts
→group_sizes_shape = (n_routed_experts,)
🤖 Prompt for AI Agents
In examples/fusedmoe/example_fusedmoe_tilelang.py around lines 137-138, the
shapes for 1D tensors were assigned scalars instead of tuples; change
routed_expert_weights_shape = group_sum to routed_expert_weights_shape =
(group_sum,) and change group_sizes_shape = n_routed_experts to
group_sizes_shape = (n_routed_experts,) so both shapes are 1-element tuples as
required.
Never mix multiple code styles, but be consistent with only one. You should revert this commit aca0569, since you may format the code with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my comment above.
65cd197
to
2a7a6c5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (10)
benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py (1)
1-1
: Remove or replace the blanket# ruff: noqa
directive on line 1 and address the E712 violation.The file-level noqa suppresses all ruff checks, preventing the PR's objective of enabling pyupgrade rules. Verification reveals an E712 violation at line 61 (
if mask_val == True:
should beif mask_val:
) that is currently being suppressed.Recommended approach:
- Remove the blanket
# ruff: noqa
from line 1- Either fix the E712 violation at line 61, or if needed for Triton compatibility, add a targeted suppression:
if mask_val == True: # noqa: E712
benchmark/matmul/benchmark_matmul_intrinsic.py (2)
304-306
: CLI flag is ignored:with_roller
is forcibly set to True.This overrides user intent and changes tuning behavior.
Apply:
- with_roller = args.with_roller - with_roller = True + with_roller = args.with_roller
316-319
: TFLOPS unit bug: factor should be 1e-12 (you’re printing GFLOPS as “TFlops”).Labels or math are off by 1e3.
Apply:
- print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best TFlops: {total_flops / best_latency * 1e-12:.3f}") @@ - print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") + print(f"Reference TFlops: {total_flops / ref_latency * 1e-12:.3f}")Alternatively, keep 1e-9 and change the labels to “GFlops”.
examples/bitnet-1.58b/tokenization_bitnet.py (1)
327-340
: Fix return type mismatch.The return type is declared as
tuple[str]
, but line 340 contains an early return with no value (implicitly returnsNone
). This will cause a type error.Apply this diff to fix the return type:
- def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: + def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str] | None:Alternatively, change the early return to be explicit:
if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") - return + return Noneexamples/deepseek_nsa/example_triton_nsa_bwd.py (3)
2-2
: Remove duplicate torch import.The
torch
module is imported twice (line 2 and line 6). Remove one of the duplicate imports.# ruff: noqa import torch from typing import Optional, Union from packaging.version import parse -import torch import triton import triton.language as tl
Also applies to: 6-6
116-162
: Remove the duplicateParallelNSAFunction
class definition at lines 116-162.The first
ParallelNSAFunction
definition is dead code—it gets overwritten by the second definition at lines 852-912 when the module loads. All callers across 5+ files pass 9 arguments matching the second definition's signature:(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
. The first definition has an incompatible signature(q, k, v, block_indices, block_size, scale, offsets)
and its backward method contains bugs (references undefinedctx.token_indices
, returns 11 values for 7 input parameters).Delete lines 116-162.
30-113
: Critical: Duplicate function definition causes the first implementation to be overwritten.The function
parallel_nsa_fwd_kernel
is defined twice in this file (lines 30-113 and lines 559-687). In Python, the second definition overwrites the first, making the first implementation dead code. The second implementation includes window_size handling (via theWS
parameter andif WS > 0:
logic) while the first does not. All callsites across the codebase (fwd_varlen.py, bwd.py, benchmark_nsa_fwd.py, fwd.py) will execute the second definition.Rename or remove one of these definitions based on your intended design.
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
46-52
: Dropped-heads and div‑by‑zero risks in head tiling
- Grid uses heads // min(block_H, kv_group_num). If heads % VALID_BLOCK_H != 0, the tail heads are skipped.
- cur_kv_head = by // (kv_group_num // block_H) divides by zero when heads < block_H.
Fix by asserting divisibility and computing groups_per_kv with VALID_BLOCK_H; then use it in both macros.
Apply this diff:
@@ - kv_group_num = heads // kv_head_num - VALID_BLOCK_H = min(block_H, kv_group_num) + kv_group_num = heads // kv_head_num + VALID_BLOCK_H = min(block_H, kv_group_num) assert kv_head_num == 1, "kv_head_num must be 1" + # Ensure safe grid shape (no skipped heads) and avoid div-by-zero when heads < block_H + assert heads % VALID_BLOCK_H == 0, ( + f"heads ({heads}) must be a multiple of VALID_BLOCK_H ({VALID_BLOCK_H})" + ) + groups_per_kv = max(1, kv_group_num // VALID_BLOCK_H) @@ - cur_kv_head = by // (kv_group_num // block_H) + cur_kv_head = by // groups_per_kv @@ - cur_kv_head = by // (kv_group_num // block_H) + cur_kv_head = by // groups_per_kvAlso applies to: 61-66, 75-76, 84-97, 100-115
examples/fusedmoe/example_fusedmoe_tilelang.py (1)
595-604
: main(): config keys don’t match the rest of the file (KeyError/behavioral mismatch).Code elsewhere expects underscored keys (e.g.,
d_hidden
,batch_size
). Fix the mapping.- config = { - "dhidden": d_hidden, - "dexpert": d_expert, - "nroutedexperts": n_routed_experts, - "nsharedexperts": n_shared_experts, - "nexpertspertoken": n_experts_per_token, - "bs": batch_size, - "seqlen": seq_len, - "seed": 81394, - } + config = { + "d_hidden": d_hidden, + "d_expert": d_expert, + "n_routed_experts": n_routed_experts, + "n_shared_experts": n_shared_experts, + "n_experts_per_token": n_experts_per_token, + "batch_size": batch_size, + "seq_len": seq_len, + "seed": 81394, + }examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (1)
221-223
: Remove @torch.compile from ParallelNSAFunction class; it breaks .apply() usage.The decorator replaces the class with a compiled callable, breaking the
.apply()
method used in 6 call sites across the codebase, including line 327 in this same file. Move the decorator to theparallel_nsa
wrapper function instead (defined at line 259).
♻️ Duplicate comments (11)
examples/flash_attention/example_mha_bwd.py (1)
358-358
: Pre-existing argparse bool flag bug remains unaddressed.This line still uses
type=bool
, which is unreliable (e.g.,--causal False
parses asTrue
). This issue was previously flagged but not fixed in this PR.Consider using
action="store_true"
instead:- parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + parser.add_argument("--causal", action="store_true", help="Enable causal attention")Or add explicit negation support:
- parser.add_argument("--causal", type=bool, default=False, help="Causal flag") + parser.add_argument("--causal", action="store_true", help="Enable causal attention") + parser.add_argument("--no-causal", dest="causal", action="store_false", help="Disable causal attention") + parser.set_defaults(causal=False)examples/convolution/example_convolution_autotune.py (2)
97-115
: The unusedenable_rasteration
parameter remains unaddressed.This parameter is still unused in the function body (lines 123-168), wasting tuning iterations. As noted in the previous review, either implement rasterization logic or remove the parameter from the signature and autotune configs.
171-183
: The unusedwith_roller
parameter remains unaddressed.This parameter is accepted via CLI (line 219) and passed to
main
but never used. Theget_configs()
function (line 28) doesn't accept it, unlike the GEMM example. As noted in the previous review, either implement roller support inget_configs()
or remove this parameter.examples/bitnet-1.58b/utils_quant.py (1)
228-229
: Same formatter compatibility concern as lines 176-178.This multi-line formatting change has the same compatibility concerns mentioned above. Consider reverting to maintain consistency with yapf.
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
368-369
: Remove unused parameters in ref_program_fa.The parameters
block_indices
,max_cache_seqlen
,num_blocks
, andblock_size
are never used in the function body.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
332-333
: Remove unused parameters in ref_program_torch.The parameters
max_cache_seqlen
andnum_blocks
are never used in the function body.
373-374
: Remove unused parameters in ref_program_fa.The parameters
block_indices
,max_cache_seqlen
,num_blocks
, andblock_size
are never used in the function body.examples/bitnet-1.58b/eval_gpu_memory.py (1)
8-9
: Pre-existing issue:--hf_path
argument still unused.The argument is defined but never parsed or applied (hardcoded path at line 39). This issue was flagged in previous reviews and remains unresolved.
examples/fusedmoe/example_fusedmoe_tilelang.py (1)
137-138
: Bug: 1D shapes must be tuples, not scalars.Passing scalars as shapes will break T.Tensor construction for 1D buffers. Restore 1‑element tuples.
- routed_expert_weights_shape = group_sum - group_sizes_shape = n_routed_experts + routed_expert_weights_shape = (group_sum,) + group_sizes_shape = (n_routed_experts,)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
441-441
: Fix fullwidth parentheses in comment.Replace the fullwidth
( )
with ASCII()
.Apply:
- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
121-123
: Use explicit 1‑D tuple for Bias shared shape.
Bias_shared_shape = block_N
is ambiguous; make it(block_N,)
to avoid shape bugs in alloc/copy and to match other examples.Apply:
- Bias_shared_shape = block_N + Bias_shared_shape = (block_N,)
🧹 Nitpick comments (33)
examples/blocksparse_attention/block_sparse_attn_triton.py (1)
201-213
: Address the unusedctx
parameter flagged by ruff.Ruff flags the
ctx
parameter as unused (ARG001). Since this PR enables stricter linting rules, consider one of these approaches:
- If
ctx
is required for API compatibility withtorch.autograd.Function
, rename it to_ctx
to signal it's intentionally unused- Add
# noqa: ARG001
if you plan to use it when implementing backward pass- Remove it if it's truly unnecessary
Based on static analysis hints.
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
115-115
: Minor: Unnecessary trailing comma with single argument.The trailing comma after
saved_model_path,
is valid but unusual for a single function argument. While some formatters add trailing commas for easier diffs, it's typically unnecessary in single-argument calls.Consider simplifying to:
- qmodel = (BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()) + qmodel = BitnetForCausalLM.from_quantized(saved_model_path).cuda().half()examples/deepseek_v32/fp8_lighting_indexer.py (2)
254-255
: Avoid hard‑coded CUDA device; derive from inputs.Hard-coding
device="cuda"
can break on CPU/MPS or mismatched devices. Reuse a single index tensor on the same device ascu_seqlen_*
(orkv
).- mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] + idx = torch.arange(seq_len_kv, device=cu_seqlen_ks.device) + mask_lo = idx.unsqueeze(0) >= cu_seqlen_ks.unsqueeze(1) + mask_hi = idx.unsqueeze(0) < cu_seqlen_ke.unsqueeze(1)
260-260
: Use dtype‑aware negative infinity.Guard against dtype changes and improve clarity.
- logits = logits.masked_fill(~mask, float("-inf")) + logits = logits.masked_fill(~mask, torch.finfo(logits.dtype).min)examples/flash_decoding/example_gqa_decode.py (1)
475-475
: Optional: Consider custom exception for better error handling.The static analyzer flags this line for TRY003—exception messages constructed with f-strings can make error handling less structured. However, for an example script, the current approach is clear and pragmatic.
Based on static analysis hints.
examples/convolution/example_convolution.py (1)
44-51
: Remove unused parameters or eliminate the hardcoded override.The
dtype
andaccum_dtype
parameters are immediately overridden on lines 50-51, rendering them useless. This creates a misleading API where callers might expect their arguments to take effect.Consider one of these solutions:
Solution 1 (preferred): Remove the unused parameters entirely.
@tilelang.jit(out_idx=[2]) def convolution( N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, - dtype="float16", - accum_dtype="float", ): KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 dtype = "float16" accum_dtype = "float"Solution 2: Remove the hardcoded overrides and use the parameters.
KH, KW = K, K OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 - dtype = "float16" - accum_dtype = "float" is_hopper = check_hopper()Note: The same pattern exists in
example_convolution_autotune.py
, which should also be addressed.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
415-415
: Remove unnecessaryint()
cast.In Python 3,
math.ceil()
already returns an integer, so the explicitint()
cast is redundant.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
334-334
: Consider prefixing unused variable with underscore.The
heads
variable is unpacked but never used in the function body. Consider prefixing it with an underscore to indicate it's intentionally unused.Apply this diff:
- batch, heads, dim = query.shape + batch, _heads, dim = query.shape
418-418
: Remove unnecessaryint()
cast.In Python 3,
math.ceil()
already returns an integer, so the explicitint()
cast is redundant.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py (1)
40-42
: Avoid hard‑coding "cuda"; pick device dynamically (CUDA vs HIP) and reuse tensor device.Helps this benchmark run on ROCm or any selected device without edits.
Apply:
@@ - q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) - k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) - v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device="cuda", dtype=torch.float16) + dev = "cuda" if torch.version.hip is None else "hip" + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device=dev, dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device=dev, dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device=dev, dtype=torch.float16) @@ - x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], - device="cuda", + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device=dev, dtype=torch.bfloat16) @@ - full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device="cuda")) + full_mask = torch.kron(block_mask.float(), torch.ones(BLOCK, BLOCK, device=q.device))Also applies to: 50-51, 58-58
benchmark/matmul/benchmark_matmul_intrinsic.py (1)
291-295
: Useaction="store_true"
for boolean CLI flags (Py3.8‑safe).
type=bool
is error‑prone; users must pass an explicit value.Apply:
- parser.add_argument( - "--with_roller", - type=bool, - default=False, - help="Whether to use roller to deduce search spaces", - ) + parser.add_argument( + "--with_roller", + action="store_true", + help="Whether to use roller to deduce search spaces", + )examples/flash_attention/example_gqa_bwd_tma_reduce.py (1)
1-600
: Logic preserved, but review burden is high.The formatting changes do not alter the semantics of the flash attention kernels, backward passes, or the test harness. However, the sheer volume of cosmetic changes (200+ lines touched) makes it difficult to spot any unintended modifications or regressions.
For future PRs, consider separating pure linting/formatting changes from functional changes to improve reviewability.
examples/dynamic_shape/example_dynamic.py (4)
56-70
: Avoid mixing formatters; revert formatting-only signature change.This is a style-only reflow (trailing comma + multiline). Given maintainers asked not to mix ruff/black formatting with yapf, please revert to the previous single‑line or yapf‑produced layout and keep this PR focused on UP/FA102 linting.
Example revert:
-def matmul_dynamic( - M, - N, - K, - block_M, - block_N, - block_K, - trans_A, - trans_B, - in_dtype, - out_dtype, - accum_dtype, - num_stages, - threads, -): +def matmul_dynamic(M, N, K, block_M, block_N, block_K, trans_A, trans_B, + in_dtype, out_dtype, accum_dtype, num_stages, threads):
74-85
: Call-site reflow is also formatter-driven; prefer reverting to repo’s style.No logic changes. To avoid formatter churn, align this call with the repository’s chosen formatter (yapf) or revert along with the signature above.
131-145
: Formatting-only call change; optional usability tweak while here.
- Re the formatter: same note—please revert to the repo’s style to avoid mixing tools.
- Optional: these defaults (16384^2) can OOM many GPUs. Consider simple CLI args to override sizes for local runs.
Example minimal tweak:
-def main(): - M, N, K = 16384, 16384, 16384 +def main(): + import os + M = int(os.getenv("TL_M", 16384)) + N = int(os.getenv("TL_N", 16384)) + K = int(os.getenv("TL_K", 16384))
109-110
: Prefergetattr(torch, out_dtype)
overtorch.__getattribute__(out_dtype)
.Same behavior, clearer and conventional.
- C = C.to(torch.__getattribute__(out_dtype)) + C = C.to(getattr(torch, out_dtype))examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
344-351
: Tolerances may need adjustment with fast‑mathIf fast‑math is kept on by default, expose rtol/atol as CLI args to avoid fragile test runs on different GPUs.
Apply this diff:
@@ - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() @@ parser.add_argument("--autotune", action="store_true", help="auto tune") + parser.add_argument("--rtol", type=float, default=0.01) + parser.add_argument("--atol", type=float, default=0.01) @@ - torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01) + torch.testing.assert_close(tilelang_output, ref_output, rtol=args.rtol, atol=args.atol)examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py (3)
21-28
: Redundant castattn_bias is already created with dtype=query.dtype; attn_bias.to(query.dtype) is a no-op and the result is unused. Remove it.
Apply this diff:
- attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) @@ - attn_bias.to(query.dtype)
323-339
: Make temp buffers explicit in dtype/deviceo and attn_logits rely on global defaults (set by callers). Make them explicit to avoid surprises when these helpers are reused.
Apply this diff:
def run_flash_mla_triton( @@ ): @@ - def flash_mla_triton(): + def flash_mla_triton(): num_kv_splits = 32 - o = torch.empty([b * s_q, h_q, dv]) - attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + o = torch.empty([b * s_q, h_q, dv], dtype=dtype, device=q.device) + attn_logits = torch.empty( + [b * s_q, h_q, num_kv_splits, dv + 1], dtype=dtype, device=q.device + )Also applies to: 349-366
340-347
: Unused variableblocked_v is assigned but never used. Remove or use; currently it’s dead code.
examples/fusedmoe/example_fusedmoe_tilelang.py (2)
7-8
: Avoid star imports; make ruff happy and improve readability.Replace
*
with explicit imports; remove unused autotuner import if not needed.As per static analysis hints.
-from tilelang.autotuner import * -from example_fusedmoe_torch import * +# Remove if unused across this file: +# from tilelang.autotuner import autotune # or delete entirely if not used +from example_fusedmoe_torch import generate_input, ref_kernel, clone_dataIf
autotuner
symbols are used elsewhere, import them explicitly; otherwise delete the line.
548-551
: Docstring return type is incorrect (no tuple returned).The function returns a single tensor but docstring claims a tuple.
- Returns: - Tuple containing: - - output: Processed tensor [batch_size, seq_len, d_model] + Returns: + torch.Tensor: Processed tensor [batch_size, seq_len, d_hidden]examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (4)
10-11
: Fix scale handling: remove unusedscale
from helper to avoid future double-scaling.
_tir_u8_to_f4_to_bf16
acceptsscale
but does not use it; the simple path already multiplies by2**scale
. Keepingscale
in the helper invites accidental double-application later and the docstring is misleading.- Make the helper scale-free and keep the external multiply. Update docstrings accordingly.
Apply:
-def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): +def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): """ - Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. + Convert a 4-bit field packed in a uint8 into a bfloat16 value (no scaling). - This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its - bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by - `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. + This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its + bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, and assembles the corresponding + bfloat16 representation. - Parameters: + Parameters: nbit (int): Number of bits in the packed field (must be 4). val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). - scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). dtype (str): Destination dtype string (must be "bfloat16"). @@ - - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. + - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". @@ - # Scale is the exponential part, within the representation of uint8 - # To handle the overflow, we may use the min function to limit the exponential part to 8 bits - # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) + # Scaling is applied by callers (e.g., multiply by 2**scale) to keep both fast/simple paths consistent. @@ - B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( - num_bits, - B_local[i, j // num_elems_per_byte], - j % num_elems_per_byte, - Scale_shared[ - i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 - dtype=out_dtype, - ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) + B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( + num_bits, + B_local[i, j // num_elems_per_byte], + j % num_elems_per_byte, + dtype=out_dtype, + ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))Also applies to: 13-32, 44-45, 305-314
139-141
: Doc/code mismatch: scaling factor uses 2Scale (not 2**(Scale-127)).**References in comments/docstrings mention a -127 bias, but both paths multiply by
2**Scale
viaT.shift_left(1, Scale)
, and the reference programs also use2**Scale
. Please align the docs with the actual behavior (remove “- 127”).Also applies to: 244-258, 445-446
362-369
: Avoid unused shared buffer when with_bias=True.You now copy Bias directly into
C_local
;Bias_shared
is allocated/annotated but unused. Drop the allocation/annotation whenwith_bias=True
, or restore the previous Bias_shared → C_local copy to justify its existence.Also applies to: 341-357
160-166
: Minor: clarify intrinsic output dtype argument.
get_mxfp_intrin_group(out_dtype=in_dtype, ...)
works becausein_dtype == out_dtype
in callers, but passingout_dtype=out_dtype
is clearer and reduces future coupling.Also applies to: 167-171
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
265-267
: Normalize 1‑D shapes to tuples for clarity and safety.These 1‑D shapes are provided as bare ints; prefer
(dim,)
for consistency with TVM/TileLang conventions.Apply:
- topk_weights: T.Tensor((M * topk), out_dtype), - sorted_token_ids: T.Tensor((padding_M), "int32"), - expert_ids: T.Tensor((padding_M // block_M), "int32"), + topk_weights: T.Tensor((M * topk,), out_dtype), + sorted_token_ids: T.Tensor((padding_M,), "int32"), + expert_ids: T.Tensor((padding_M // block_M,), "int32"), @@ - topk_weights_shared = T.alloc_shared((block_M), out_dtype) - sorted_token_ids_shared = T.alloc_shared((block_M), "int32") + topk_weights_shared = T.alloc_shared((block_M,), out_dtype) + sorted_token_ids_shared = T.alloc_shared((block_M,), "int32")Also applies to: 280-283
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (3)
26-28
: Broaden autotune cache key to avoid config reuse across shapes/flags.Include dims/flags that materially affect runtime (H/HQ/G/K/V/S, WS, USE_*). This prevents mis-tuned configs being reused for different shapes.
@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=["BS", "BK", "BV"], + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["BS", "BK", "BV", "H", "HQ", "G", "K", "V", "S", "WS", "USE_OFFSETS", "USE_BLOCK_COUNTS"], )
283-301
: Docstring fixes: block_counts shape with head_first, scale type, and typo.Clarify shapes when head_first=True and fix type/typo.
- g_swa (torch.Tensor): - Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. @@ - block_counts (Union[torch.LongTensor, int]): + block_counts (Union[torch.LongTensor, int]): Number of selected blocks for each token. - If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`, + If a tensor is provided: shape `[B, T, H]` if `head_first=False`, else `[B, H, T]`. each token can select the same number of blocks. If not provided, it will default to `S`, Default: `None` @@ - scale (Optional[int]): + scale (Optional[float]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
2-6
: Minor: remove duplicateimport torch
.Harmless, but worth cleaning even with
# ruff: noqa
.-import torch @@ -import torch +import torchexamples/cast/example_triton_cast_to_fp8.py (1)
135-140
: Guard againstdtype=None
.
torch.finfo(dtype)
will crash if callers omitdtype
. Either set a safe default or assert non-None.@@ - finfo = torch.finfo(dtype) + if dtype is None: + # Only e4m3fn supported per docstring. + dtype = torch.float8_e4m3fn + finfo = torch.finfo(dtype)examples/cast/example_per_token_cast_to_fp8.py (1)
32-38
: Potential ZeroDivisionError in thread mapping for smallblk_m
.
(blk_m // 4)
becomes 0 whenblk_m < 4
, makingi // (blk_m // 4)
invalid. Add a precondition.@@ def per_token_cast_to_fp8(M, N, blk_m): - @T.prim_func + assert blk_m % 4 == 0 and blk_m >= 4, "blk_m must be a multiple of 4 (>=4) for thread mapping." + @T.prim_func def per_token_cast(examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
45-47
: Sameblk_m // 4
division risk.Protect against
blk_m < 4
as in the other example.@@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): - @T.prim_func + assert blk_m % 4 == 0 and blk_m >= 4, "blk_m must be a multiple of 4 (>=4) for thread mapping." + @T.prim_func def group_per_split_token_cast(
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (107)
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
(3 hunks)benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
(4 hunks)benchmark/matmul/benchmark_matmul.py
(1 hunks)benchmark/matmul/benchmark_matmul_intrinsic.py
(1 hunks)benchmark/matmul/benchmark_matmul_sp.py
(3 hunks)benchmark/matmul_fp8/benchmark_matmul.py
(0 hunks)docs/conf.py
(2 hunks)examples/amd/example_amd_flash_attn_bwd.py
(3 hunks)examples/amd/example_amd_flash_attn_fwd.py
(7 hunks)examples/analyze/example_conv_analyze.py
(2 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(10 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(6 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(9 hunks)examples/bitnet-1.58b/benchmark_generate.py
(3 hunks)examples/bitnet-1.58b/benchmark_inference_latency.py
(3 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(2 hunks)examples/bitnet-1.58b/eval_correctness.py
(3 hunks)examples/bitnet-1.58b/eval_gpu_memory.py
(2 hunks)examples/bitnet-1.58b/eval_ppl.py
(3 hunks)examples/bitnet-1.58b/eval_utils.py
(1 hunks)examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
(2 hunks)examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
(0 hunks)examples/bitnet-1.58b/load_from_quantized.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(3 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(27 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(9 hunks)examples/bitnet-1.58b/utils_quant.py
(2 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(22 hunks)examples/bitnet-1.58b/vllm_workspace/utils.py
(3 hunks)examples/blocksparse_attention/block_sparse_attn_triton.py
(9 hunks)examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
(5 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
(17 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(15 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(16 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
(12 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
(12 hunks)examples/blocksparse_attention/heuristic.py
(1 hunks)examples/blocksparse_gemm/example_blocksparse_gemm.py
(5 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(7 hunks)examples/cast/example_per_token_cast_to_fp8.py
(7 hunks)examples/cast/example_triton_cast_to_fp8.py
(3 hunks)examples/convolution/example_convolution.py
(3 hunks)examples/convolution/example_convolution_autotune.py
(5 hunks)examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
(4 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
(7 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
(9 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
(9 hunks)examples/deepseek_mla/benchmark_mla.py
(13 hunks)examples/deepseek_mla/example_mla_decode.py
(7 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(12 hunks)examples/deepseek_mla/example_mla_decode_persistent.py
(4 hunks)examples/deepseek_mla/example_mla_decode_ws.py
(14 hunks)examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
(3 hunks)examples/deepseek_mla/torch_refs.py
(1 hunks)examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
(28 hunks)examples/deepseek_nsa/example_tilelang_nsa_bwd.py
(4 hunks)examples/deepseek_nsa/example_tilelang_nsa_decode.py
(3 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd.py
(2 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
(7 hunks)examples/deepseek_nsa/example_triton_nsa_bwd.py
(22 hunks)examples/deepseek_nsa/example_triton_nsa_fwd.py
(8 hunks)examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
(12 hunks)examples/deepseek_nsa/reference.py
(9 hunks)examples/deepseek_v32/fp8_lighting_indexer.py
(3 hunks)examples/deepseek_v32/sparse_mla_bwd.py
(12 hunks)examples/deepseek_v32/sparse_mla_fwd.py
(6 hunks)examples/deepseek_v32/sparse_mla_fwd_pipelined.py
(13 hunks)examples/deepseek_v32/topk_selector.py
(3 hunks)examples/deepseek_v32/utils.py
(4 hunks)examples/dequantize_gemm/dequantize_utils.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(7 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
(8 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(9 hunks)examples/dequantize_gemm/example_dequant_gemm_w4a8.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
(5 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(14 hunks)examples/dynamic_shape/example_dynamic.py
(2 hunks)examples/elementwise/example_elementwise_add.py
(2 hunks)examples/elementwise/example_elementwise_add_tma_1d.py
(1 hunks)examples/flash_attention/bert_padding.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(12 hunks)examples/flash_attention/example_gqa_bwd_tma_reduce.py
(12 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(12 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(5 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_bwd.py
(5 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_varlen.py
(4 hunks)examples/flash_decoding/example_gqa_decode.py
(11 hunks)examples/flash_decoding/example_mha_inference.py
(9 hunks)examples/fusedmoe/example_fusedmoe_tilelang.py
(16 hunks)examples/fusedmoe/example_fusedmoe_torch.py
(7 hunks)examples/fusedmoe/test_example_fusedmoe.py
(1 hunks)examples/gdn/example_chunk_delta_bwd.py
(10 hunks)examples/gdn/example_chunk_delta_h.py
(9 hunks)examples/gdn/example_chunk_o.py
(5 hunks)
⛔ Files not processed due to max files limit (36)
- examples/gdn/example_chunk_o_bwd.py
- examples/gdn/example_chunk_scaled_dot_kkt.py
- examples/gdn/example_cumsum.py
- examples/gdn/example_wy_fast.py
- examples/gdn/example_wy_fast_bwd_split.py
- examples/gdn/test_example_gdn_compilation.py
- examples/gdn/utils.py
- examples/gemm/example_gemm_autotune.py
- examples/gemm/example_gemm_intrinsics.py
- examples/gemm/example_gemm_persistent.py
- examples/gemm_fp8/example_tilelang_gemm_amd.py
- examples/gemm_fp8/example_tilelang_gemm_fp8.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
- examples/gemm_sm100/gemm_mma.py
- examples/gemm_sm100/gemm_tcgen5mma.py
- examples/gemm_sp/example_gemm_sp.py
- examples/gemm_splitk/example_tilelang_gemm_splitk.py
- examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
- examples/gemm_streamk/example_tilelang_gemm_streamk.py
- examples/grouped_gemm/example_grouped_gemm_bwd.py
- examples/grouped_gemm/example_grouped_gemm_fwd.py
- examples/hadamard_transform/example_hadamard.py
- examples/linear_attention/example_linear_attn_bwd.py
- examples/linear_attention/example_linear_attn_fwd.py
- examples/linear_attention/example_mamba_chunk_scan.py
- examples/linear_attention/example_mamba_chunk_state.py
- examples/linear_attention/example_retention_fwd.py
- examples/minference/example_vertical_slash_sparse_attn.py
- examples/norm/rms_norm.py
- examples/online_softmax/online_softmax.py
- examples/plot_layout/fragment_mma_load_a.py
- examples/seer_attention/block_sparse_attn_tilelang.py
- examples/seer_attention/block_sparse_attn_triton.py
- examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py
- examples/topk/example_topk.py
💤 Files with no reviewable changes (2)
- benchmark/matmul_fp8/benchmark_matmul.py
- examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
✅ Files skipped from review due to trivial changes (11)
- examples/amd/example_amd_flash_attn_fwd.py
- examples/gdn/example_chunk_o.py
- examples/deepseek_nsa/reference.py
- examples/flash_attention/bert_padding.py
- examples/bitnet-1.58b/eval_utils.py
- examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
- examples/dequantize_gemm/example_dequant_gemm_w4a8.py
- examples/fusedmoe/test_example_fusedmoe.py
- examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
- examples/deepseek_v32/topk_selector.py
- benchmark/matmul/benchmark_matmul.py
🚧 Files skipped from review as they are similar to previous changes (40)
- examples/elementwise/example_elementwise_add.py
- examples/deepseek_mla/example_mla_decode.py
- examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
- examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
- docs/conf.py
- examples/deepseek_mla/example_mla_decode_ws.py
- examples/deepseek_mla/torch_refs.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
- examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
- examples/deepseek_v32/sparse_mla_fwd.py
- benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
- examples/dequantize_gemm/dequantize_utils.py
- examples/deepseek_nsa/example_tilelang_nsa_bwd.py
- examples/deepseek_mla/benchmark_mla.py
- examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
- examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
- examples/flash_attention/example_mha_fwd_varlen.py
- examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
- examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
- examples/gdn/example_chunk_delta_h.py
- examples/flash_decoding/example_mha_inference.py
- examples/deepseek_v32/sparse_mla_bwd.py
- examples/bitnet-1.58b/benchmark_inference_latency.py
- examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
- examples/deepseek_nsa/example_triton_nsa_fwd.py
- examples/bitnet-1.58b/eval_ppl.py
- examples/deepseek_nsa/example_tilelang_nsa_decode.py
- examples/bitnet-1.58b/eval_correctness.py
- benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
- examples/blocksparse_attention/heuristic.py
- examples/flash_attention/example_mha_bwd_bhsd.py
- examples/deepseek_nsa/example_tilelang_nsa_fwd.py
- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
- examples/bitnet-1.58b/load_from_quantized.py
- examples/flash_attention/example_gqa_bwd.py
- examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
- examples/bitnet-1.58b/benchmark_generate.py
- examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
🧰 Additional context used
🧬 Code graph analysis (42)
examples/elementwise/example_elementwise_add_tma_1d.py (1)
tilelang/language/parallel.py (1)
Parallel
(10-30)
examples/flash_attention/example_mha_fwd_bshd.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/blocksparse_attention/example_tilelang_block_sparse_attn.py (2)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
from_quantized
(1501-1579)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py (1)
benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py (1)
get_sparse_attn_mask_from_topk
(14-26)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py (2)
tilelang/language/tir/op.py (1)
reinterpret
(1898-1917)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
matmul
(49-354)
examples/bitnet-1.58b/eval_gpu_memory.py (2)
examples/bitnet-1.58b/benchmark_generate.py (1)
profile
(54-74)examples/bitnet-1.58b/modeling_bitnet.py (1)
_post_process_weights
(1488-1492)
examples/flash_attention/example_mha_fwd_bhsd.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/cast/example_group_per_split_token_cast_to_fp8.py (2)
tilelang/language/tir/op.py (1)
ceildiv
(3116-3133)examples/cast/example_per_token_cast_to_fp8.py (2)
ref_program
(81-91)ceil_div
(67-78)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py (1)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py (6)
run_torch_mla
(35-73)run_flash_mla_triton
(327-373)flash_mla_triton
(352-369)mla_decode_triton
(292-323)compare_a
(458-505)compare_ab
(382-455)
examples/blocksparse_gemm/example_blocksparse_gemm.py (2)
tilelang/jit/kernel.py (1)
params
(475-476)tilelang/engine/param.py (1)
KernelParam
(13-105)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py (1)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/convolution/example_convolution.py (1)
examples/convolution/example_convolution_autotune.py (1)
convolution
(97-168)
examples/deepseek_mla/example_mla_decode_paged.py (2)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/jit/__init__.py (1)
jit
(243-319)
examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py (3)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)examples/flash_attention/example_gqa_bwd.py (1)
flashattn_bwd_atomic_add
(156-252)
examples/flash_attention/example_mha_bwd.py (3)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/profiler/__init__.py (1)
do_bench
(220-283)
examples/fusedmoe/example_fusedmoe_torch.py (1)
examples/fusedmoe/example_fusedmoe_tilelang.py (3)
forward
(317-320)forward
(333-338)forward
(432-535)
examples/cast/example_per_token_cast_to_fp8.py (1)
tilelang/language/tir/op.py (1)
ceildiv
(3116-3133)
examples/deepseek_mla/example_mla_decode_persistent.py (1)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/analyze/example_conv_analyze.py (2)
examples/gemm/example_gemm_autotune.py (1)
kernel
(110-150)examples/analyze/example_gemm_analyze.py (1)
kernel
(10-46)
examples/amd/example_amd_flash_attn_bwd.py (2)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
main
(462-533)examples/flash_attention/example_gqa_bwd.py (2)
main
(479-539)run1
(529-530)
examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py (2)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(26-107)parallel_nsa
(232-308)examples/deepseek_nsa/example_triton_nsa_bwd.py (3)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa_fwd_kernel
(559-686)parallel_nsa
(914-990)
examples/deepseek_nsa/example_triton_nsa_bwd.py (2)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(26-107)parallel_nsa
(232-308)examples/deepseek_nsa/example_triton_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa
(238-314)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (2)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (5)
flashattn
(31-201)ref_program
(206-252)main
(135-199)main
(385-457)gen_inputs
(371-382)tilelang/language/allocate.py (1)
alloc_local
(39-50)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (3)
examples/deepseek_nsa/example_triton_nsa_bwd.py (3)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa_fwd_kernel
(559-686)parallel_nsa
(914-990)examples/deepseek_nsa/example_triton_nsa_fwd.py (2)
parallel_nsa_fwd_kernel
(30-113)parallel_nsa
(238-314)examples/deepseek_nsa/reference.py (1)
naive_nsa
(9-172)
examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
get_output_embeddings
(1250-1251)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (6)
flashattn
(35-207)ref_program
(212-257)main
(141-205)main
(394-473)gen_inputs
(379-391)triton_program
(349-376)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
tilelang/math/__init__.py (1)
next_power_of_2
(1-2)examples/deepseek_v32/sparse_mla_fwd.py (1)
sparse_mla_fwd
(15-173)
examples/bitnet-1.58b/modeling_bitnet.py (1)
examples/bitnet-1.58b/configuration_bitnet.py (1)
BitnetConfig
(29-194)
examples/gdn/example_chunk_delta_bwd.py (2)
tilelang/language/copy.py (1)
copy
(16-94)examples/gdn/example_chunk_delta_h.py (1)
prepare_input
(39-67)
examples/fusedmoe/example_fusedmoe_tilelang.py (4)
tilelang/jit/__init__.py (1)
jit
(243-319)tilelang/language/copy.py (1)
copy
(16-94)examples/fusedmoe/example_fusedmoe_torch.py (3)
forward
(21-24)forward
(37-42)forward
(56-67)tilelang/jit/kernel.py (1)
JITKernel
(27-503)
examples/flash_attention/example_gqa_fwd_bshd.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (1)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
ref_program
(193-239)
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
tilelang/language/tir/op.py (2)
call_extern
(173-195)address_of
(464-480)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/language/allocate.py (1)
alloc_local
(39-50)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
ref_program
(414-459)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/tir/op.py (1)
if_then_else
(2906-2936)src/op/gemm.h (4)
GemmWarpPolicy
(75-98)GemmWarpPolicy
(79-83)GemmWarpPolicy
(85-89)GemmWarpPolicy
(91-97)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/tir/op.py (1)
if_then_else
(2906-2936)tilelang/language/__init__.py (1)
symbolic
(88-99)
examples/convolution/example_convolution_autotune.py (2)
examples/gemm/example_gemm_autotune.py (3)
get_configs
(22-105)main
(122-148)main
(245-269)examples/convolution/example_convolution.py (4)
convolution
(29-99)main
(18-23)main
(55-97)main
(102-138)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
matmul
(86-381)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group
(52-109)
examples/flash_decoding/example_gqa_decode.py (3)
examples/gemm/example_gemm_autotune.py (1)
get_heuristic_config
(165-199)examples/convolution/example_convolution_autotune.py (1)
get_heuristic_config
(58-92)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
🪛 Ruff (0.14.0)
examples/deepseek_mla/example_mla_decode_paged.py
315-315: Unused function argument: block_table
(ARG001)
318-318: Unused function argument: block_size
(ARG001)
examples/bitnet-1.58b/modeling_bitnet.py
379-379: Unused method argument: use_cache
(ARG002)
381-381: Unused method argument: kwargs
(ARG002)
532-532: Unused method argument: use_cache
(ARG002)
534-534: Unused method argument: kwargs
(ARG002)
620-620: Unused method argument: use_cache
(ARG002)
622-622: Unused method argument: kwargs
(ARG002)
examples/fusedmoe/example_fusedmoe_tilelang.py
7-7: from tilelang.autotuner import *
used; unable to detect undefined names
(F403)
8-8: from example_fusedmoe_torch import *
used; unable to detect undefined names
(F403)
examples/blocksparse_attention/block_sparse_attn_triton.py
202-202: Unused function argument: ctx
(ARG001)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
332-332: Unused function argument: max_cache_seqlen
(ARG001)
332-332: Unused function argument: num_blocks
(ARG001)
334-334: Unpacked variable heads
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
373-373: Unused function argument: block_indices
(ARG001)
373-373: Unused function argument: max_cache_seqlen
(ARG001)
373-373: Unused function argument: num_blocks
(ARG001)
374-374: Unused function argument: block_size
(ARG001)
418-418: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
368-368: Unused function argument: block_indices
(ARG001)
368-368: Unused function argument: max_cache_seqlen
(ARG001)
368-368: Unused function argument: num_blocks
(ARG001)
369-369: Unused function argument: block_size
(ARG001)
415-415: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/convolution/example_convolution_autotune.py
112-112: Unused function argument: enable_rasteration
(ARG001)
182-182: Unused function argument: with_roller
(ARG001)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
441-441: Comment contains ambiguous (
(FULLWIDTH LEFT PARENTHESIS). Did you mean (
(LEFT PARENTHESIS)?
(RUF003)
441-441: Comment contains ambiguous )
(FULLWIDTH RIGHT PARENTHESIS). Did you mean )
(RIGHT PARENTHESIS)?
(RUF003)
examples/flash_decoding/example_gqa_decode.py
475-475: Avoid specifying long messages outside the exception class
(TRY003)
query = (query.transpose(1, 2).contiguous().unsqueeze(3) | ||
) # align with the original function's interface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in comment.
Same typo: "function's interface" needs a space.
Apply this diff:
- query = (query.transpose(1, 2).contiguous().unsqueeze(3)
- ) # align with the original function's interface
+ query = (query.transpose(1, 2).contiguous().unsqueeze(3)
+ ) # align with the original function's interface
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/attention_sink/example_mha_sink_bwd_bhsd.py around lines 417 to 418,
the inline comment has a missing space between "function's" and "interface";
update the comment so it reads "...align with the original function's interface"
(i.e., insert the missing space) to correct the typo.
query = (query.transpose(1, 2).contiguous().unsqueeze(3) | ||
) # align with the original function'sinterface |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in comment.
The comment has a typo: "function'sinterface" should be "function's interface" (missing space).
Apply this diff:
- query = (query.transpose(1, 2).contiguous().unsqueeze(3)
- ) # align with the original function'sinterface
+ query = (query.transpose(1, 2).contiguous().unsqueeze(3)
+ ) # align with the original function's interface
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
query = (query.transpose(1, 2).contiguous().unsqueeze(3) | |
) # align with the original function'sinterface | |
query = (query.transpose(1, 2).contiguous().unsqueeze(3) | |
) # align with the original function's interface |
🤖 Prompt for AI Agents
In examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py around
lines 214 to 215, there's a typo in the inline comment: "function'sinterface"
should be corrected to "function's interface"; update the comment text
accordingly to add the missing space so it reads "...align with the original
function's interface".
return out_string | ||
|
||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: | ||
def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add missing type hint for save_directory
parameter.
The save_directory
parameter is missing a type hint. For consistency and type safety, it should be annotated as str
.
Apply this diff:
- def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]:
+ def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: | |
def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]: |
🤖 Prompt for AI Agents
In examples/bitnet-1.58b/tokenization_bitnet.py around line 327, the function
signature for save_vocabulary is missing a type annotation for the
save_directory parameter; update the signature to annotate save_directory as str
(i.e., save_directory: str) and ensure any references or type checks within the
function remain consistent with the new annotation.
@@ -1,9 +1,10 @@ | |||
from __future__ import annotations | |||
from typing import Dict, List, Tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Update type aliases to use built-in generics for consistency.
The function signatures (lines 7, 29-30) now use built-in generics (list[...]
), but the type aliases still use typing.List
, typing.Tuple
, and typing.Dict
. For consistency and full pyupgrade compliance, update the type aliases to use lowercase built-in types.
Apply this diff:
-from typing import Dict, List, Tuple
+from __future__ import annotations
-TokensText = Tuple[List[int], str]
+TokensText = tuple[list[int], str]
And similarly for line 25:
-TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]]
+TokensTextLogprobs = tuple[list[int], str, list[dict[int, float]]]
After these changes, the typing
import on line 2 can be removed entirely since all generics will use built-in types.
Also applies to: 4-4, 25-25
🤖 Prompt for AI Agents
In examples/bitnet-1.58b/vllm_workspace/utils.py around lines 2–30, update the
type alias declarations that currently use typing.List, typing.Tuple, and
typing.Dict to use the built-in generics (list[...], tuple[...], dict[...]) for
consistency with the function signatures on lines 7 and 29–30 and the usages on
lines 4 and 25; after converting all type aliases and any other typing.* usages
to built-in generics, remove the now-unused import of typing from line 2.
if USE_OFFSETS: | ||
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + | ||
1).to(tl.int32) | ||
i_n, i_t = ( | ||
tl.load(token_indices + i_t * 2).to(tl.int32), | ||
tl.load(token_indices + i_t * 2 + 1).to(tl.int32), | ||
) | ||
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) | ||
T = eos - bos |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Guard against OOB: ensure grid.x matches total tokens and types are correct when varlen.
If offsets are provided, T (grid.x) must equal offsets[-1]. Add pre-call assertions and make token_indices contiguous to avoid bad pointer math.
def parallel_nsa_fwd(
@@
- grid = (T, NV, B * H)
+ if offsets is not None:
+ # offsets must be 1D Long on the same device, and total tokens must match T
+ assert offsets.ndim == 1, "offsets must be 1D [N+1]"
+ assert offsets.dtype == torch.long, "offsets must be torch.long"
+ assert int(offsets[-1].item()) == T, "sum of sequence lengths (offsets[-1]) must equal T"
+ assert offsets.device == q.device, "offsets must be on the same device as inputs"
+ grid = (T, NV, B * H)
And ensure contiguity/device for token_indices at creation:
class ParallelNSAFunction(torch.autograd.Function):
@@
- token_indices = prepare_token_indices(offsets) if offsets is not None else None
+ token_indices = (prepare_token_indices(offsets).contiguous()
+ if offsets is not None else None)
Committable suggestion skipped: line range outside the PR's diff.
BI = block_I | ||
NI = tilelang.cdiv(topk, block_I) | ||
assert NI % 2 == 0, 'NI should be a multiple of 2' | ||
assert NI % 2 == 0, "NI should be a multiple of 2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
New runtime constraint added.
This assertion requires NI % 2 == 0
(where NI = tilelang.cdiv(topk, block_I)
), which is a new constraint not present in the original code. This could break existing code that calls this function with topk
and block_I
combinations where NI
is odd.
Verify that all existing call sites satisfy this new constraint:
🏁 Script executed:
#!/bin/bash
# Description: Find all call sites of sparse_mla_fwd and check if they satisfy NI % 2 == 0
# Search for calls to sparse_mla_fwd function
rg -nP '\bsparse_mla_fwd\s*\(' --type=py -C3
Length of output: 1878
---
🏁 Script executed:
```shell
#!/bin/bash
# Extract the call site and surrounding context to see what parameters are passed
sed -n '360,385p' examples/deepseek_v32/sparse_mla_fwd_pipelined.py
Length of output: 870
🏁 Script executed:
#!/bin/bash
# Find where block_I is defined in the pipelined file
rg -nP '\bblock_I\s*=' examples/deepseek_v32/sparse_mla_fwd_pipelined.py -B2 -A2
Length of output: 381
🏁 Script executed:
#!/bin/bash
# Find all topk references in context around the function call
rg -nP '\btopk\b' examples/deepseek_v32/sparse_mla_fwd_pipelined.py -B1 -A1
Length of output: 1670
🏁 Script executed:
#!/bin/bash
# Also check the original sparse_mla_fwd.py for comparison and to see if it has this assertion
sed -n '65,80p' examples/deepseek_v32/sparse_mla_fwd.py
Length of output: 572
🏁 Script executed:
#!/bin/bash
# Check if there's a git diff to see what was added
git diff HEAD examples/deepseek_v32/sparse_mla_fwd_pipelined.py | head -150
Length of output: 0
🏁 Script executed:
#!/bin/bash
# Search for context on NI computation in both files
rg -nP '\bNI\s*=' examples/deepseek_v32/sparse_mla_fwd*.py -B1 -A1
Length of output: 467
---
🏁 Script executed:
```shell
#!/bin/bash
# Search for any imports or calls to sparse_mla_fwd_pipelined from outside its own file
rg -nP 'sparse_mla_fwd_pipelined|from.*sparse_mla_fwd_pipelined|import.*sparse_mla_fwd_pipelined' --type=py -g '!sparse_mla_fwd_pipelined.py'
Length of output: 334
🏁 Script executed:
#!/bin/bash
# Check if the original sparse_mla_fwd.py has the NI % 2 assertion
sed -n '58,75p' examples/deepseek_v32/sparse_mla_fwd.py
Length of output: 596
🏁 Script executed:
#!/bin/bash
# Verify the constraint chain: does topk get used with other values in the wrapper function?
sed -n '340,395p' examples/deepseek_v32/sparse_mla_fwd_pipelined.py | head -60
Length of output: 1684
---
🏁 Script executed:
```shell
#!/bin/bash
# Check test file to see what topk values it uses
sed -n '28,50p' examples/deepseek_v32/test_tilelang_example_deepseek_v32.py
Length of output: 556
🏁 Script executed:
#!/bin/bash
# Also check the main test function parameters
rg -nP 'test_sparse_mla_fwd_pipelined' examples/deepseek_v32/sparse_mla_fwd_pipelined.py -A25 | head -40
Length of output: 960
New runtime constraint added that breaks backward compatibility.
This assertion requires NI % 2 == 0
(where NI = tilelang.cdiv(topk, block_I)
with block_I = 64
), which means topk
must be divisible by 128. This is a stricter constraint than the existing check at line 46 (topk % block_I == 0
), and it was not present in the original sparse_mla_fwd.py
(which has NI computation but no such assertion).
Breaking change: Previously, any topk
divisible by 64 was accepted (64, 192, 320, etc.). Now only topk
divisible by 128 are accepted (128, 256, 384, etc.). Since topk
is derived from user-provided tensor dimensions, existing call sites with odd multiples of 64 will fail at runtime.
The existing test suite passes because the test values (256, 2048) satisfy the new constraint, but this does not guarantee external users won't be affected.
🤖 Prompt for AI Agents
In examples/deepseek_v32/sparse_mla_fwd_pipelined.py around line 72, the new
assertion "assert NI % 2 == 0" enforces NI even (making topk divisible by 128)
which is a breaking change; revert this to preserve previous behavior by
removing the NI-even assertion and instead ensure the original constraint topk %
block_I == 0 is enforced (or, if NI is used elsewhere and must be
integer-divided, use floor/ceil or cdiv as before without requiring evenness).
Update the surrounding comment to state only the required constraint (topk
divisible by block_I) so existing call sites that are odd multiples of block_I
continue to work.
B_shape = ( | ||
N // micro_size_y, | ||
K // micro_size_k, | ||
micro_size_y, | ||
micro_size_k // num_elems_per_byte, | ||
) | ||
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) | ||
B_shared_shape = ( | ||
block_N // micro_size_y, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Guard the new 4‑D B layout with invariants.
Add early asserts for divisibility: N % micro_size_y == 0
, K % micro_size_k == 0
, and block_K % micro_size_k == 0
, to prevent shape/layout mismatches at runtime.
Apply:
B_shape = (
N // micro_size_y,
K // micro_size_k,
micro_size_y,
micro_size_k // num_elems_per_byte,
)
+ assert N % micro_size_y == 0
+ assert K % micro_size_k == 0
+ assert block_K % micro_size_k == 0
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
B_shape = ( | |
N // micro_size_y, | |
K // micro_size_k, | |
micro_size_y, | |
micro_size_k // num_elems_per_byte, | |
) | |
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) | |
B_shared_shape = ( | |
block_N // micro_size_y, | |
B_shape = ( | |
N // micro_size_y, | |
K // micro_size_k, | |
micro_size_y, | |
micro_size_k // num_elems_per_byte, | |
) | |
assert N % micro_size_y == 0 | |
assert K % micro_size_k == 0 | |
assert block_K % micro_size_k == 0 | |
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) | |
B_shared_shape = ( | |
block_N // micro_size_y, |
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_fine_grained.py around lines
197 to 205, the newly introduced 4-D B layout assumes divisibility that can
break at runtime; add early asserts confirming N % micro_size_y == 0, K %
micro_size_k == 0, and block_K % micro_size_k == 0 before computing B_shape (and
related shapes) so the code fails fast with a clear message when the dimensions
are not divisible as required.
def __init__(self, config: dict, weights: dict): | ||
super().__init__() | ||
self.top_k: int = config["n_experts_per_token"] | ||
self.num_experts: int = config["n_routed_experts"] | ||
self.d_hidden: int = config["d_hidden"] | ||
|
||
self.W_g_weight = weights['router.weight'].t() | ||
self.W_g_weight = weights["router.weight"].t() | ||
|
||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device mismatch risk: register router weight as a buffer so .to(...)
moves it.
self.W_g_weight
is a plain tensor; moving the module to CUDA won’t move it, leading to CPU/GPU matmul errors. Register it as a buffer and ensure contiguity.
class MoEGate(nn.Module):
- def __init__(self, config: dict, weights: dict):
+ def __init__(self, config: dict, weights: dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
-
- self.W_g_weight = weights["router.weight"].t()
+ # Register as buffer so module.to(device) migrates it with the rest.
+ self.register_buffer(
+ "W_g_weight",
+ weights["router.weight"].t().contiguous(),
+ persistent=False,
+ )
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __init__(self, config: dict, weights: dict): | |
super().__init__() | |
self.top_k: int = config["n_experts_per_token"] | |
self.num_experts: int = config["n_routed_experts"] | |
self.d_hidden: int = config["d_hidden"] | |
self.W_g_weight = weights['router.weight'].t() | |
self.W_g_weight = weights["router.weight"].t() | |
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
def __init__(self, config: dict, weights: dict): | |
super().__init__() | |
self.top_k: int = config["n_experts_per_token"] | |
self.num_experts: int = config["n_routed_experts"] | |
self.d_hidden: int = config["d_hidden"] | |
# Register as buffer so module.to(device) migrates it with the rest. | |
self.register_buffer( | |
"W_g_weight", | |
weights["router.weight"].t().contiguous(), | |
persistent=False, | |
) | |
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
🤖 Prompt for AI Agents
In examples/fusedmoe/example_fusedmoe_tilelang.py around lines 325 to 333, the
router weight is stored as a plain tensor (self.W_g_weight) which will not move
with the module and can cause CPU/GPU device-mismatch errors; replace that
assignment by registering the tensor as a buffer so .to(...) moves it with the
module and ensure it is contiguous and transposed as needed (e.g., use
self.register_buffer("W_g_weight", weights["router.weight"].t().contiguous())).
2a7a6c5
to
104c661
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
benchmark/matmul/benchmark_matmul_intrinsic.py (1)
305-305
: Remove hardcoded override that ignores user input.This line forces
with_roller
toTrue
, completely ignoring the--with_roller
CLI argument parsed on line 304. This breaks the command-line interface and is unrelated to the PR's linting objectives.Apply this diff to remove the hardcoded override:
with_roller = args.with_roller -with_roller = True # Compute total floating-point operations
If forcing
with_roller=True
is intentional for testing purposes, please document this in the PR description and consider using a different approach (e.g., changing the default value in the argument parser or adding a comment explaining why the override is necessary).examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py (1)
36-36
: Fix typo and simplify assertionMessage has “TRue”; also “is True” is unnecessary.
-assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" +assert trans_B, "Dequantize only implemented for trans_B=True currently"examples/deepseek_nsa/example_tilelang_nsa_decode.py (1)
93-99
: Bounds check bug allows OOB slices when sentinel SEQ_LEN is present
block_indices
is filled withSEQ_LEN
as padding; withi_s = BlockIndices[...] * BS
,if i_s >= 0
admits the sentinel and triggers an out‑of‑range slice. Guard upper bound as well.- if i_s >= 0: # Skip invalid/padding blocks + # Skip invalid/padding blocks; ensure the [i_s, i_s+BS) window is valid + if 0 <= i_s and i_s + BS <= seq_len:examples/flash_attention/example_gqa_fwd_bshd.py (1)
201-223
: Fix device mismatch in ref_program (CPU scalar tensor on CUDA tensor)Match device when computing the scaling factor. Same issue as in the WGMMA variant.
- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device)) # or: +# import math +# scores = scores * (1.0 / math.sqrt(dim))
♻️ Duplicate comments (17)
examples/convolution/example_convolution_autotune.py (2)
97-115
: Multi-line signature improves readability, but the unused parameter issue remains.The reformatting to a multi-line function signature with trailing commas aligns with the PR's linting objectives and improves readability.
However, the
enable_rasteration
parameter is still unused in the function body (confirmed by static analysis). This critical issue was flagged in previous reviews and remains unresolved.
171-183
: Multi-line signature improves readability, but the unused parameter issue remains.The reformatting to a multi-line function signature with trailing commas is consistent with project-wide conventions and improves readability.
However, the
with_roller
parameter is still unused in the function body (confirmed by static analysis). This major issue was flagged in previous reviews and remains unresolved.examples/bitnet-1.58b/tokenization_bitnet.py (1)
326-326
: Missing type annotation forsave_directory
parameter.The
save_directory
parameter is still missing a type hint. It should be annotated asstr
for consistency with the rest of the codebase.Apply this diff:
- def save_vocabulary(self, save_directory, filename_prefix: str | None = None) -> tuple[str]: + def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None) -> tuple[str]:examples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
72-72
: Breaking change: New runtime constraint.This assertion requires
NI % 2 == 0
, meaningtopk
must be divisible by 128 instead of just 64. This breaks backward compatibility for existing code using odd multiples of 64 (e.g., 64, 192, 320).examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
368-369
: Remove unused parameters in ref_program_fa.The parameters
block_indices
,max_cache_seqlen
,num_blocks
, andblock_size
are never used in the function body. Consider removing them and updating call sites at lines 463 and 468.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (2)
332-333
: Remove unused parameters in ref_program_torch.The parameters
max_cache_seqlen
andnum_blocks
are never used in the function body. Consider removing them and updating the call site at line 464.
373-374
: Remove unused parameters in ref_program_fa.The parameters
block_indices
,max_cache_seqlen
,num_blocks
, andblock_size
are never used in the function body. Consider removing them and updating call sites at lines 475 and 480.examples/bitnet-1.58b/vllm_workspace/utils.py (1)
2-4
: Complete the migration to built-in generics for type aliases.The function signatures on lines 7 and 29-30 have been updated to use built-in generics (
list[...]
), but the type aliases still usetyping.Tuple
,typing.List
, andtyping.Dict
. This inconsistency was flagged in a previous review and remains unaddressed.Apply this diff to complete the migration:
-from typing import Dict, List, Tuple + -TokensText = Tuple[List[int], str] +TokensText = tuple[list[int], str]And for line 25:
-TokensTextLogprobs = Tuple[List[int], str, List[Dict[int, float]]] +TokensTextLogprobs = tuple[list[int], str, list[dict[int, float]]]After these changes, the
typing
import on line 2 can be removed entirely.Also applies to: 25-25
examples/fusedmoe/example_fusedmoe_tilelang.py (2)
137-138
: Critical: 1D tensor shapes must be tuples, not scalars.Lines 137-138 define shapes as scalars (
routed_expert_weights_shape = group_sum
andgroup_sizes_shape = n_routed_experts
) but TileLang requires all tensor shapes to be tuples. For 1D tensors, use 1-element tuples.Apply this diff to fix:
- routed_expert_weights_shape = group_sum - group_sizes_shape = n_routed_experts + routed_expert_weights_shape = (group_sum,) + group_sizes_shape = (n_routed_experts,)
325-333
: Critical: Register router weight as buffer to prevent device mismatch.
self.W_g_weight
is stored as a plain tensor attribute (line 331), which won't move when the module is transferred to a device via.to(device)
. This causes CPU/GPU mismatch errors during the matmul inforward
(line 334).Apply this diff to register it as a buffer:
class MoEGate(nn.Module): def __init__(self, config: dict, weights: dict): super().__init__() self.top_k: int = config["n_experts_per_token"] self.num_experts: int = config["n_routed_experts"] self.d_hidden: int = config["d_hidden"] - self.W_g_weight = weights["router.weight"].t() + self.register_buffer( + "W_g_weight", + weights["router.weight"].t().contiguous(), + persistent=False, + )examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
441-441
: Fix fullwidth parentheses in comment (duplicate issue).This issue was flagged in a previous review but remains unfixed. The comment still contains fullwidth parentheses
(padding_M,)
that should be ASCII(padding_M,)
.Apply this diff to fix:
- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
211-212
: Typo in inline comment.Missing space: "function'sinterface" → "function's interface".
Apply this diff:
- query = (query.transpose(1, 2).contiguous().unsqueeze(3) - ) # align with the original function'sinterface + query = (query.transpose(1, 2).contiguous().unsqueeze(3) + ) # align with the original function's interfaceexamples/cast/example_per_token_cast_to_fp8.py (1)
81-81
: Python 3.8 compat: avoid built‑in generics in return type.tuple[...] can break runtime type-hint evaluation on 3.8. Prefer typing.Tuple[...] to match the PR’s 3.8-compat objective.
-from __future__ import annotations +from __future__ import annotations +from typing import Tuple @@ -def ref_program(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def ref_program(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:examples/cast/example_group_per_split_token_cast_to_fp8.py (2)
154-154
: Python 3.8 compat: avoid built‑in generics in return type.tuple[...] may break when evaluated on 3.8; switch to typing.Tuple[...] to align with the PR’s stated compatibility target.
-from __future__ import annotations +from __future__ import annotations +from typing import Tuple @@ -def ref_per_token_cast_to_fp8(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def ref_per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
167-176
: Shape bug for non‑multiples of 128.n // 128 truncates; the rest of the code uses ceil division. This mis-sizes the scales tensor for N % 128 != 0.
- x_fp8 = ( - torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), - torch.empty((num_groups, m, n // 128), device="cuda", dtype=torch.float), - ) + x_fp8 = ( + torch.empty((num_groups, m, n), device="cuda", dtype=torch.float8_e4m3fn), + torch.empty((num_groups, m, ceil_div(n, 128)), device="cuda", dtype=torch.float), + )examples/flash_attention/example_gqa_bwd.py (1)
258-270
: Consistent formatting applied.The multi-line signature formatting is consistent with
flashattn_bwd_atomic_add
above, maintaining code style uniformity.examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (1)
194-202
: Critical: undefined D, and block_counts may be int
- D is not defined; use K for q/k and V for v.
- block_counts can be int;
.to(...)
will crash. Normalize to a tensor.- Ensure offsets/token_indices are provided before kernel call.
- kernel( - q.view(C_SEQ_LEN, HQ, D), - k.view(C_SEQ_LEN, H, D), - v.view(C_SEQ_LEN, H, D), + assert offsets is not None, "offsets must be provided for varlen forward" + assert token_indices is not None, "token_indices must be provided for varlen forward" + + if isinstance(block_counts, torch.Tensor): + block_counts_t = block_counts.to(torch.int32).view(C_SEQ_LEN, H) + else: + # constant per-token count + block_counts_t = torch.full( + (C_SEQ_LEN, H), int(block_counts), dtype=torch.int32, device=block_indices.device + ) + + kernel( + q.view(C_SEQ_LEN, HQ, K), + k.view(C_SEQ_LEN, H, K), + v.view(C_SEQ_LEN, H, V), o_slc.view(C_SEQ_LEN, HQ, V), block_indices.to(torch.int32).view(C_SEQ_LEN, H, S), - block_counts.to(torch.int32).view(C_SEQ_LEN, H), - offsets.to(torch.int32), - token_indices.to(torch.int32), + block_counts_t, + offsets.to(torch.int32), + token_indices.to(torch.int32), )Based on past review comments.
🧹 Nitpick comments (27)
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
115-115
: Minor style observation: trailing comma after single argumentThe parentheses wrapping is fine, but the trailing comma after
saved_model_path,
is unusual for a single argument. While syntactically valid, trailing commas are typically used for multi-argument calls or to facilitate future additions. Consider removing it for consistency unless it's an intentional style choice.Apply this diff if you prefer to remove the trailing comma:
- qmodel = (BitnetForCausalLM.from_quantized(saved_model_path,).cuda().half()) + qmodel = (BitnetForCausalLM.from_quantized(saved_model_path).cuda().half())examples/dequantize_gemm/dequantize_utils.py (1)
148-149
: Explicit float literal is fine but optional.Changing
1
to1.0
on line 148 makes the float context more explicit, though sincesim
is already a float, the behavior is unchanged. The f-string quote change on line 149 is consistent with the standardization throughout the file.examples/dequantize_gemm/example_dequant_gemm_w4a8.py (1)
74-75
: Consider renaming variables for consistency and clarity.The variable names are now backwards:
i4_shifted
holds the unshifted masked value, whilei4
holds the final shifted result. This is inconsistent with the similar function_tir_u8_to_i4_to_i8
(lines 17-20), wherei4
is the masked value andi8_shifted
is the left-shifted intermediate.Consider reverting the variable names or renaming them to be self-documenting:
- i4_shifted = (val >> (pos * 4)) & mask - i4 = (i4_shifted << 4) >> 4 + i4 = (val >> (pos * 4)) & mask + i4_signed = (i4 << 4) >> 4 - return i4.view(torch.int8) + return i4_signed.view(torch.int8)examples/deepseek_v32/utils.py (1)
17-17
: Inconsistent with PR objectives: Legacy typing imports not modernized.This file still imports legacy typing constructs (
Dict
,Optional
,Tuple
) and lacksfrom __future__ import annotations
. According to the PR objectives and AI summary, the PR aims to enable FA102 and modernize type annotations to Python 3.10+ syntax (dict
,list
,X | None
, etc.) throughout the codebase.Consider updating this file to align with the PR's goals, or clarify if this file is intentionally excluded from the modernization.
Expected changes:
+from __future__ import annotations + # -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa import torch ... -from typing import Any, Callable, Dict, Literal, Optional, Tuple +from typing import Any, Callable, LiteralThen update type hints throughout (e.g., lines 49-50):
- last_args: Optional[Tuple] = None - last_kwargs: Optional[Dict] = None + last_args: tuple | None = None + last_kwargs: dict | None = Noneexamples/deepseek_v32/sparse_mla_fwd_pipelined.py (1)
280-297
: Consider more compact formatting for array indexing.The multi-line expansion of array indexing expressions (one subscript per line) creates very tall code blocks that are harder to scan. While consistent with the trailing-comma pattern for function arguments, array indexing benefits from more compact formatting.
Consider maintaining the original compact style for array subscripts:
# More readable compact form: KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = \ KV[b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v]This preserves the assignment pattern clarity while avoiding excessive vertical space.
Also applies to: 315-332
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
415-415
: Remove redundantint()
cast.In Python 3,
math.ceil()
already returns anint
, making the outerint()
wrapper unnecessary.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
418-418
: Remove redundantint()
cast.In Python 3,
math.ceil()
already returns anint
, making the outerint()
wrapper unnecessary.Apply this diff:
- max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + max_selected_blocks = math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)examples/blocksparse_gemm/example_blocksparse_gemm.py (1)
64-66
: Consider improving line break placement for readability.The current formatting splits
.to(torch.float32)
across lines, which slightly impacts readability. Consider keeping method calls intact when breaking lines.Alternative formatting:
- accu += A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to( - torch.float32) @ B[k * block_K:(k + 1) * block_K, - j * block_N:(j + 1) * block_N].to(torch.float32) + accu += ( + A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(torch.float32) + @ B[k * block_K:(k + 1) * block_K, j * block_N:(j + 1) * block_N].to(torch.float32) + )examples/flash_decoding/example_gqa_decode.py (1)
34-46
: Consider more precise type annotation for the return value.While the modernization from
Tuple[Dict, int]
totuple[dict, int]
is correct, thedict
type is overly generic (equivalent todict[Any, Any]
). Based on the returned dictionary structure, a more precise annotation would be:def get_heuristic_config() -> tuple[dict[str, int | bool], int]:This better reflects that the dictionary has string keys and values that are either
int
(block_N, block_H, num_split, num_stages, threads) orbool
(though the current implementation only usesint
).examples/deepseek_mla/example_mla_decode_paged.py (1)
313-354
: Suggest removing unused parameters.The function signature formatting is good. However,
block_table
andblock_size
parameters are not used in the function body.Consider removing these unused parameters if they're not required for API consistency:
def run_torch_mla( q, - block_table, blocked_k, max_seqlen_pad, - block_size, b, s_q, cache_seqlens,Alternatively, if keeping them for API parity with
run_tilelang_mla
, add a comment explaining why they're present but unused.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
231-233
: Prefer explicit None check for Optional[int].Truthiness can mis-handle 0; use
is not None
for clarity and correctness.Apply this diff:
- if sliding_window: + if sliding_window is not None: too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) mask.masked_fill_(too_old, float("-inf"))examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
357-358
: Minor: read sink once to a local for clarity (and tiny reuse).You already load
sink[0] = Sinks[bx]
. Use it below to avoid re-indexing.Apply this diff:
- dsink_fragment[i] = (-T.exp2(Sinks[bx] * 1.44269504 - lse_fragment[i]) * - delta_fragment[i]) + dsink_fragment[i] = (-T.exp2(sink[0] * 1.44269504 - lse_fragment[i]) + * delta_fragment[i])
437-439
: Prefer explicit None check for Optional[int].Align with other sites and avoid edge-case truthiness.
Apply this diff:
- if sliding_window: + if sliding_window is not None: too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) mask.masked_fill_(too_old, float("-inf"))examples/bitnet-1.58b/benchmark_inference_latency.py (2)
12-16
: Drop numpy import; return the scalar directly.get_runtime returns a float; np.mean(times) is redundant. Removing numpy shrinks deps and speeds cold start slightly.
def profile(model, input_data): import time - - import numpy as np @@ - return np.mean(times) + return times
37-44
: Device placement: avoid mixing device_map="auto" with manual .cuda().You set device_map="auto" at load but later do model = model.cuda() in profile(). These can conflict (sharded/offloaded models vs single‑GPU move) and skew timings.
Two options (pick one):
- Single‑GPU benchmarking (simplest, consistent with profile()):
- Load on CPU (omit device_map or set None) and move once before profiling; then remove model.cuda() from profile().
- Accelerate/auto sharding route:
- Keep device_map="auto", but don’t call model.cuda() in profile(); rely on the map.
Minimal change if you prefer the first:
- device_map="auto", + # device_map=None, # or drop entirely; we'll place explicitlyIf adopting explicit placement, also move the model once in main (and delete the .cuda() in profile):
).half() - with torch.no_grad(): + with torch.no_grad(): + model = model.cuda() model.quantize() model = torch.compile(model)Please confirm which path you want; I can produce a final patch accordingly.
examples/cast/example_per_token_cast_to_fp8.py (1)
11-16
: Guard blk_m preconditions used by forward_thread_fn.blk_m // 4 implies blk_m must be a multiple of 4 and ≥ 4. Add an explicit assert to fail fast and aid users.
@tilelang.jit(out_idx=[1, 2]) def per_token_cast_to_fp8(M, N, blk_m): dtype = "float" group_size = 128 fp8_min = -448.0 fp8_max = 448.0 + assert blk_m % 4 == 0 and blk_m >= 4, "blk_m must be a multiple of 4 and >= 4"
Also applies to: 36-37
examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
66-73
: Avoid divide‑by‑zero in inactive rows (optional).When y_s_local[i] is set to 0 for padded rows, y_local[i, j] / 0 is computed then masked later. Consider skipping the division for those rows to avoid infinities.
examples/cast/example_triton_cast_to_fp8.py (1)
131-133
: Optional UX improvement: accept non‑divisible last dim.Instead of asserting divisibility, consider padding to next multiple of group_size (as other examples do) for a more forgiving API.
examples/flash_attention/example_gqa_bwd.py (1)
382-385
: Consider keeping simple tuple unpacking on one line.The multi-line formatting for this 2-element tuple unpacking may reduce readability. The single-line form is more idiomatic for simple cases:
HEAD_KV, D_HEAD_V = v.shape[-2], v.shape[-1]examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py (2)
75-75
: Redundant self-assignment
import_source = import_source
is a no-op. Remove.- import_source = import_source
214-214
: Avoid redundant dtype/device castsAllocate on the correct device with dtype once.
- B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device) + B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half, device=A.device)examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (1)
12-15
: Prefer import fallback over version parsingUsing try/except avoids packaging overhead and handles edge builds.
-if parse(fla.__version__) < parse("0.2.1"): - from fla.ops.common.utils import prepare_token_indices -else: - from fla.ops.utils import prepare_token_indices +try: + from fla.ops.utils import prepare_token_indices +except Exception: + from fla.ops.common.utils import prepare_token_indicesdocs/conf.py (1)
7-8
: Specify encoding when reading VERSIONPrevents locale-dependent issues on some systems.
-with open("../VERSION") as f: +with open("../VERSION", encoding="utf-8") as f:examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
37-41
: Nit: explicit encoding (optional)For reproducible test data reads.
-def _read_prompts(filename: str) -> list[str]: - with open(filename) as f: +def _read_prompts(filename: str) -> list[str]: + with open(filename, encoding="utf-8") as f:examples/deepseek_nsa/example_tilelang_nsa_fwd.py (1)
152-158
: Harden CUDA usage and remove dead tensor
- Add an early CUDA check and reuse a single device variable; avoids crashes on non‑CUDA hosts.
- DO is unused; drop it.
Apply this diff to tensor inits:
- Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True) - DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda") - block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda") - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device="cuda") + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device=device).requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device=device).requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device=device).requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device=device).requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device=device).requires_grad_(True) + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device=device) + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN, H), device=device)Add near the top of main (outside this hunk):
+ if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this example") + device = "cuda"Also applies to: 159-159, 167-167
examples/deepseek_nsa/example_tilelang_nsa_decode.py (1)
153-156
: Remove unused tensors and add CUDA guard
mask
andDO
are unused; drop them.- Add a CUDA check and reuse
device
for inits.Apply this diff:
- Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True) - K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True) - - mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda") - DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda") - - block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda") + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this example") + device = "cuda" + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device=device).requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device=device).requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device=device).requires_grad_(True) + + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device=device) @@ - block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda") + block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device=device)Also applies to: 157-159, 167-167
examples/flash_attention/example_gqa_fwd_bshd.py (1)
43-44
: Prevent zero‑sized warp tiles slipping through filter
warp_M = block_M // warp_count
andwarp_N = block_N // warp_count
can be 0 whenwarp_count > block_M
or> block_N
.0 % warp_alignment == 0
, so your current check wrongly accepts them. Add lower‑bound and multiple‑of checks.- if warp_M % config.warp_alignment != 0 or warp_N % config.warp_alignment != 0: + if ( + warp_M < config.warp_alignment + or warp_N < config.warp_alignment + or warp_M % config.warp_alignment != 0 + or warp_N % config.warp_alignment != 0 + ): continue
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (107)
benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
(3 hunks)benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
(1 hunks)benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
(4 hunks)benchmark/matmul/benchmark_matmul.py
(1 hunks)benchmark/matmul/benchmark_matmul_intrinsic.py
(1 hunks)benchmark/matmul/benchmark_matmul_sp.py
(3 hunks)benchmark/matmul_fp8/benchmark_matmul.py
(0 hunks)docs/conf.py
(2 hunks)examples/amd/example_amd_flash_attn_bwd.py
(9 hunks)examples/amd/example_amd_flash_attn_fwd.py
(6 hunks)examples/analyze/example_conv_analyze.py
(2 hunks)examples/attention_sink/benchmark_gqa_sink_fwd.py
(5 hunks)examples/attention_sink/benchmark_mha_sink_fwd.py
(4 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(8 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(14 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(6 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(7 hunks)examples/bitnet-1.58b/benchmark_generate.py
(3 hunks)examples/bitnet-1.58b/benchmark_inference_latency.py
(3 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(2 hunks)examples/bitnet-1.58b/eval_correctness.py
(3 hunks)examples/bitnet-1.58b/eval_gpu_memory.py
(2 hunks)examples/bitnet-1.58b/eval_ppl.py
(3 hunks)examples/bitnet-1.58b/eval_utils.py
(1 hunks)examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
(2 hunks)examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
(0 hunks)examples/bitnet-1.58b/load_from_quantized.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(3 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(27 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(9 hunks)examples/bitnet-1.58b/utils_quant.py
(2 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(22 hunks)examples/bitnet-1.58b/vllm_workspace/utils.py
(3 hunks)examples/blocksparse_attention/block_sparse_attn_triton.py
(9 hunks)examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
(5 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
(17 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(15 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(16 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py
(12 hunks)examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py
(12 hunks)examples/blocksparse_attention/heuristic.py
(1 hunks)examples/blocksparse_attention/test_example_blocksparse_attention.py
(2 hunks)examples/blocksparse_gemm/example_blocksparse_gemm.py
(5 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(7 hunks)examples/cast/example_per_token_cast_to_fp8.py
(7 hunks)examples/cast/example_triton_cast_to_fp8.py
(3 hunks)examples/convolution/example_convolution.py
(3 hunks)examples/convolution/example_convolution_autotune.py
(5 hunks)examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
(4 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
(7 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
(9 hunks)examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py
(9 hunks)examples/deepseek_mla/benchmark_mla.py
(13 hunks)examples/deepseek_mla/example_mla_decode.py
(7 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(12 hunks)examples/deepseek_mla/example_mla_decode_persistent.py
(4 hunks)examples/deepseek_mla/example_mla_decode_ws.py
(14 hunks)examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py
(3 hunks)examples/deepseek_mla/torch_refs.py
(1 hunks)examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
(28 hunks)examples/deepseek_nsa/example_tilelang_nsa_bwd.py
(4 hunks)examples/deepseek_nsa/example_tilelang_nsa_decode.py
(3 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd.py
(2 hunks)examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py
(7 hunks)examples/deepseek_nsa/example_triton_nsa_bwd.py
(22 hunks)examples/deepseek_nsa/example_triton_nsa_fwd.py
(8 hunks)examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
(12 hunks)examples/deepseek_nsa/reference.py
(9 hunks)examples/deepseek_v32/fp8_lighting_indexer.py
(3 hunks)examples/deepseek_v32/sparse_mla_bwd.py
(12 hunks)examples/deepseek_v32/sparse_mla_fwd.py
(6 hunks)examples/deepseek_v32/sparse_mla_fwd_pipelined.py
(13 hunks)examples/deepseek_v32/topk_selector.py
(3 hunks)examples/deepseek_v32/utils.py
(4 hunks)examples/dequantize_gemm/dequantize_utils.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(7 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(11 hunks)examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
(8 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(9 hunks)examples/dequantize_gemm/example_dequant_gemm_w4a8.py
(4 hunks)examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py
(5 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(14 hunks)examples/dynamic_shape/example_dynamic.py
(2 hunks)examples/elementwise/example_elementwise_add.py
(2 hunks)examples/elementwise/example_elementwise_add_tma_1d.py
(1 hunks)examples/flash_attention/bert_padding.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(12 hunks)examples/flash_attention/example_gqa_bwd_tma_reduce.py
(12 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(8 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(5 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_bwd.py
(5 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(5 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(6 hunks)examples/flash_attention/example_mha_fwd_varlen.py
(4 hunks)examples/flash_decoding/example_gqa_decode.py
(11 hunks)examples/flash_decoding/example_mha_inference.py
(9 hunks)examples/fusedmoe/example_fusedmoe_tilelang.py
(16 hunks)examples/fusedmoe/example_fusedmoe_torch.py
(7 hunks)examples/fusedmoe/test_example_fusedmoe.py
(1 hunks)
⛔ Files not processed due to max files limit (36)
- examples/gdn/example_chunk_delta_bwd.py
- examples/gdn/example_chunk_delta_h.py
- examples/gdn/example_chunk_o.py
- examples/gdn/example_chunk_o_bwd.py
- examples/gdn/example_chunk_scaled_dot_kkt.py
- examples/gdn/example_cumsum.py
- examples/gdn/example_wy_fast.py
- examples/gdn/example_wy_fast_bwd_split.py
- examples/gdn/test_example_gdn_compilation.py
- examples/gdn/utils.py
- examples/gemm/example_gemm_autotune.py
- examples/gemm/example_gemm_intrinsics.py
- examples/gemm/example_gemm_persistent.py
- examples/gemm_fp8/example_tilelang_gemm_amd.py
- examples/gemm_fp8/example_tilelang_gemm_fp8.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
- examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
- examples/gemm_sm100/gemm_mma.py
- examples/gemm_sm100/gemm_tcgen5mma.py
- examples/gemm_sp/example_gemm_sp.py
- examples/gemm_splitk/example_tilelang_gemm_splitk.py
- examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py
- examples/gemm_streamk/example_tilelang_gemm_streamk.py
- examples/grouped_gemm/example_grouped_gemm_bwd.py
- examples/grouped_gemm/example_grouped_gemm_fwd.py
- examples/hadamard_transform/example_hadamard.py
- examples/linear_attention/example_linear_attn_bwd.py
- examples/linear_attention/example_linear_attn_fwd.py
- examples/linear_attention/example_mamba_chunk_scan.py
- examples/linear_attention/example_mamba_chunk_state.py
- examples/linear_attention/example_retention_fwd.py
- examples/minference/example_vertical_slash_sparse_attn.py
- examples/norm/rms_norm.py
- examples/online_softmax/online_softmax.py
- examples/plot_layout/fragment_mma_load_a.py
- examples/seer_attention/block_sparse_attn_tilelang.py
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/kernel_benchmark/tl_int8xint8.py
- benchmark/matmul_fp8/benchmark_matmul.py
✅ Files skipped from review due to trivial changes (14)
- benchmark/blocksparse_attention/benchmark_tilelang_block_sparse_fmha.py
- examples/flash_attention/example_mha_bwd.py
- examples/bitnet-1.58b/eval_utils.py
- examples/bitnet-1.58b/utils_quant.py
- examples/deepseek_v32/fp8_lighting_indexer.py
- examples/flash_attention/example_mha_fwd_bhsd.py
- examples/flash_decoding/example_mha_inference.py
- examples/blocksparse_attention/example_tilelang_block_sparse_attn.py
- examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
- examples/blocksparse_attention/test_example_blocksparse_attention.py
- examples/attention_sink/benchmark_gqa_sink_fwd.py
- examples/deepseek_v32/sparse_mla_bwd.py
- benchmark/blocksparse_attention/benchmark_torch_block_sparse_fmha.py
- examples/bitnet-1.58b/load_from_quantized.py
🚧 Files skipped from review as they are similar to previous changes (34)
- examples/flash_attention/bert_padding.py
- benchmark/matmul/benchmark_matmul.py
- examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
- benchmark/blocksparse_attention/benchmark_library_dense_fmha.py
- examples/bitnet-1.58b/configuration_bitnet.py
- examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py
- examples/elementwise/example_elementwise_add_tma_1d.py
- examples/blocksparse_attention/heuristic.py
- examples/dynamic_shape/example_dynamic.py
- examples/flash_attention/example_mha_bwd_bhsd.py
- examples/flash_attention/example_mha_fwd_bshd.py
- examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
- examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py
- benchmark/blocksparse_attention/benchmark_triton_block_sparse_fmha.py
- examples/flash_attention/example_gqa_bwd_tma_reduce.py
- examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py
- examples/amd/example_amd_flash_attn_fwd.py
- examples/deepseek_v32/sparse_mla_fwd.py
- examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py
- examples/deepseek_mla/example_mla_decode_persistent.py
- examples/deepseek_nsa/reference.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
- examples/deepseek_mla/benchmark_mla.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
- examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
- examples/attention_sink/example_mha_sink_fwd_bhsd.py
- examples/flash_attention/example_mha_fwd_varlen.py
- examples/deepseek_mla/example_mla_decode_ws.py
- examples/deepseek_nsa/example_triton_nsa_bwd.py
- examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
- examples/deepseek_nsa/example_tilelang_nsa_bwd.py
- examples/deepseek_nsa/example_triton_nsa_fwd_varlen.py
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py
- examples/deepseek_nsa/example_triton_nsa_fwd.py
🧰 Additional context used
🧬 Code graph analysis (39)
examples/deepseek_mla/example_mla_decode.py (4)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/tileop/gemm/gemm_base.py (2)
clear_accum
(107-108)policy
(119-120)examples/gemm/example_gemm.py (1)
gemm
(9-25)tilelang/language/kernel.py (1)
threads
(215-219)
examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
from_quantized
(1500-1578)
examples/analyze/example_conv_analyze.py (2)
examples/gemm/example_gemm_autotune.py (1)
kernel
(110-150)examples/analyze/example_gemm_analyze.py (1)
kernel
(10-46)
examples/cast/example_per_token_cast_to_fp8.py (1)
tilelang/language/tir/op.py (1)
ceildiv
(3116-3133)
examples/cast/example_group_per_split_token_cast_to_fp8.py (3)
tilelang/language/tir/op.py (1)
ceildiv
(3116-3133)tilelang/language/kernel.py (2)
Kernel
(229-303)threads
(215-219)examples/cast/example_per_token_cast_to_fp8.py (2)
ref_program
(81-91)ceil_div
(67-78)
examples/convolution/example_convolution.py (1)
examples/convolution/example_convolution_autotune.py (1)
convolution
(97-168)
examples/convolution/example_convolution_autotune.py (2)
tilelang/autotuner/tuner.py (1)
autotune
(727-820)examples/convolution/example_convolution.py (4)
convolution
(29-99)main
(18-23)main
(55-97)main
(102-138)
examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (2)
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (5)
flashattn
(32-204)ref_program
(209-254)main
(138-202)main
(272-334)gen_inputs
(257-269)examples/attention_sink/example_mha_sink_fwd_bhsd.py (5)
flashattn
(27-189)ref_program
(193-239)main
(130-187)main
(256-319)gen_inputs
(242-253)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (1)
kernel_func
(24-202)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (2)
tilelang/language/allocate.py (1)
alloc_local
(39-50)examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)
ref_program
(193-239)
examples/blocksparse_gemm/example_blocksparse_gemm.py (2)
tilelang/jit/kernel.py (1)
params
(475-476)tilelang/engine/param.py (1)
KernelParam
(12-104)
examples/attention_sink/benchmark_mha_sink_fwd.py (1)
examples/attention_sink/benchmark_gqa_sink_fwd.py (2)
main
(130-201)triton_program
(100-127)
examples/elementwise/example_elementwise_add.py (2)
tilelang/language/parallel.py (1)
Parallel
(9-29)tilelang/autotuner/tuner.py (4)
AutoTuner
(93-588)from_kernel
(122-132)set_compile_args
(134-165)set_profile_args
(167-231)
examples/flash_decoding/example_gqa_decode.py (2)
examples/gemm/example_gemm_autotune.py (1)
get_heuristic_config
(165-199)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
tilelang/language/tir/op.py (2)
call_extern
(173-195)address_of
(464-480)
examples/bitnet-1.58b/eval_ppl.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
BitnetForCausalLM
(1231-1578)
examples/fusedmoe/example_fusedmoe_torch.py (1)
examples/fusedmoe/example_fusedmoe_tilelang.py (3)
forward
(317-320)forward
(333-338)forward
(432-535)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (2)
tilelang/language/tir/op.py (1)
reinterpret
(1898-1917)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
matmul
(49-354)
examples/deepseek_mla/experimental/example_mla_decode_kv_fp8.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/tileop/gemm/gemm_base.py (1)
policy
(119-120)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py (1)
examples/deepseek_mla/amd/benchmark_mla_decode_amd_torch.py (6)
run_torch_mla
(35-73)run_flash_mla_triton
(327-373)flash_mla_triton
(352-369)mla_decode_triton
(292-323)compare_a
(458-505)compare_ab
(382-455)
examples/bitnet-1.58b/benchmark_generate.py (5)
examples/bitnet-1.58b/eval_ppl.py (1)
main
(31-61)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py (1)
main
(64-117)examples/bitnet-1.58b/benchmark_inference_latency.py (1)
main
(37-53)examples/bitnet-1.58b/eval_correctness.py (1)
main
(75-92)examples/bitnet-1.58b/load_from_quantized.py (1)
main
(50-61)
examples/bitnet-1.58b/vllm_workspace/conftest.py (1)
examples/bitnet-1.58b/modeling_bitnet.py (1)
get_output_embeddings
(1249-1250)
examples/flash_attention/example_gqa_fwd_bshd.py (4)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/utils/tensor.py (1)
TensorSupplyType
(11-18)tilelang/profiler/__init__.py (2)
assert_allclose
(77-146)do_bench
(219-282)
examples/flash_attention/example_gqa_bwd.py (3)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)examples/flash_attention/example_gqa_bwd_tma_reduce.py (2)
flashattn_bwd_atomic_add
(175-276)flashattn_bwd_split
(282-383)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/bitnet-1.58b/eval_gpu_memory.py (2)
examples/bitnet-1.58b/benchmark_generate.py (1)
profile
(54-74)examples/bitnet-1.58b/modeling_bitnet.py (1)
_post_process_weights
(1487-1491)
examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py (1)
examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py (1)
parallel_nsa
(232-308)
examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py (2)
tilelang/language/__init__.py (1)
import_source
(206-208)tilelang/quantize/utils.py (1)
interleave_weight
(72-126)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (2)
tilelang/jit/kernel.py (2)
out_idx
(471-472)get_profiler
(385-401)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
examples/deepseek_v32/sparse_mla_fwd_pipelined.py (2)
tilelang/math/__init__.py (1)
next_power_of_2
(1-2)examples/deepseek_v32/sparse_mla_fwd.py (1)
sparse_mla_fwd
(15-173)
examples/bitnet-1.58b/modeling_bitnet.py (1)
examples/bitnet-1.58b/configuration_bitnet.py (1)
BitnetConfig
(29-194)
examples/bitnet-1.58b/benchmark_inference_latency.py (4)
examples/bitnet-1.58b/benchmark_generate.py (1)
profile
(54-74)examples/bitnet-1.58b/eval_correctness.py (1)
profile
(49-69)examples/bitnet-1.58b/eval_gpu_memory.py (1)
profile
(12-34)examples/bitnet-1.58b/benchmark_model_10k_loops.py (1)
profile
(19-41)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
matmul
(49-354)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
matmul
(86-381)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group
(52-109)
examples/fusedmoe/example_fusedmoe_tilelang.py (4)
tilelang/jit/__init__.py (1)
jit
(242-318)tilelang/language/copy.py (1)
copy
(15-93)tilelang/language/parallel.py (1)
Parallel
(9-29)tilelang/jit/kernel.py (1)
JITKernel
(27-503)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (2)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/__init__.py (1)
symbolic
(87-98)
examples/deepseek_nsa/example_tilelang_nsa_decode.py (2)
examples/deepseek_nsa/example_tilelang_nsa_bwd.py (1)
native_sparse_attention
(69-153)examples/deepseek_nsa/example_tilelang_nsa_fwd.py (2)
native_sparse_attention
(19-133)native_sparse_attention
(55-131)
examples/amd/example_amd_flash_attn_bwd.py (2)
tilelang/language/copy.py (1)
copy
(15-93)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (1)
flash_bwd_prep
(159-176)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py (3)
tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)tilelang/language/tir/op.py (1)
if_then_else
(2906-2936)tilelang/language/__init__.py (1)
symbolic
(87-98)
examples/deepseek_mla/example_mla_decode_paged.py (2)
tilelang/jit/kernel.py (1)
out_idx
(471-472)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-105)
🪛 Ruff (0.14.0)
examples/convolution/example_convolution_autotune.py
112-112: Unused function argument: enable_rasteration
(ARG001)
182-182: Unused function argument: with_roller
(ARG001)
examples/flash_decoding/example_gqa_decode.py
475-475: Avoid specifying long messages outside the exception class
(TRY003)
examples/bitnet-1.58b/modeling_bitnet.py
378-378: Unused method argument: use_cache
(ARG002)
380-380: Unused method argument: kwargs
(ARG002)
531-531: Unused method argument: use_cache
(ARG002)
533-533: Unused method argument: kwargs
(ARG002)
619-619: Unused method argument: use_cache
(ARG002)
621-621: Unused method argument: kwargs
(ARG002)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
441-441: Comment contains ambiguous (
(FULLWIDTH LEFT PARENTHESIS). Did you mean (
(LEFT PARENTHESIS)?
(RUF003)
441-441: Comment contains ambiguous )
(FULLWIDTH RIGHT PARENTHESIS). Did you mean )
(RIGHT PARENTHESIS)?
(RUF003)
examples/fusedmoe/example_fusedmoe_tilelang.py
7-7: from tilelang.autotuner import *
used; unable to detect undefined names
(F403)
8-8: from example_fusedmoe_torch import *
used; unable to detect undefined names
(F403)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
368-368: Unused function argument: block_indices
(ARG001)
368-368: Unused function argument: max_cache_seqlen
(ARG001)
368-368: Unused function argument: num_blocks
(ARG001)
369-369: Unused function argument: block_size
(ARG001)
415-415: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/amd/example_amd_flash_attn_bwd.py
268-268: Ambiguous variable name: O
(E741)
examples/blocksparse_attention/block_sparse_attn_triton.py
202-202: Unused function argument: ctx
(ARG001)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
332-332: Unused function argument: max_cache_seqlen
(ARG001)
332-332: Unused function argument: num_blocks
(ARG001)
334-334: Unpacked variable heads
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
373-373: Unused function argument: block_indices
(ARG001)
373-373: Unused function argument: max_cache_seqlen
(ARG001)
373-373: Unused function argument: num_blocks
(ARG001)
374-374: Unused function argument: block_size
(ARG001)
418-418: Value being cast to int
is already an integer
Remove unnecessary int
call
(RUF046)
examples/deepseek_mla/example_mla_decode_paged.py
315-315: Unused function argument: block_table
(ARG001)
318-318: Unused function argument: block_size
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Build wheels for Python 3.8 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.8 on ubuntu-24.04-arm with CUDA-12.8
- GitHub Check: Build wheels for Python 3.8 on ubuntu-latest with CUDA-12.1
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str) | ||
parser.add_argument("--hf_path", default="1bitLLM/bitnet_b1_58-3B", type=str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CLI flag is unused — wire it into model loading.
You define --hf_path but ignore it and still hardcode the repo when loading. Parse args in main and pass hf_path to from_pretrained.
Apply:
def main():
- model = BitnetForCausalLM.from_pretrained(
+ args = parser.parse_args()
+ model = BitnetForCausalLM.from_pretrained(
- "1bitLLM/bitnet_b1_58-3B",
+ args.hf_path,
device_map="auto",
low_cpu_mem_usage=True,
use_flash_attention_2=True,
torch_dtype=torch.float16,
).half()
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/bitnet-1.58b/benchmark_inference_latency.py around line 9, the CLI
flag --hf_path is defined but not used; the model repo is still hardcoded when
loading. Update the script to parse arguments inside main (or ensure args are
available where model is loaded) and pass args.hf_path into the model loading
call (from_pretrained) instead of the hardcoded "1bitLLM/bitnet_b1_58-3B";
ensure any default stays the same and propagate hf_path wherever the repo string
is used.
def _forward( | ||
ctx, | ||
q, | ||
k, | ||
v, | ||
block_sparse_mask, | ||
sm_scale, | ||
BLOCK_M=64, | ||
BLOCK_N=64, | ||
num_warps=None, | ||
num_stages=1, | ||
out=None, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or prefix unused ctx
parameter with underscore.
The parameter ctx
is declared but never used in the function body (lines 214–254). Per ruff's ARG001 rule, either remove it, rename it to _ctx
, or implement its intended usage.
If ctx
is truly unused, apply this diff to remove it:
-def _forward(
- ctx,
- q,
- k,
- v,
- block_sparse_mask,
- sm_scale,
- BLOCK_M=64,
- BLOCK_N=64,
- num_warps=None,
- num_stages=1,
- out=None,
-):
+def _forward(
+ q,
+ k,
+ v,
+ block_sparse_mask,
+ sm_scale,
+ BLOCK_M=64,
+ BLOCK_N=64,
+ num_warps=None,
+ num_stages=1,
+ out=None,
+):
Then update the call site at line 262:
- return _forward(ctx, q, k, v, block_sparse_dense, sm_scale)
+ return _forward(q, k, v, block_sparse_dense, sm_scale)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.0)
202-202: Unused function argument: ctx
(ARG001)
🤖 Prompt for AI Agents
In examples/blocksparse_attention/block_sparse_attn_triton.py around lines 201
to 213, the function _forward declares a ctx parameter that is never used; to
satisfy ruff ARG001 either remove ctx from the parameter list or rename it to
_ctx, and then update the corresponding call site at line 262 to drop the
argument or pass _ctx accordingly so signatures match; make the change
consistently (function definition and all call sites) and run the tests/linter
to verify the warning is resolved.
from flash_attn import flash_attn_with_kvcache # fa2 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated import module change in a linting-focused PR.
The import statement changed from flash_attn_interface
(fa3) to flash_attn
(fa2), which is a behavioral change rather than a linting/formatting update. This modification is inconsistent with the PR's stated objective of enabling pyupgrade linter rules and could affect runtime behavior.
#!/bin/bash
# Verify if this import change is intentional or accidental
git log --oneline --all --grep="flash_attn" -- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py | head -5
# Check if other files in the PR also change this import
git diff origin/main HEAD --name-only | xargs -I {} sh -c 'git diff origin/main HEAD -- {} | grep -l "flash_attn_interface\|flash_attn"' 2>/dev/null
🤖 Prompt for AI Agents
In examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py
around lines 423-424 the import was changed from "flash_attn_interface" (fa3) to
"flash_attn" (fa2), which is a behavioral change unrelated to linting; verify
whether this swap was intentional and if not, revert the import back to "from
flash_attn_interface import flash_attn_with_kvcache # fa3" (restore original
symbol and alias), or if the change is required, update the PR description and
run relevant tests/benchmarks to confirm runtime compatibility and update any
call sites or documentation to reflect the new module.
dtype: torch.dtype | None = None, | ||
column_major_scales: bool = False, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Function to perform per-token-group quantization on an input tensor `x`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Python 3.8 compat: avoid PEP 604 unions and built‑in generics in public API.
torch.dtype | None and tuple[...] can break get_type_hints on 3.8 even with future annotations. Use Optional[...] and typing.Tuple[...] if 3.8 is supported.
Alternative if you want to keep None default:
-from typing import Tuple
+from typing import Tuple, Optional
@@
- dtype: torch.dtype | None = None,
+ dtype: Optional[torch.dtype] = None,
@@
-) -> Tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ if dtype is None:
+ dtype = torch.float8_e4m3fn
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/cast/example_triton_cast_to_fp8.py around lines 115 to 118, the
public function signature uses PEP 604 union and built-in generics (torch.dtype
| None and tuple[...]) which can break typing on Python 3.8; change the
annotations to typing-compatible forms by importing Optional and Tuple from
typing and annotate as Optional[torch.dtype] for the dtype parameter and
Tuple[torch.Tensor, torch.Tensor] as the return type (keep the default None for
dtype unchanged).
Default None for dtype causes runtime error.
torch.finfo(dtype) will crash if dtype is None. Either require dtype or set a safe default.
-from __future__ import annotations
+from __future__ import annotations
+from typing import Tuple
@@
-def per_token_group_quant_fp8(
+def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
- dtype: torch.dtype | None = None,
+ dtype: torch.dtype = torch.float8_e4m3fn,
column_major_scales: bool = False,
-) -> tuple[torch.Tensor, torch.Tensor]:
+) -> Tuple[torch.Tensor, torch.Tensor]:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
dtype: torch.dtype | None = None, | |
column_major_scales: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
) -> tuple[torch.Tensor, torch.Tensor]: | |
"""Function to perform per-token-group quantization on an input tensor `x`. | |
dtype: torch.dtype = torch.float8_e4m3fn, | |
column_major_scales: bool = False, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Function to perform per-token-group quantization on an input tensor `x`. |
🤖 Prompt for AI Agents
In examples/cast/example_triton_cast_to_fp8.py around lines 115 to 118, the
function allows dtype: torch.dtype | None = None but later calls
torch.finfo(dtype), which will raise if dtype is None; change the signature to
require a concrete dtype (e.g., torch.dtype = torch.float32) or keep Optional
but assign a safe default before use (e.g., if dtype is None: dtype =
torch.float32) so torch.finfo always receives a valid dtype; update type
annotation accordingly and ensure any callers are adjusted if you choose to make
the parameter mandatory.
assert Q.size(2) == K.size(2) * groups, ( | ||
f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}") | ||
assert Q.size(2) == V.size(2) * groups, ( | ||
f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}") | ||
|
||
dim = Q.size(-1) | ||
K = K.repeat_interleave(groups, dim=2) | ||
V = V.repeat_interleave(groups, dim=2) | ||
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) | ||
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) | ||
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) | ||
if is_causal: | ||
seq_len = Q.size(1) | ||
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) | ||
mask = mask.unsqueeze(0).unsqueeze(0) | ||
scores = scores.masked_fill(mask == 0, float('-inf')) | ||
scores = scores.masked_fill(mask == 0, float("-inf")) | ||
attention_weights = F.softmax(scores, dim=-1) | ||
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) | ||
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix device mismatch in ref_program (CPU scalar tensor on CUDA tensor)
torch.tensor(dim, dtype=scores.dtype)
is CPU by default; dividing CUDA scores
by it errors. Place it on scores.device
(or use a Python scalar).
- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
+ scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))
# or:
+# import math
+# scores = scores * (1.0 / math.sqrt(dim))
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
assert Q.size(2) == K.size(2) * groups, ( | |
f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}") | |
assert Q.size(2) == V.size(2) * groups, ( | |
f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}") | |
dim = Q.size(-1) | |
K = K.repeat_interleave(groups, dim=2) | |
V = V.repeat_interleave(groups, dim=2) | |
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) | |
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) | |
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) | |
if is_causal: | |
seq_len = Q.size(1) | |
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) | |
mask = mask.unsqueeze(0).unsqueeze(0) | |
scores = scores.masked_fill(mask == 0, float('-inf')) | |
scores = scores.masked_fill(mask == 0, float("-inf")) | |
attention_weights = F.softmax(scores, dim=-1) | |
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) | |
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) | |
return output | |
assert Q.size(2) == K.size(2) * groups, ( | |
f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}") | |
assert Q.size(2) == V.size(2) * groups, ( | |
f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}") | |
dim = Q.size(-1) | |
K = K.repeat_interleave(groups, dim=2) | |
V = V.repeat_interleave(groups, dim=2) | |
scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) | |
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device)) | |
if is_causal: | |
seq_len = Q.size(1) | |
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) | |
mask = mask.unsqueeze(0).unsqueeze(0) | |
scores = scores.masked_fill(mask == 0, float("-inf")) | |
attention_weights = F.softmax(scores, dim=-1) | |
output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) | |
return output |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py around lines
177 to 194, the code divides a CUDA tensor by torch.tensor(dim,
dtype=scores.dtype) which is created on CPU causing a device mismatch; replace
that CPU tensor with a device-matched scalar (e.g., use dim as a Python int or
create the tensor on scores.device) so the division occurs without crossing
devices; update the division to either scores = scores / float(dim) or
torch.tensor(dim, dtype=scores.dtype, device=scores.device) to fix the error.
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
Outdated
Show resolved
Hide resolved
011e082
to
f33dd22
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/flash_attention/example_mha_fwd_bshd.py (1)
153-154
: Fix device mismatch in ref_program division (CUDA vs CPU tensor).
torch.tensor(dim, dtype=scores.dtype)
is CPU by default; dividing CUDAscores
by it will error. Adddevice=scores.device
(or use a Python scalar).- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))Alternative (no tensor creation):
import math
andscores = scores * (1.0 / math.sqrt(dim))
.examples/flash_attention/example_mha_bwd_bhsd.py (1)
283-286
: Prevent CUDA/CPU device mismatch in ref_program.Create the
dim
tensor onscores.device
(or use a math scalar) to avoid runtime error.- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
165-168
: Fix CUDA/CPU device mismatch in ref_program.Ensure the
dim
tensor matchesscores.device
(or use a scalar) to prevent runtime errors.- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))
♻️ Duplicate comments (4)
tilelang/autotuner/param.py (1)
53-53
: Unresolved: Type annotation doesn't includeNone
despiteNone
default.Line 53 has
target_host: Union[str, Target] = None
but the type annotation doesn't includeNone
. With pyupgrade's RUF013 enabled, this will fail lint checks.Apply this diff to fix the type annotation:
- target_host: Union[str, Target] = None + target_host: Optional[Union[str, Target]] = NoneOr alternatively:
- target_host: Union[str, Target] = None + target_host: Union[str, Target, None] = Nonetilelang/carver/roller/shape_inference/common.py (1)
47-47
: Add explicit| None
type annotation.The parameter
rstep
has a default value ofNone
but the type annotation doesn't includeNone
. Per PEP 484, this should beDict[str, int] | None = None
.Apply this diff:
- def infer(self, shape, rstep: Dict[str, int] = None): + def infer(self, shape, rstep: Dict[str, int] | None = None):tilelang/carver/roller/node.py (1)
304-305
: Potential decorator syntax issue flagged in previous review.Past review comments indicated that changing
@functools.lru_cache()
to@functools.lru_cache
causes aTypeError
and was addressed in earlier commits. However, the current code still shows the decorator without parentheses. Additionally, static analysis warns about potential memory leaks when usinglru_cache
on methods (B019).Based on learnings
Also applies to: 421-422
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
183-184
: Re: device mismatch in ref_program (still present).Create the scalar tensor on the same device as
scores
(or use a math scalar) to avoid CUDA/CPU mismatch.- scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype, device=scores.device))
🧹 Nitpick comments (1)
examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
357-357
: Complete the f-string conversion for consistency.The print statement mixes string concatenation with an f-string. For full consistency with the pyupgrade modernization goal, convert the entire expression to an f-string.
Apply this diff:
- print(name + f" all_close={all_close}") + print(f"{name} all_close={all_close}")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (64)
docs/conf.py
(1 hunks)examples/attention_sink/benchmark_gqa_sink_fwd.py
(1 hunks)examples/attention_sink/benchmark_mha_sink_fwd.py
(1 hunks)examples/attention_sink/example_gqa_sink_bwd_bhsd.py
(1 hunks)examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/attention_sink/example_mha_sink_bwd_bhsd.py
(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd.py
(1 hunks)examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/bitnet-1.58b/configuration_bitnet.py
(0 hunks)examples/bitnet-1.58b/eval_ppl.py
(1 hunks)examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
(1 hunks)examples/bitnet-1.58b/modeling_bitnet.py
(1 hunks)examples/bitnet-1.58b/tokenization_bitnet.py
(0 hunks)examples/bitnet-1.58b/vllm_workspace/conftest.py
(1 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
(1 hunks)examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py
(1 hunks)examples/cast/example_group_per_split_token_cast_to_fp8.py
(1 hunks)examples/cast/example_per_token_cast_to_fp8.py
(2 hunks)examples/deepseek_mla/example_mla_decode_paged.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
(1 hunks)examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
(1 hunks)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
(1 hunks)examples/flash_attention/example_gqa_bwd.py
(1 hunks)examples/flash_attention/example_gqa_bwd_tma_reduce.py
(1 hunks)examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_gqa_fwd_bshd.py
(1 hunks)examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_bwd.py
(1 hunks)examples/flash_attention/example_mha_bwd_bhsd.py
(1 hunks)examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_fwd_bhsd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd.py
(1 hunks)examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py
(1 hunks)examples/flash_decoding/example_gqa_decode.py
(1 hunks)examples/flash_decoding/example_mha_inference.py
(1 hunks)examples/hadamard_transform/example_hadamard.py
(1 hunks)examples/linear_attention/example_mamba_chunk_scan.py
(1 hunks)examples/linear_attention/example_mamba_chunk_state.py
(2 hunks)examples/minference/example_vertical_slash_sparse_attn.py
(1 hunks)examples/norm/rms_norm.py
(1 hunks)testing/python/kernel/test_tilelang_kernel_gemm.py
(1 hunks)testing/python/kernel/test_tilelang_kernel_gemm_simt.py
(1 hunks)testing/python/language/test_tilelang_language_pipeline.py
(1 hunks)tilelang/autotuner/param.py
(3 hunks)tilelang/cache/kernel_cache.py
(1 hunks)tilelang/carver/arch/cuda.py
(1 hunks)tilelang/carver/roller/bestfit.py
(1 hunks)tilelang/carver/roller/hint.py
(1 hunks)tilelang/carver/roller/node.py
(4 hunks)tilelang/carver/roller/rasterization.py
(1 hunks)tilelang/carver/roller/shape_inference/common.py
(2 hunks)tilelang/carver/roller/shape_inference/tir.py
(2 hunks)tilelang/intrinsics/mfma_macro_generator.py
(1 hunks)tilelang/intrinsics/mma_macro_generator.py
(2 hunks)tilelang/intrinsics/wgmma_macro_generator.py
(1 hunks)tilelang/jit/adapter/libgen.py
(1 hunks)tilelang/jit/adapter/wrapper.py
(6 hunks)tilelang/jit/kernel.py
(1 hunks)tilelang/quantize/lop3.py
(1 hunks)tilelang/quantize/quantization.py
(2 hunks)tilelang/tileop/gemm/gemm_base.py
(1 hunks)
💤 Files with no reviewable changes (2)
- examples/bitnet-1.58b/tokenization_bitnet.py
- examples/bitnet-1.58b/configuration_bitnet.py
✅ Files skipped from review due to trivial changes (10)
- examples/bitnet-1.58b/maint/create_bitblas_ckpt.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py
- examples/norm/rms_norm.py
- examples/bitnet-1.58b/vllm_workspace/conftest.py
- testing/python/language/test_tilelang_language_pipeline.py
- testing/python/kernel/test_tilelang_kernel_gemm.py
- examples/flash_attention/example_gqa_bwd_tma_reduce.py
- examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py
- tilelang/intrinsics/mfma_macro_generator.py
- examples/attention_sink/example_mha_sink_fwd_bhsd.py
🚧 Files skipped from review as they are similar to previous changes (28)
- testing/python/kernel/test_tilelang_kernel_gemm_simt.py
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
- tilelang/carver/roller/bestfit.py
- examples/flash_attention/example_mha_bwd.py
- examples/deepseek_mla/example_mla_decode_paged.py
- examples/attention_sink/example_mha_sink_bwd_bhsd.py
- examples/attention_sink/benchmark_gqa_sink_fwd.py
- examples/minference/example_vertical_slash_sparse_attn.py
- examples/flash_decoding/example_mha_inference.py
- tilelang/jit/kernel.py
- examples/flash_attention/example_gqa_fwd_bshd.py
- examples/flash_attention/example_mha_fwd_bhsd.py
- tilelang/intrinsics/wgmma_macro_generator.py
- tilelang/intrinsics/mma_macro_generator.py
- examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py
- examples/bitnet-1.58b/eval_ppl.py
- examples/linear_attention/example_mamba_chunk_scan.py
- examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
- tilelang/cache/kernel_cache.py
- examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
- examples/flash_attention/example_mha_bwd_wgmma_pipelined.py
- examples/linear_attention/example_mamba_chunk_state.py
- examples/attention_sink/example_gqa_sink_bwd_bhsd.py
- tilelang/carver/roller/rasterization.py
- examples/bitnet-1.58b/modeling_bitnet.py
- examples/hadamard_transform/example_hadamard.py
- tilelang/carver/roller/shape_inference/tir.py
- examples/attention_sink/benchmark_mha_sink_fwd.py
🧰 Additional context used
🧬 Code graph analysis (6)
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
tilelang/carver/roller/shape_inference/common.py (1)
tilelang/carver/roller/shape_inference/tir.py (2)
Statement
(7-43)InputShapeInference
(169-318)
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bshd.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
examples/flash_decoding/example_gqa_decode.py (1)
tilelang/profiler/__init__.py (1)
do_bench
(218-281)
🪛 Ruff (0.14.0)
tilelang/quantize/lop3.py
1189-1189: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/carver/roller/node.py
304-304: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
421-421: Use of functools.lru_cache
or functools.cache
on methods can lead to memory leaks
(B019)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Build wheels for Python 3.8 on ubuntu-latest with CUDA-12.1
- GitHub Check: Build wheels for Python 3.8 on macos-latest with Metal
- GitHub Check: Build wheels for Python 3.8 on ubuntu-24.04-arm with CUDA-12.8
🔇 Additional comments (31)
examples/flash_attention/example_gqa_bwd.py (2)
518-519
: LGTM! Clean modernization to f-strings.The conversion from
.format()
style to f-strings is correct and improves readability while maintaining the same output format.
521-522
: LGTM! Consistent modernization to f-strings.The conversion matches the torch output format above and correctly preserves the formatting specifications.
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
436-437
: LGTM! Clean conversion to f-strings.The print statements have been correctly modernized from
.format()
calls to f-strings. The formatting specifications and expressions are preserved, and f-strings are more readable and performant.examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py (1)
362-362
: LGTM! Clean f-string conversion.The f-string format is properly applied with clear interpolation of multiple values.
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py (1)
198-202
: LGTM! Clean modernization to f-strings.The conversion from
.format()
to f-strings is correct and improves readability while maintaining the same output format. This aligns well with the PR objective of enabling pyupgrade linting rules.examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1)
317-318
: LGTM! Clean f-string conversion.The conversion from
.format()
style to f-strings is correct and preserves the original formatting behavior.tilelang/quantize/quantization.py (2)
226-226
: LGTM! Redundant parentheses removed.The removal of the outer parentheses from
(((val & 63) << 7))
to((val & 63) << 7)
is correct and improves readability without affecting the computation.
235-235
: LGTM! Redundant parentheses removed.Consistent with the change at line 226, the removal of outer parentheses from
(((val & 63) << 7))
to((val & 63) << 7)
is correct and improves code clarity.tilelang/quantize/lop3.py (1)
1189-1191
: LGTM! F-string conversions are correct.The conversions from
.format()
style to f-strings are properly implemented and align with the PR's pyupgrade modernization goals.Note: The static analysis hint (TRY003) on line 1189 suggests defining error messages outside the exception class, but for single-use error messages like this, inline definitions are acceptable and commonly used.
examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1)
329-330
: LGTM! Clean f-string conversion.The conversion from
.format()
to f-strings is syntactically correct, preserves the formatting specifications (:.2f), and aligns with the PR's objective of enabling pyupgrade linting rules. F-strings are more readable and performant for this use case.examples/cast/example_per_token_cast_to_fp8.py (1)
102-116
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability while maintaining identical functionality.docs/conf.py (1)
8-9
: LGTM! Safe simplification of file I/O.Removing the explicit
"r"
mode is safe since it's the default for text file reads.tilelang/carver/arch/cuda.py (1)
94-94
: LGTM! Python 3 modernization.Removing explicit
object
base class is appropriate since all Python 3 classes implicitly inherit fromobject
.tilelang/carver/roller/hint.py (1)
150-150
: LGTM! Python 3 modernization.Removing explicit
object
base class is appropriate since all Python 3 classes implicitly inherit fromobject
.tilelang/autotuner/param.py (2)
249-250
: LGTM! Safe simplification of file I/O.Removing the explicit
"r"
mode is safe since it's the default for text file reads.
318-330
: LGTM! Safe simplification of file I/O.Removing the explicit
"r"
mode from both file reads is safe since it's the default for text file reads.tilelang/jit/adapter/libgen.py (1)
32-32
: LGTM! Python 3 modernization.Removing explicit
object
base class is appropriate since all Python 3 classes implicitly inherit fromobject
.tilelang/tileop/gemm/gemm_base.py (1)
12-12
: LGTM! Python 3 modernization.Removing explicit
object
base class is appropriate since all Python 3 classes implicitly inherit fromobject
.examples/cast/example_group_per_split_token_cast_to_fp8.py (1)
201-204
: LGTM! Clean f-string modernization.The conversion from
.format()
to f-strings improves readability while maintaining identical functionality.tilelang/carver/roller/shape_inference/common.py (1)
7-7
: LGTM! Class declaration syntax modernized.The removal of empty parentheses from class declarations aligns with modern Python style and the pyupgrade linting rules being enabled in this PR.
Also applies to: 21-21
tilelang/carver/roller/node.py (2)
32-32
: LGTM! Class declarations modernized.Removing explicit
object
inheritance aligns with Python 3 conventions where all classes implicitly inherit fromobject
.Also applies to: 93-93
87-88
: The review comment is incorrect. The code still contains quoted references.The actual code at lines 87-88 shows
src_node: 'Node'
anddst_node: 'Node'
with quotes intact. The review claims these forward references were changed to unquoted form, but this change was not made. Since the quotes remain, there is no forward reference issue and no import requirement. The code is correct as written.Likely an incorrect or invalid review comment.
tilelang/jit/adapter/wrapper.py (2)
179-179
: LGTM! Class declarations modernized.Removing explicit
(object)
inheritance is a clean modernization that aligns with Python 3 conventions.Also applies to: 895-895, 1062-1062
325-326
: LGTM! String formatting modernized to f-strings.The conversion from
.format()
to f-strings improves readability and is more idiomatic in modern Python. The logic and output remain unchanged.Also applies to: 350-351, 993-993
examples/flash_attention/example_mha_fwd_bshd.py (1)
193-197
: F-string modernization LGTM.Accurate, clearer, and preserves formatting semantics.
examples/flash_attention/example_mha_bwd_bhsd.py (1)
341-346
: F-string updates look good.No behavior change; formatting preserved.
examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py (1)
225-230
: F-strings: OK.Concise and equivalent formatting.
examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py (1)
211-215
: F-strings applied correctly.Keeps precision and units; no semantic changes.
examples/flash_decoding/example_gqa_decode.py (3)
34-34
: Inconsistency between AI summary and actual code.The AI summary claims the return type annotation was updated to PEP 585 syntax (
tuple[Dict, int]
), but the code still showsTuple[Dict, int]
from thetyping
module. This line was not actually changed in this PR.
476-480
: LGTM! F-string modernization applied correctly.The conversion to f-strings is syntactically correct and improves code readability. The formatting specifications (
.2f
) and inline expressions are properly handled.
10-10
: No action required—code is correct for the target Python version.The codebase targets Python 3.8 (as confirmed in
pyproject.toml
), which requiresTuple
andDict
from thetyping
module. The current code usingfrom typing import Tuple, Dict
andTuple[Dict, int]
is correct and compliant. PEP 585 lowercase generics (e.g.,tuple
,dict
) are only available in Python 3.9+. ruff's pyupgrade rules correctly passed without violations.
and add FA102 for missing
from __future__ import annotations
. Prevent issues like #959.Or use UP035 (with FA102) only if this is too wide.
This should be sufficient to keep py38 compatibility for a long time.
Summary by CodeRabbit
Style
Chores
tuple
instead ofTuple
).Configuration