Skip to content
43 changes: 37 additions & 6 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,34 @@ Fragment FragmentNode::BindThreadRange(Range thread_range) const {
return Fragment(n);
}

Layout LayoutNode::Inverse() const {
std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
arith::Analyzer analyzer;
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(input_size_);
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
}
Comment on lines +251 to +257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify the NoCheck fallback doesn't mask correctness issues.

The warning states "Runtime guards keep dynamic tails safe," but ensure that all call sites that obtain InverseWithLevel() actually use the returned level to generate appropriate guards. If any caller ignores the level and assumes Bijective, dynamic shapes could produce incorrect results.

Run the following script to check which code paths use InverseWithLevel() and whether they respect the returned level:


🏁 Script executed:

#!/bin/bash
# Description: Find all usages of InverseWithLevel and verify they handle the returned level

# Search for InverseWithLevel calls
echo "=== InverseWithLevel call sites ==="
rg -nP 'InverseWithLevel\s*\(' --type=cpp -A 5

# Search for uses of the level from the pair
echo -e "\n=== Checking if returned level is used ==="
rg -nP '(auto|std::pair).*InverseWithLevel' --type=cpp -A 10

Length of output: 7817


Need to dive deeper to understand the full picture. Let me search for the wrapper methods and their callers:Now let me search for more specific information about the context of this code:


---

🏁 Script executed:

```shell
#!/bin/bash
# Get the broader context around the warning in layout.cc
echo "=== Context of warning in layout.cc (lines 245-260) ==="
sed -n '245,260p' src/layout/layout.cc

# Find the complete Inverse() method and check what it does
echo -e "\n=== Inverse() method around line 284 ==="
sed -n '280,290p' src/layout/layout.cc

# Find FragmentNode's wrapper Inverse() around line 396
echo -e "\n=== FragmentNode Inverse() around line 396 ==="
sed -n '390,400p' src/layout/layout.cc

# Check for any guards or assertions when Inverse() is called
echo -e "\n=== Search for usage patterns of Inverse() method ==="
rg -nP 'Inverse\(\)' --type=cpp -B 2 -A 3 -g '!test' -g '!*_test.cc'

Length of output: 5676


Let me check if the callers of Inverse() have mechanisms to detect and handle symbolic shapes:


---

🏁 Script executed:

```shell
#!/bin/bash
# Check the context of Inverse() calls to see if they guard against symbolic shapes
echo "=== Checking if parallel.cc:65 handles symbolic shapes ==="
sed -n '60,75p' src/op/parallel.cc

echo -e "\n=== Checking if parallel.cc:540 handles symbolic shapes ==="
sed -n '535,550p' src/op/parallel.cc

echo -e "\n=== Checking if copy.cc:1025 handles symbolic shapes ==="
sed -n '1020,1035p' src/op/copy.cc

# Check function signatures to understand context
echo -e "\n=== Function containing parallel.cc:65 ==="
rg -nB 20 'large_frag->Inverse' src/op/parallel.cc | head -40

echo -e "\n=== Check if there's any documented assumption about Inverse() ==="
rg -nP 'Inverse' src/layout/layout.h -B 3 -A 1

Length of output: 3891


Verify the NoCheck fallback doesn't mask correctness issues — CONCERN CONFIRMED.

The review comment identifies a real design flaw. Investigation reveals:

  1. InverseWithLevel() properly detects symbolic dimensions (lines 232-280) and returns NoCheck level when symbolic dims exist, with a warning logged.

  2. However, the public Inverse() wrapper method discards the returned level (line 283 and line 395), returning only the Layout object.

  3. Multiple call sites use Inverse() without guard detection:

    • src/op/parallel.cc:65 (ProveFragmentContains)
    • src/op/parallel.cc:540 and src/op/parallel.cc:614
    • src/op/copy.cc:1025

    These callers cannot detect whether the inverse layout requires runtime guards.

  4. Contrast: loop_partition.cc:81 properly uses InverseWithLevel() and checks the level before proceeding: bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;

