-
Notifications
You must be signed in to change notification settings - Fork 8
[RFC][Refactor] Refactor TileOps for better hierarchical abstraction and designs #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| self.total_flops *= 0.5 | ||
|
|
||
| # TODO: dispatch to different kernels based on archs and inputs | ||
| self.kernel = mha_fwd_kernel_sm80(batch, heads, seq_len, dim, is_causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For future kernel dispatching, it might be useful to keep a whitelist of supported architectures per function. Since TileLang kernels can typically run across multiple architectures, we could register the supported architectures for each kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. But sometimes kernels for different archs have distinct computation schedules, and we may have to dispatch in function if one kernel is not sufficiently compatible with all archs?
| def gen_inputs(self): | ||
| return (torch.randn([self.batch, self.seq_len, self.heads, self.dim], | ||
| dtype=torch.float16, device='cuda') for _ in range(3)) | ||
|
|
||
| def check(self): | ||
| Q, K, V = self.gen_inputs() | ||
| o, _ = self.forward(Q, K, V) # lse is only used for bwd | ||
| o_ref = self.ref_program(Q, K, V) | ||
| assert torch.allclose(o, o_ref, atol=1e-2, rtol=1e-2), f'o max err: {(o - o_ref).abs().max()}' | ||
| print("All checks passed.✅") | ||
|
|
||
| def profile(self, warmup=100, rep=100): | ||
| # TODO: support cupti backend for better accuracy | ||
| Q, K, V = self.gen_inputs() | ||
| with torch.no_grad(): | ||
| tl_latency = do_bench(lambda: self.forward(Q, K, V), warmup=warmup, rep=rep) | ||
|
|
||
| print(f"Tilelang latency: {tl_latency:.2f} ms") | ||
| print(f"Tilelang TFlops: {self.total_flops / tl_latency * 1e-9:.2f} TFlops") No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three functions are primarily used for benchmarking (both correctness and performance). They should be an independent object rather than being coupled with Function.
| # TODO: dispatch to different kernels based on archs and inputs | ||
| self.kernel = mha_fwd_kernel_sm80(batch, heads, seq_len, dim, is_causal) | ||
|
|
||
| def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add asserts (either here or in the kernel file) to check that the inputs — shape, data type, etc. — are supported. This will prevent undefined behavior and make it clear to users why a failure occurs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but will it bring unnecessary runtime overhead? Or we may use an argument to control whether to perform such assertations and checks?
|
|
||
| def forward(self, *args, **kwargs): | ||
| raise NotImplementedError("forward method is not implemented") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about backward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the document explained, function here does not have backward impl, and its forward() method is merely an alias of __call__(). We use a pair of fwd&bwd functions to compose an op(erator)
| args = parser.parse_args() | ||
| B, S, H, D, causal = args.batch, args.seq_len, args.heads, args.dim, args.causal | ||
|
|
||
| test_mha_kernel(B, S, H, D, causal) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could provide some common input configurations along with a script to generate the benchmark results.
This pull request introduces a new modular kernel and function structure for multi-head attention (MHA), refactors imports to support the new organization, and adds comprehensive testing and profiling capabilities. The main focus is on making the MHA kernel implementation more extensible and easier to autotune, while providing a standardized interface for kernel and function classes.
MHA kernel and function modularization
Kernelbase class (top/kernels/kernel.py) and refactored the MHA forward kernel into a new classmha_fwd_kernel_sm80with default configs and autotuning support (top/kernels/mha.py). [1] [2]Functionbase class and implemented themha_fwdfunction class, which wraps the kernel, provides reference implementation, input generation, correctness checks, and profiling (top/functions/function.py,top/functions/mha.py). [1] [2]Refactoring and import changes
__init__.pyfiles to use the newkernelsandfunctionsmodules, simplifying imports and exposing the new API (top/__init__.py,top/functions/__init__.py,top/kernels/__init__.py). [1] [2] [3]Testing and profiling
tests/v2/test_mha.pythat allows command-line configuration of MHA parameters and runs correctness and performance checks for the new kernel and function classes.These changes lay the groundwork for future kernel and function extensions, enable easier autotuning, and provide a more robust testing and profiling workflow.