Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

This pull request introduces a new modular kernel and function structure for multi-head attention (MHA), refactors imports to support the new organization, and adds comprehensive testing and profiling capabilities. The main focus is on making the MHA kernel implementation more extensible and easier to autotune, while providing a standardized interface for kernel and function classes.

MHA kernel and function modularization

  • Added Kernel base class (top/kernels/kernel.py) and refactored the MHA forward kernel into a new class mha_fwd_kernel_sm80 with default configs and autotuning support (top/kernels/mha.py). [1] [2]
  • Introduced Function base class and implemented the mha_fwd function class, which wraps the kernel, provides reference implementation, input generation, correctness checks, and profiling (top/functions/function.py, top/functions/mha.py). [1] [2]

Refactoring and import changes

  • Updated package-level __init__.py files to use the new kernels and functions modules, simplifying imports and exposing the new API (top/__init__.py, top/functions/__init__.py, top/kernels/__init__.py). [1] [2] [3]

Testing and profiling

  • Added a new test script tests/v2/test_mha.py that allows command-line configuration of MHA parameters and runs correctness and performance checks for the new kernel and function classes.

These changes lay the groundwork for future kernel and function extensions, enable easier autotuning, and provide a more robust testing and profiling workflow.

@Rachmanino Rachmanino changed the title [Refactor] Refactor TileOps for better hierarchical abstraction and designs [RFC][Refactor] Refactor TileOps for better hierarchical abstraction and designs Oct 11, 2025
self.total_flops *= 0.5

# TODO: dispatch to different kernels based on archs and inputs
self.kernel = mha_fwd_kernel_sm80(batch, heads, seq_len, dim, is_causal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future kernel dispatching, it might be useful to keep a whitelist of supported architectures per function. Since TileLang kernels can typically run across multiple architectures, we could register the supported architectures for each kernel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. But sometimes kernels for different archs have distinct computation schedules, and we may have to dispatch in function if one kernel is not sufficiently compatible with all archs?

Comment on lines +42 to +60
def gen_inputs(self):
return (torch.randn([self.batch, self.seq_len, self.heads, self.dim],
dtype=torch.float16, device='cuda') for _ in range(3))

def check(self):
Q, K, V = self.gen_inputs()
o, _ = self.forward(Q, K, V) # lse is only used for bwd
o_ref = self.ref_program(Q, K, V)
assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}'
print("All checks passed.✅")

def profile(self, warmup=100, rep=100):
# TODO: support cupti backend for better accuracy
Q, K, V = self.gen_inputs()
with torch.no_grad():
tl_latency = do_bench(lambda: self.forward(Q, K, V), warmup=warmup, rep=rep)

print(f"Tilelang latency: {tl_latency:.2f} ms")
print(f"Tilelang TFlops: {self.total_flops / tl_latency * 1e-9:.2f} TFlops") No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three functions are primarily used for benchmarking (both correctness and performance). They should be an independent object rather than being coupled with Function.

# TODO: dispatch to different kernels based on archs and inputs
self.kernel = mha_fwd_kernel_sm80(batch, heads, seq_len, dim, is_causal)

def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add asserts (either here or in the kernel file) to check that the inputs — shape, data type, etc. — are supported. This will prevent undefined behavior and make it clear to users why a failure occurs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but will it bring unnecessary runtime overhead? Or we may use an argument to control whether to perform such assertations and checks?


def forward(self, *args, **kwargs):
raise NotImplementedError("forward method is not implemented")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about backward?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the document explained, function here does not have backward impl, and its forward() method is merely an alias of __call__(). We use a pair of fwd&bwd functions to compose an op(erator)

args = parser.parse_args()
B, S, H, D, causal = args.batch, args.seq_len, args.heads, args.dim, args.causal

test_mha_kernel(B, S, H, D, causal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could provide some common input configurations along with a script to generate the benchmark results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants