-
Notifications
You must be signed in to change notification settings - Fork 275
[Parallel] Support T.Parallel with dynamic extents
#990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8f22baf
e540e54
b31c459
3d1f300
bf771db
b7b2a1d
f806710
21c6bef
7312e53
ddd3ed7
071699c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, | |
| ICHECK(thread_var.defined()); | ||
| int old_loop_depth = loop_layout->InputDim(); | ||
| int new_loop_depth = loop_layout->OutputDim(); | ||
|
|
||
| // Create the new loop iter var | ||
| Array<Var> vars; | ||
| for (int i = 0; i < new_loop_depth; i++) { | ||
| Var var = Var(std::string{char('i' + i)}); | ||
| analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype), | ||
| loop_layout->OutputShape()[i])); | ||
| vars.push_back(var); | ||
| } | ||
| vars.push_back(thread_var); | ||
| // create the substitute map, and the loop body | ||
| Map<Var, PrimExpr> vmap; | ||
| Stmt body = std::move(op); | ||
| auto inv_loop = loop_layout->Inverse(); | ||
| Array<PrimExpr> loop_mins; | ||
| Array<PrimExpr> loop_extents; | ||
| auto inverse_info = loop_layout->InverseWithLevel(); | ||
| auto inv_loop = inverse_info.first; | ||
| // Must check the guard if the layout can not be proved as bijective | ||
| bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective; | ||
| auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end())); | ||
| // Normalize thread var once so we can reuse the same substitution later. | ||
| Map<Var, PrimExpr> thread_offset_map; | ||
| bool has_thread_offset = false; | ||
| if (loop_layout->ThreadRange().defined()) { | ||
| auto range = loop_layout->ThreadRange(); | ||
| thread_offset_map.Set(thread_var, thread_var - range->min); | ||
| has_thread_offset = true; | ||
| } | ||
| for (int i = 0; i < old_loop_depth; i++) { | ||
| const ForNode *loop = body.as<ForNode>(); | ||
| ICHECK(loop != nullptr); | ||
| vmap.Set(loop->loop_var, indices[i]); | ||
| loop_mins.push_back(loop->min); | ||
| loop_extents.push_back(loop->extent); | ||
| body = loop->body; | ||
|
Comment on lines
+86
to
100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Substitute loop bounds incrementally before collecting. The past review comment remains unaddressed. When an inner loop's At iteration Apply this diff to fix the substitution: vmap.Set(loop->loop_var, indices[i]);
- loop_mins.push_back(loop->min);
- loop_extents.push_back(loop->extent);
+ loop_mins.push_back(Substitute(loop->min, vmap));
+ loop_extents.push_back(Substitute(loop->extent, vmap));
body = loop->body;🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
| // substitute and re-construct the serial loop | ||
| body = Substitute(body, vmap); | ||
| // Guard executes the recovered loop body only if each inverse-mapped iterator | ||
| // falls back into the original For ranges. We first check every axis from the | ||
| // old loop nest (old_loop_depth) and then the extra index produced by inverse | ||
| // layouts that carry a replicate/thread component (`inv_output_shape`). Both | ||
| // must stay within bounds to ensure correctness. Example: layout([i, j]) = | ||
| // floor((i * 16 + j) / 32) may generate extra points when the new loop | ||
| // enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j) | ||
| // or replicate index fall outside their original extents. | ||
| // Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points | ||
| // when the new loop enumerates 0..31; this guard skips iterations where the | ||
| // inverse i, j land outside the original extents. This protects | ||
| // non-surjective loop_layout mappings that otherwise over-cover the parallel | ||
| // space. | ||
| PrimExpr guard = const_true(); | ||
|
|
||
| if (need_guard) { | ||
| for (int i = 0; i < old_loop_depth; i++) { | ||
| PrimExpr index = indices[i]; | ||
| if (has_thread_offset) { | ||
| index = Substitute(index, thread_offset_map); | ||
| } | ||
| PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]); | ||
| PrimExpr upper_bound = | ||
| analyzer->Simplify(index < loop_mins[i] + loop_extents[i]); | ||
| guard = And(guard, And(lower_bound, upper_bound)); | ||
| } | ||
| auto inv_output_shape = inv_loop->OutputShape(); | ||
| if (inv_output_shape.size() > static_cast<size_t>(old_loop_depth)) { | ||
| PrimExpr replicate_index = indices[old_loop_depth]; | ||
| if (has_thread_offset) { | ||
| replicate_index = Substitute(replicate_index, thread_offset_map); | ||
| } | ||
| PrimExpr replicate_extent = inv_output_shape[old_loop_depth]; | ||
| PrimExpr lower_bound = analyzer->Simplify( | ||
| replicate_index >= make_zero(replicate_index.dtype())); | ||
| PrimExpr upper_bound = | ||
| analyzer->Simplify(replicate_index < replicate_extent); | ||
| guard = And(guard, And(lower_bound, upper_bound)); | ||
| } | ||
| PrimExpr simplified_guard = analyzer->Simplify(guard); | ||
| if (!analyzer->CanProve(simplified_guard)) { | ||
| body = IfThenElse(simplified_guard, body, Stmt()); | ||
| } | ||
| } | ||
|
|
||
| for (int i = new_loop_depth - 1; i >= 0; i--) { | ||
| body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], | ||
| ForKind::kSerial, body); | ||
|
|
@@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, | |
|
|
||
| body = BufferIndiceSimplify(analyzer)(body); | ||
|
|
||
| auto for_node = LoopPragmaUnroll(Downcast<For>(body)); | ||
| if (loop_layout->ThreadRange().defined()) { | ||
| auto range = loop_layout->ThreadRange(); | ||
| auto thread_var_with_offset = thread_var - range->min; | ||
| for_node.CopyOnWrite()->body = | ||
| Substitute(for_node->body, {{thread_var, thread_var_with_offset}}); | ||
| if (has_thread_offset) { | ||
| body = Substitute(body, thread_offset_map); | ||
| } | ||
|
|
||
| auto for_node = LoopPragmaUnroll(Downcast<For>(body)); | ||
| return for_node; | ||
| } | ||
|
|
||
|
|
@@ -111,6 +169,10 @@ class LoopPramaUnroller : public StmtExprMutator { | |
| private: | ||
| Stmt VisitStmt_(const ForNode *node) final { | ||
| if (node->kind == ForKind::kSerial) { | ||
| auto analyzer = std::make_shared<arith::Analyzer>(); | ||
| if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) { | ||
| return StmtExprMutator::VisitStmt_(node); | ||
| } | ||
| For new_for = GetRef<For>(node); | ||
| auto for_ptr = new_for.CopyOnWrite(); | ||
| for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false)); | ||
|
|
@@ -127,22 +189,20 @@ class LoopPartitioner : public StmtExprVisitor { | |
|
|
||
| Fragment Partition(const For &op, int num_thread, int vectorize_size) { | ||
| this->VisitStmt(op); | ||
| int loop_size_full = 1; | ||
| PrimExpr flattened = 0; | ||
| ICHECK(!loop_vars_.empty()); | ||
| DataType dtype = loop_vars_[0]->var.dtype(); | ||
| PrimExpr flattened = make_const(dtype, 0); | ||
| PrimExpr vector_extent = make_const(dtype, vectorize_size); | ||
| PrimExpr thread_extent_const = make_const(dtype, num_thread); | ||
| for (size_t i = 0; i < loop_vars_.size(); i++) { | ||
| auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent); | ||
| ICHECK(ext_ptr) | ||
| << "Loop partitioner only works with constant loop sizes, but got " | ||
| << loop_vars_[i]->dom->extent; | ||
| int extent = *ext_ptr; | ||
| loop_size_full *= extent; | ||
| PrimExpr extent = loop_vars_[i]->dom->extent; | ||
| flattened = flattened * extent + loop_vars_[i]->var; | ||
| } | ||
| ICHECK(loop_size_full % vectorize_size == 0); | ||
| PrimExpr access_idx = FloorDiv(flattened, vectorize_size); | ||
| PrimExpr thd = FloorMod(access_idx, num_thread); | ||
| PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size + | ||
| FloorMod(flattened, vectorize_size); | ||
| PrimExpr access_idx = FloorDiv(flattened, vector_extent); | ||
| PrimExpr thd = FloorMod(access_idx, thread_extent_const); | ||
| PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent + | ||
| FloorMod(flattened, vector_extent); | ||
|
|
||
| auto fragment = Fragment(loop_vars_, {idx}, {thd}, {}); | ||
| if (has_fragment_) { | ||
| // for fragment buffer, we don't need to replicate the loop layout | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import torch | ||
| import tilelang.testing | ||
| import pytest | ||
|
|
||
| tilelang.testing.set_random_seed() | ||
|
|
||
|
|
||
| @tilelang.jit(out_idx=[1]) | ||
| def parallel_elementwise_static(length=256, dtype="float32"): | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((length,), dtype), | ||
| B: T.Tensor((length,), dtype), | ||
| ): | ||
| with T.Kernel(1, threads=length) as _: | ||
| for i in T.Parallel(length): | ||
| B[i] = A[i] + 1.0 | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| @tilelang.jit(out_idx=[1]) | ||
| def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((max_len,), dtype), | ||
| B: T.Tensor((max_len,), dtype), | ||
| valid_len: T.int32, | ||
| ): | ||
| with T.Kernel(1, threads=threads) as _: | ||
| for i in T.Parallel(max_len): | ||
| B[i] = 0.0 | ||
| span = T.min(valid_len, max_len) | ||
| for i in T.Parallel(span): | ||
| B[i] = A[i] - 1.0 | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def _require_cuda_tensor(shape, dtype=torch.float32): | ||
| if not torch.cuda.is_available(): | ||
| pytest.skip("CUDA not available") | ||
| try: | ||
| return torch.randn(*shape, device="cuda", dtype=dtype) | ||
| except RuntimeError as err: | ||
| pytest.skip(f"CUDA runtime unavailable: {err}") | ||
|
|
||
|
|
||
| def test_parallel_static_extent(): | ||
| kernel = parallel_elementwise_static(length=256) | ||
| data = _require_cuda_tensor((256,), torch.float32) | ||
| result = kernel(data) | ||
| torch.testing.assert_close(result, data + 1.0, atol=1e-5, rtol=1e-5) | ||
|
|
||
|
|
||
| def test_parallel_dynamic_extent(): | ||
| kernel = parallel_elementwise_dynamic(max_len=512, threads=256) | ||
| data = _require_cuda_tensor((512,), torch.float32) | ||
| for valid_len in [0, 13, 200, 600]: | ||
| out = kernel(data, valid_len) | ||
| reference = torch.zeros_like(data) | ||
| clip = min(valid_len, data.shape[0]) | ||
| reference[:clip] = data[:clip] - 1.0 | ||
| torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify the NoCheck fallback doesn't mask correctness issues.
The warning states "Runtime guards keep dynamic tails safe," but ensure that all call sites that obtain
InverseWithLevel()actually use the returned level to generate appropriate guards. If any caller ignores the level and assumes Bijective, dynamic shapes could produce incorrect results.Run the following script to check which code paths use
InverseWithLevel()and whether they respect the returned level:🏁 Script executed:
Length of output: 7817
Need to dive deeper to understand the full picture. Let me search for the wrapper methods and their callers:Now let me search for more specific information about the context of this code:
Length of output: 5676
Let me check if the callers of
Inverse()have mechanisms to detect and handle symbolic shapes:Length of output: 3891
Verify the NoCheck fallback doesn't mask correctness issues — CONCERN CONFIRMED.
The review comment identifies a real design flaw. Investigation reveals:
InverseWithLevel() properly detects symbolic dimensions (lines 232-280) and returns
NoChecklevel when symbolic dims exist, with a warning logged.However, the public Inverse() wrapper method discards the returned level (line 283 and line 395), returning only the Layout object.
Multiple call sites use Inverse() without guard detection:
src/op/parallel.cc:65(ProveFragmentContains)src/op/parallel.cc:540andsrc/op/parallel.cc:614src/op/copy.cc:1025These callers cannot detect whether the inverse layout requires runtime guards.
Contrast: loop_partition.cc:81 properly uses InverseWithLevel() and checks the level before proceeding:
bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;The issue: The warning states guards keep dynamic tails safe, but Inverse() callers have no mechanism to know when guards are needed. If any of these callers operates on symbolic layouts, they could silently produce incorrect results without guards.
The fix requires either: (a) documenting that Inverse() assumes static/bijective layouts only, or (b) modifying callers to either detect symbolic shapes independently or switch to InverseWithLevel().
🤖 Prompt for AI Agents