Skip to content

Conversation

@BourneSun0527
Copy link
Contributor

Add lightning_indexer and sparse_flash_attention

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @BourneSun0527, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the custom operator library by adding lightning_indexer and sparse_flash_attention. These operators are specifically optimized for Huawei Ascend Atlas A3 inference products, aiming to improve performance in large language models. The changes include their core C++ implementations, integration into the CMake build process, and user-friendly Python interfaces with detailed examples and documentation.

Highlights

  • New Operator Implementations: Introduced two new custom operators, lightning_indexer and sparse_flash_attention, designed for Huawei Ascend Atlas A3 inference products.
  • Comprehensive Build System Integration: Integrated the new operators into the CMake build system, including their C++ definitions, tiling logic, and Python bindings.
  • Python Bindings and Examples: Provided Python bindings and example usage for both lightning_indexer and sparse_flash_attention, supporting eager and graph modes for seamless integration with PyTorch.
  • Detailed Documentation: Added extensive documentation in Chinese for each operator, covering functionality, parameters, constraints, and usage examples.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two new operators, lightning_indexer and sparse_flash_attention, along with their build configurations, documentation, and example usage. The changes are extensive, covering CMake build scripts, C++ kernel and host implementations, Python bindings, and documentation. My review focuses on improving code quality, maintainability, and correctness. Key feedback includes removing compiler warning suppression, fixing typos in code and build scripts, improving test assertions for correctness, and enhancing code clarity by replacing magic numbers and translating comments to English.

fi

CUSTOM_OPTION="-DBUILD_OPEN_PROJECT=ON"
CUSTOM_OPTION+=" -DCMAKE_CXX_FLAGS=\"-w\" -DCMAKE_C_FLAGS=\"-w\""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Suppressing all compiler warnings with -w is a dangerous practice as it can hide potential bugs and issues in the code. It's better to address warnings directly or, if they are known false positives, to suppress them selectively.

Suggested change
CUSTOM_OPTION+=" -DCMAKE_CXX_FLAGS=\"-w\" -DCMAKE_C_FLAGS=\"-w\""
CUSTOM_OPTION+=""

Comment on lines +183 to +190
# compare result
npu_out = npu_out.reshape(-1, sparse_count).cpu()
cpuout = cpuout.reshape(-1, sparse_count).cpu()
t = npu_out.shape[0]
for i in range(t):
for j in range(sparse_count):
if npu_out[i][j] != cpuout[i][j]:
print("t K npu cpu = ", i, j, npu_out[i][j], cpuout[i][j])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test case iterates through the results and prints differences, but it does not include any assertions. This means the test will always pass, regardless of whether the NPU output matches the CPU output. A proper test should use assertions to validate the results and fail if they don't match. This issue is present in all test methods within this file.

Please replace the printing logic with an assertion, for example, using self.assertEqual for integer tensors.

Suggested change
# compare result
npu_out = npu_out.reshape(-1, sparse_count).cpu()
cpuout = cpuout.reshape(-1, sparse_count).cpu()
t = npu_out.shape[0]
for i in range(t):
for j in range(sparse_count):
if npu_out[i][j] != cpuout[i][j]:
print("t K npu cpu = ", i, j, npu_out[i][j], cpuout[i][j])
# compare result
self.assertEqual(npu_out.cpu(), cpuout.cpu())

HOST_TILING="true"
shift
;;
--disable-check-compatible|--disable-check-compatiable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the option --disable-check-compatiable. It should be --disable-check-compatible.

Suggested change
--disable-check-compatible|--disable-check-compatiable)
--disable-check-compatible)

if [ "${BUILD}" == "host" ];then
cmake_config -DENABLE_OPS_KERNEL=OFF
build_host
# TO DO
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

A TO DO comment is present, which suggests incomplete work. This should be resolved or converted into a proper issue in the project's issue tracker for future reference.

set(_OP_LIST)
set(_OP_DIR_LIST)

file(GLOB OP_HOST_CMAKE_FILES "${CMAKE_CURRENT_SOURCE_DIR}/src/lightning_indexer/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/src/sparse_flash_attention/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/src/**/ophost/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/src/**/**/**/ophost/CMakeLists.txt" "${CMAKE_CURRENT_SOURCE_DIR}/src/**/**/**/CMakeLists.txt")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The operator discovery mechanism uses a file(GLOB ...) command with hardcoded operator names (lightning_indexer, sparse_flash_attention). This is not easily extensible. When new operators are added, this line will need to be manually updated. A more robust approach would be to dynamically discover operators by iterating through the subdirectories of src/.