The issue: The warning states guards keep dynamic tails safe, but Inverse() callers have no mechanism to know when guards are needed. If any of these callers operates on symbolic layouts, they could silently produce incorrect results without guards.

The fix requires either: (a) documenting that Inverse() assumes static/bijective layouts only, or (b) modifying callers to either detect symbolic shapes independently or switch to InverseWithLevel().

🤖 Prompt for AI Agents
In src/layout/layout.cc around lines 251-257 (and note the Inverse wrapper at
~lines 283 and 395), the current Inverse() discards the IterMapLevel returned by
InverseWithLevel() causing callers to be unable to detect when the inverse is
only valid with runtime guards; update the implementation and call sites so
callers don't silently miss this: either (1) change/overload Inverse() to
propagate the level (e.g., return a pair or provide an InverseChecked API) so
callers can inspect the IterMapLevel, or (2) update each caller that currently
calls Inverse() (src/op/parallel.cc lines ~65, ~540, ~614; src/op/copy.cc line
~1025) to call InverseWithLevel() directly and handle non-bijective results by
inserting the appropriate runtime guard or failing with a clear error, and add a
DCHECK or documentation in Inverse() that it must only be used for
static/bijective layouts if you keep the simpler API.

arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1,
arith::IterMapLevel::Bijective, &analyzer);
arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
ICHECK(res->errors.empty())
<< "Layout " << DebugOutput() << " has errors: " << res->errors;

Expand All @@ -254,9 +277,13 @@ Layout LayoutNode::Inverse() const {
}
}

return Layout(outputs_shape, backward_index);
return {Layout(outputs_shape, backward_index), level};
}

Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first);
}
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread,
arith::Analyzer *analyzer) {
Expand Down Expand Up @@ -366,15 +393,19 @@ PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
}

Layout FragmentNode::Inverse() const {
auto result = InverseWithLevel();
return std::move(result.first);
}

