Skip to content

Commit

Permalink
bug fix for splitk
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Dec 19, 2024
1 parent ffc31ee commit 08ce1fc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 4 additions & 6 deletions bitblas/base/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from bitblas.base.base_scheduler import BaseScheduler
from bitblas.base.utils import apply_and_build as tir_apply_and_build
from bitblas.tl.tuner import apply_and_build as tl_apply_and_build
from bitblas.utils import retrieve_func_from_module
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -376,17 +377,14 @@ def fast_tune_with_dynamic_range_tilelang(
_, best = fast_tune(unit_scheduler, target, topk, parallel_build)
if best is None:
return None
specialized_func = best.sch.mod["main"]
specialized_func = retrieve_func_from_module(best.sch.mod)
function_symbol = global_symbol
if kernel_name_generator is not None:
scheduled_mod = best.sch.mod
best_hint = best.config
assert len(scheduled_mod.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert "main" in scheduled_mod, (
"The optimized module should have a function named 'main' for default schedule.")
default_kernal_name = kernel_name_generator.generate(best_hint)
specialized_func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
prim_func = retrieve_func_from_module(scheduled_mod)
specialized_func = prim_func.with_attr("global_symbol", default_kernal_name)
function_symbol = default_kernal_name

function_symbols.append(function_symbol)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def serialize_hints_to_configs(self, hints: List[Hint]):

# Double the split-k factor and check if the resulting K-dimension size is too large
expand_split_k = split_k_factor * 2
if K % (expand_split_k * block_K) != 0:
break
if expand_split_k * block_K >= K:
break

Expand Down

0 comments on commit 08ce1fc

Please sign in to comment.