Comment on lines +58 to +59
this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048); // 2048:默认值,筛选前2048
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3:默认值,只计算下三角
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments are in Chinese. For consistency and to make the code accessible to a wider audience, it's recommended to write comments in English.

        this->Attr("sparse_count").AttrType(OPTIONAL).Int(2048);  // 2048: default value, top 2048 to be selected
        this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3);      // 3: default value, only compute lower triangle

OPS_LOG_E_IF_NULL(context, attrs, return ge::GRAPH_FAILED);
const char *inputLayoutQueryPtr = attrs->GetAttrPointer<char>(ATTR_QUERY_LAYOUT_INDEX);
OPS_LOG_E_IF_NULL(context, inputLayoutQueryPtr, return ge::GRAPH_FAILED);
const int64_t *seleced_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the variable name seleced_count. It should be selected_count. This typo appears in multiple places in this function.

    const int64_t *selected_count = attrs->GetInt(ATTR_SPARSE_COUNT_INDEX);

OPS_LOG_I(context_->GetNodeName(), "layout_key is:%s", opParamInfo_.layOutKey);
}
if (opParamInfo_.sparseCount != nullptr) {
OPS_LOG_I(context_->GetNodeName(), "selscted count is:%d", *opParamInfo_.sparseCount);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the log message: "selscted count". It should be "selected count".

        OPS_LOG_I(context_->GetNodeName(), "selected count is:%d", *opParamInfo_.sparseCount);

Comment on lines +564 to +574
constexpr uint32_t MM1_RES_ELEM_SIZE = 4; // 4: fp32
constexpr uint32_t DOUBLE_BUFFER = 2; // 双Buffer
constexpr uint32_t M_BASE_SIZE = 512; // m轴基本块大小
constexpr uint32_t S2_BASE_SIZE = 512; // S2轴基本块大小
constexpr uint32_t V1_RES_ELEM_SIZE = 4; // 4: int32
constexpr uint32_t V1_RES_ELEM_TYPE = 2; // 保留Index和Value 2种数据
constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = 8; // 8: int64
constexpr uint32_t V1_DECODE_PARAM_NUM = 16; // Decode参数个数
constexpr uint32_t V1_DECODE_DATA_NUM = 2; // Decode每个核需要存储头和尾部两块数据
constexpr uint32_t S1_BASE_SIZE = 8; // S1轴基本块的大小
constexpr uint32_t TOPK_MAX_SIZE = 2048; // TopK选取个数
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The workspace size calculation uses several magic numbers for element sizes. This makes the code harder to read and maintain. It's better to use sizeof for data types. Also, some comments are in Chinese and should be translated to English for consistency.

    constexpr uint32_t MM1_RES_ELEM_SIZE = sizeof(float);
    constexpr uint32_t DOUBLE_BUFFER = 2;              // Double buffer
    constexpr uint32_t M_BASE_SIZE = 512;              // M-axis base block size
    constexpr uint32_t S2_BASE_SIZE = 512;             // S2-axis base block size
    constexpr uint32_t V1_RES_ELEM_SIZE = sizeof(int32_t);
    constexpr uint32_t V1_RES_ELEM_TYPE = 2;           // Reserve 2 types of data: Index and Value
    constexpr uint32_t V1_DECODE_PARAM_ELEM_SIZE = sizeof(int64_t);
    constexpr uint32_t V1_DECODE_PARAM_NUM = 16;       // Number of Decode parameters
    constexpr uint32_t V1_DECODE_DATA_NUM = 2;         // Decode needs to store header and tail data for each core
    constexpr uint32_t S1_BASE_SIZE = 8;               // S1-axis base block size
    constexpr uint32_t TOPK_MAX_SIZE = 2048;           // TopK selection count

this->Attr("sparse_block_size").AttrType(REQUIRED).Int(1);
this->Attr("layout_query").AttrType(OPTIONAL).String("BSND");
this->Attr("layout_kv").AttrType(OPTIONAL).String("BSND");
this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3); // 3:默认值,只计算下三角
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment is in Chinese. For consistency and to make the code accessible to a wider audience, it's recommended to write comments in English.

        this->Attr("sparse_mode").AttrType(OPTIONAL).Int(3);  // 3: default value, only compute lower triangle

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant