Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

This pull request introduces several improvements to the test configuration and code organization for multi-head attention (MHA) kernels, as well as a new utility function for device architecture detection. The main changes include adding support for configurable data types in the MHA tests, refactoring imports for better modularity, and introducing an is_hopper utility function.

Test configuration improvements

  • Added a --dtype argument to the MHA test script (tests/test_mha.py), allowing selection between float16 and bfloat16 data types, and updated the test function to use this argument. Also adjusted defaults for --causal and --tune arguments to improve usability. [1] [2]
  • Changed the default value of the --tune argument in the MHA decode test script (tests/test_mha_decode.py) to False for consistency with other test scripts.

Code organization

  • Refactored kernel imports in top/__init__.py and top/kernel/__init__.py to import MHADecodeKernel from its own module (mha_decode) instead of mha, improving modularity and maintainability. [1] [2]
  • Added an __all__ declaration to top/kernel/mla.py to explicitly define the public API of the module.

New utility function

  • Added a new is_hopper function in top/utils/utils.py to detect if the current CUDA device is of the Hopper architecture (compute capability 9.0).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant