Skip to content

Commit 2bb6829

Browse files
committed
[Xe] Additional comments
1 parent 460d34a commit 2bb6829

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ template <class FMHAKernel> struct ExampleRunner {
375375
auto head_size_qk = shape.head_size_qk = options.head_size_qk;
376376
auto head_size_vo = shape.head_size_vo = options.head_size_vo;
377377

378+
// Set up strides.
379+
// These lines can be adjusted to support different data layouts, as needed.
378380
stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch));
379381
stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch));
380382
stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch));

include/cute/tensor_sg.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ struct is_sg_tensor<SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type
9898
template <class Engine, class Layout, class SubgroupTVLayout>
9999
struct is_tensor<SubgroupTensor<Engine,Layout,SubgroupTVLayout>> : true_type {};
100100

101+
// Create a SubgroupTensor from its component parts:
102+
// a regular rmem Tensor and the subgroup-scope TV-layout.
101103
template <class Engine,
102104
class Layout,
103105
class SubgroupTVLayout,
@@ -111,6 +113,9 @@ make_subgroup_tensor(Tensor<Engine, Layout> const& tensor, SubgroupTVLayout cons
111113
return static_cast<SubgroupTensor<Engine,Layout,SubgroupTVLayout> const&>(tensor);
112114
}
113115

116+
// Create a new owning SubgroupTensor with the given subgroup-level layout.
117+
// Elements are assigned to threads following the normal Xe interleaved mapping
118+
// (i.e. work-item i gets elements i, i + 16, i + 32, ...)
114119
template <typename T, class Shape, class Stride>
115120
CUTE_HOST_DEVICE
116121
constexpr auto
@@ -123,6 +128,8 @@ make_subgroup_tensor(Layout<Shape,Stride> const& sg_layout)
123128
return make_subgroup_tensor(make_fragment_like<T>(sv_layout(0,_)), sv_layout);
124129
}
125130

131+
// Create a new owning SubgroupTensor with a subgroup-level layout, constructed
132+
// from the argument list with make_layout.
126133
template <typename T, class... Args>
127134
CUTE_HOST_DEVICE
128135
constexpr auto

0 commit comments

Comments
 (0)