std::pair<Layout, arith::IterMapLevel> FragmentNode::InverseWithLevel() const {
auto input_size_copy = input_size_;
input_size_copy.push_back(ReplicateExtent());
auto forward_index_copy = forward_index_;
forward_index_copy.push_back(
Substitute(forward_thread_,
{{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
auto fwd = Layout(input_size_copy, forward_index_copy);
auto bwd = fwd->Inverse();
return bwd;
return fwd->InverseWithLevel();
}

Fragment FragmentNode::CondenseReplicateVar() const {
Expand Down
4 changes: 4 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#define TVM_TL_LAYOUT_LAYOUT_H_

#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <utility>

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -36,6 +38,7 @@ class LayoutNode : public Object {
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;

virtual Layout Inverse() const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;

virtual std::string DebugOutput() const;

Expand Down Expand Up @@ -76,6 +79,7 @@ class FragmentNode : public LayoutNode {
Array<PrimExpr> GetForwardVars() const final;

Layout Inverse() const final;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;

PrimExpr ThreadExtent() const;

Expand Down
104 changes: 82 additions & 22 deletions src/transform/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim();
int new_loop_depth = loop_layout->OutputDim();

// Create the new loop iter var
Array<Var> vars;
for (int i = 0; i < new_loop_depth; i++) {
Var var = Var(std::string{char('i' + i)});
analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype),
loop_layout->OutputShape()[i]));
vars.push_back(var);
}
vars.push_back(thread_var);
// create the substitute map, and the loop body
Map<Var, PrimExpr> vmap;
Stmt body = std::move(op);
auto inv_loop = loop_layout->Inverse();
Array<PrimExpr> loop_mins;
Array<PrimExpr> loop_extents;
auto inverse_info = loop_layout->InverseWithLevel();
auto inv_loop = inverse_info.first;
// Must check the guard if the layout can not be proved as bijective
bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
// Normalize thread var once so we can reuse the same substitution later.
Map<Var, PrimExpr> thread_offset_map;
bool has_thread_offset = false;
if (loop_layout->ThreadRange().defined()) {
auto range = loop_layout->ThreadRange();
thread_offset_map.Set(thread_var, thread_var - range->min);
has_thread_offset = true;
}
for (int i = 0; i < old_loop_depth; i++) {
const ForNode *loop = body.as<ForNode>();
ICHECK(loop != nullptr);
vmap.Set(loop->loop_var, indices[i]);
loop_mins.push_back(loop->min);
loop_extents.push_back(loop->extent);
body = loop->body;
Comment on lines +86 to 100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Substitute loop bounds incrementally before collecting.

The past review comment remains unaddressed. When an inner loop's min or extent depends on an outer iterator (e.g., triangular loops for j in [i, N)), lines 94–95 capture those bounds without substituting the accumulated vmap. After the new loops are constructed, the old loop variables become free symbols, rendering the guard at lines 119–121 ill‑formed and incorrectly rejecting valid iterations at runtime.

At iteration i of the extraction loop (lines 90–96), vmap already holds substitutions for loops 0..i-1. Before appending to loop_mins and loop_extents, apply Substitute(loop->min, vmap) and Substitute(loop->extent, vmap) so that the collected bounds are expressed solely in terms of the new iterator variables.

Apply this diff to fix the substitution:

     vmap.Set(loop->loop_var, indices[i]);
-    loop_mins.push_back(loop->min);
-    loop_extents.push_back(loop->extent);
+    loop_mins.push_back(Substitute(loop->min, vmap));
+    loop_extents.push_back(Substitute(loop->extent, vmap));
     body = loop->body;
🤖 Prompt for AI Agents
In src/transform/loop_partition.cc around lines 82 to 96, the loop mins and
extents are being collected without applying the current vmap, which leaves
inner-loop bounds depending on outer-loop variables and later produces
ill-formed guards; fix this by substituting accumulated mappings into the bounds
before storing them: replace pushing loop->min and loop->extent with pushing
Substitute(loop->min, vmap) and Substitute(loop->extent, vmap) (also ensure any
thread_offset_map substitution is applied consistently if needed) so the
collected bounds are expressed only in terms of new iterator variables.

}

// substitute and re-construct the serial loop
body = Substitute(body, vmap);
// Guard executes the recovered loop body only if each inverse-mapped iterator
// falls back into the original For ranges. We first check every axis from the
// old loop nest (old_loop_depth) and then the extra index produced by inverse
// layouts that carry a replicate/thread component (`inv_output_shape`). Both
// must stay within bounds to ensure correctness. Example: layout([i, j]) =
// floor((i * 16 + j) / 32) may generate extra points when the new loop
// enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j)
// or replicate index fall outside their original extents.
// Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points
// when the new loop enumerates 0..31; this guard skips iterations where the
// inverse i, j land outside the original extents. This protects
// non-surjective loop_layout mappings that otherwise over-cover the parallel
// space.
PrimExpr guard = const_true();

if (need_guard) {
for (int i = 0; i < old_loop_depth; i++) {
PrimExpr index = indices[i];
if (has_thread_offset) {
index = Substitute(index, thread_offset_map);
}
PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]);
PrimExpr upper_bound =
analyzer->Simplify(index < loop_mins[i] + loop_extents[i]);
guard = And(guard, And(lower_bound, upper_bound));
}
auto inv_output_shape = inv_loop->OutputShape();
if (inv_output_shape.size() > static_cast<size_t>(old_loop_depth)) {
PrimExpr replicate_index = indices[old_loop_depth];
if (has_thread_offset) {
replicate_index = Substitute(replicate_index, thread_offset_map);
}
PrimExpr replicate_extent = inv_output_shape[old_loop_depth];
PrimExpr lower_bound = analyzer->Simplify(
replicate_index >= make_zero(replicate_index.dtype()));
PrimExpr upper_bound =
analyzer->Simplify(replicate_index < replicate_extent);
guard = And(guard, And(lower_bound, upper_bound));
}
PrimExpr simplified_guard = analyzer->Simplify(guard);
if (!analyzer->CanProve(simplified_guard)) {
body = IfThenElse(simplified_guard, body, Stmt());
}
}

for (int i = new_loop_depth - 1; i >= 0; i--) {
body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
ForKind::kSerial, body);
Expand All @@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,

body = BufferIndiceSimplify(analyzer)(body);

auto for_node = LoopPragmaUnroll(Downcast<For>(body));
if (loop_layout->ThreadRange().defined()) {
auto range = loop_layout->ThreadRange();
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
if (has_thread_offset) {
body = Substitute(body, thread_offset_map);
}

auto for_node = LoopPragmaUnroll(Downcast<For>(body));
return for_node;
}

Expand All @@ -111,6 +169,10 @@ class LoopPramaUnroller : public StmtExprMutator {
private:
Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kSerial) {
auto analyzer = std::make_shared<arith::Analyzer>();
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
return StmtExprMutator::VisitStmt_(node);
}
For new_for = GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
Expand All @@ -127,22 +189,20 @@ class LoopPartitioner : public StmtExprVisitor {

Fragment Partition(const For &op, int num_thread, int vectorize_size) {
this->VisitStmt(op);
int loop_size_full = 1;
PrimExpr flattened = 0;
ICHECK(!loop_vars_.empty());
DataType dtype = loop_vars_[0]->var.dtype();
PrimExpr flattened = make_const(dtype, 0);
PrimExpr vector_extent = make_const(dtype, vectorize_size);
PrimExpr thread_extent_const = make_const(dtype, num_thread);
for (size_t i = 0; i < loop_vars_.size(); i++) {
auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent);
ICHECK(ext_ptr)
<< "Loop partitioner only works with constant loop sizes, but got "
<< loop_vars_[i]->dom->extent;
int extent = *ext_ptr;
loop_size_full *= extent;
PrimExpr extent = loop_vars_[i]->dom->extent;
flattened = flattened * extent + loop_vars_[i]->var;
}
ICHECK(loop_size_full % vectorize_size == 0);
PrimExpr access_idx = FloorDiv(flattened, vectorize_size);
PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorMod(flattened, vectorize_size);
PrimExpr access_idx = FloorDiv(flattened, vector_extent);
PrimExpr thd = FloorMod(access_idx, thread_extent_const);
PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent +
FloorMod(flattened, vector_extent);

auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
if (has_fragment_) {
// for fragment buffer, we don't need to replicate the loop layout
Expand Down
2 changes: 1 addition & 1 deletion src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto extent_ptr = as_const_int(node->extent);
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
Expand Down
72 changes: 72 additions & 0 deletions testing/python/language/test_tilelang_language_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import pytest

tilelang.testing.set_random_seed()


@tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"):

@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length):
B[i] = A[i] + 1.0

return main


@tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):

@T.prim_func
def main(
A: T.Tensor((max_len,), dtype),
B: T.Tensor((max_len,), dtype),
valid_len: T.int32,
):
with T.Kernel(1, threads=threads) as _:
for i in T.Parallel(max_len):
B[i] = 0.0
span = T.min(valid_len, max_len)
for i in T.Parallel(span):
B[i] = A[i] - 1.0

return main


def _require_cuda_tensor(shape, dtype=torch.float32):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")


def test_parallel_static_extent():
kernel = parallel_elementwise_static(length=256)
data = _require_cuda_tensor((256,), torch.float32)
result = kernel(data)
torch.testing.assert_close(result, data + 1.0, atol=1e-5, rtol=1e-5)


def test_parallel_dynamic_extent():
kernel = parallel_elementwise_dynamic(max_len=512, threads=256)
data = _require_cuda_tensor((512,), torch.float32)
for valid_len in [0, 13, 200, 600]:
out = kernel(data, valid_len)
reference = torch.zeros_like(data)
clip = min(valid_len, data.shape[0])
reference[:clip] = data[:clip] - 1.0
torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
tilelang.testing.main()