Skip to content

Commit 0f5ec37

Browse files
Update docs
1 parent ac0dbd2 commit 0f5ec37

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

_sources/autoapi/tilelang/language/reduce/index.rst.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,30 @@ Module Contents
146146

147147
Negative `dim` indices are normalized (Python-style). If `dst` is None, the operation is performed in-place into `src`. Raises ValueError when `dim` is out of bounds for `src.shape`. When `src.scope() == "local.fragment"`, this delegates to `cumsum_fragment`; otherwise it emits the `tl.cumsum` intrinsic.
148148

149+
.. rubric:: Examples
150+
151+
A 1D inclusive scan that writes the result into a separate shared-memory buffer:
152+
153+
>>> import tilelang.language as T
154+
>>> @T.prim_func
155+
... def kernel(A: T.Tensor((128,), "float32"), B: T.Tensor((128,), "float32")):
156+
... with T.Kernel(1, threads=128):
157+
... A_shared = T.alloc_shared((128,), "float32")
158+
... T.copy(A, A_shared)
159+
... T.cumsum(src=A_shared, dst=A_shared, dim=0)
160+
... T.copy(A_shared, B)
161+
162+
A 2D prefix sum along the last dimension with reverse accumulation:
163+
164+
>>> import tilelang.language as T
165+
>>> @T.prim_func
166+
... def kernel2d(A: T.Tensor((64, 64), "float16"), B: T.Tensor((64, 64), "float16")):
167+
... with T.Kernel(1, 1, threads=256):
168+
... tile = T.alloc_shared((64, 64), "float16")
169+
... T.copy(A, tile)
170+
... T.cumsum(src=tile, dim=1, reverse=True)
171+
... T.copy(tile, B)
172+
149173
:returns: A handle to the emitted cumulative-sum operation.
150174
:rtype: tir.Call
151175

autoapi/tilelang/language/reduce/index.html

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,29 @@ <h2>Module Contents<a class="headerlink" href="#module-contents" title="Link to
695695
<span class="sig-prename descclassname"><span class="pre">tilelang.language.reduce.</span></span><span class="sig-name descname"><span class="pre">cumsum</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">src</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dst</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dim</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reverse</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#tilelang.language.reduce.cumsum" title="Link to this definition"></a></dt>
696696
<dd><p>Compute the cumulative sum of <cite>src</cite> along <cite>dim</cite>, writing results to <cite>dst</cite>.</p>
697697
<p>Negative <cite>dim</cite> indices are normalized (Python-style). If <cite>dst</cite> is None, the operation is performed in-place into <cite>src</cite>. Raises ValueError when <cite>dim</cite> is out of bounds for <cite>src.shape</cite>. When <cite>src.scope() == “local.fragment”</cite>, this delegates to <cite>cumsum_fragment</cite>; otherwise it emits the <cite>tl.cumsum</cite> intrinsic.</p>
698+
<p class="rubric">Examples</p>
699+
<p>A 1D inclusive scan that writes the result into a separate shared-memory buffer:</p>
700+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span><span class="w"> </span><span class="nn">tilelang.language</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">T</span>
701+
<span class="gp">&gt;&gt;&gt; </span><span class="nd">@T</span><span class="o">.</span><span class="n">prim_func</span>
702+
<span class="gp">... </span><span class="k">def</span><span class="w"> </span><span class="nf">kernel</span><span class="p">(</span><span class="n">A</span><span class="p">:</span> <span class="n">T</span><span class="o">.</span><span class="n">Tensor</span><span class="p">((</span><span class="mi">128</span><span class="p">,),</span> <span class="s2">&quot;float32&quot;</span><span class="p">),</span> <span class="n">B</span><span class="p">:</span> <span class="n">T</span><span class="o">.</span><span class="n">Tensor</span><span class="p">((</span><span class="mi">128</span><span class="p">,),</span> <span class="s2">&quot;float32&quot;</span><span class="p">)):</span>
703+
<span class="gp">... </span> <span class="k">with</span> <span class="n">T</span><span class="o">.</span><span class="n">Kernel</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">threads</span><span class="o">=</span><span class="mi">128</span><span class="p">):</span>
704+
<span class="gp">... </span> <span class="n">A_shared</span> <span class="o">=</span> <span class="n">T</span><span class="o">.</span><span class="n">alloc_shared</span><span class="p">((</span><span class="mi">128</span><span class="p">,),</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span>
705+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">A_shared</span><span class="p">)</span>
706+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">src</span><span class="o">=</span><span class="n">A_shared</span><span class="p">,</span> <span class="n">dst</span><span class="o">=</span><span class="n">A_shared</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
707+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">A_shared</span><span class="p">,</span> <span class="n">B</span><span class="p">)</span>
708+
</pre></div>
709+
</div>
710+
<p>A 2D prefix sum along the last dimension with reverse accumulation:</p>
711+
<div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span><span class="w"> </span><span class="nn">tilelang.language</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">T</span>
712+
<span class="gp">&gt;&gt;&gt; </span><span class="nd">@T</span><span class="o">.</span><span class="n">prim_func</span>
713+
<span class="gp">... </span><span class="k">def</span><span class="w"> </span><span class="nf">kernel2d</span><span class="p">(</span><span class="n">A</span><span class="p">:</span> <span class="n">T</span><span class="o">.</span><span class="n">Tensor</span><span class="p">((</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="s2">&quot;float16&quot;</span><span class="p">),</span> <span class="n">B</span><span class="p">:</span> <span class="n">T</span><span class="o">.</span><span class="n">Tensor</span><span class="p">((</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="s2">&quot;float16&quot;</span><span class="p">)):</span>
714+
<span class="gp">... </span> <span class="k">with</span> <span class="n">T</span><span class="o">.</span><span class="n">Kernel</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">threads</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span>
715+
<span class="gp">... </span> <span class="n">tile</span> <span class="o">=</span> <span class="n">T</span><span class="o">.</span><span class="n">alloc_shared</span><span class="p">((</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="s2">&quot;float16&quot;</span><span class="p">)</span>
716+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">A</span><span class="p">,</span> <span class="n">tile</span><span class="p">)</span>
717+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">src</span><span class="o">=</span><span class="n">tile</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
718+
<span class="gp">... </span> <span class="n">T</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">tile</span><span class="p">,</span> <span class="n">B</span><span class="p">)</span>
719+
</pre></div>
720+
</div>
698721
<dl class="field-list simple">
699722
<dt class="field-odd">Returns<span class="colon">:</span></dt>
700723
<dd class="field-odd"><p>A handle to the emitted cumulative-sum operation.</p>

searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)