Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Unifying AOT and JIT C++ Code via C Macros #706

Open
yzh119 opened this issue Dec 29, 2024 · 0 comments
Open

[RFC] Unifying AOT and JIT C++ Code via C Macros #706

yzh119 opened this issue Dec 29, 2024 · 0 comments

Comments

@yzh119
Copy link
Collaborator

yzh119 commented Dec 29, 2024

Background and Motivation

Our current JIT (Just-In-Time) infrastructure is fragile because it embeds C++ template code as Python strings. This design introduces redundancy and complicates maintenance, since we must keep two copies of the same C++/Python binding code:

  1. One version stored as a Python string for JIT.
  2. Another version stored in .cpp files for AOT (Ahead-Of-Time) compilation.

Whenever we change the interface, both copies must be updated, which is error-prone. Additionally, we cannot leverage syntax highlighting or other development tools (like linting) for the embedded JIT code.

Proposed Solution

Per discussions with @lsrcz, we realized that we are not using advanced Jinja features in our JIT code. The few string substitutions that we do require can be implemented with C macros. By switching to a macro-based approach, we can unify the JIT and AOT code bases:

  1. Shared C++ Source
    Use the same .cpp and .h files for both AOT and JIT modes.

  2. Macro-Generated Headers

    • JIT Mode: Generate a header file that defines constant expressions for each JIT instance.
    • AOT Mode: Generate a header file that dispatches between different parameters or attention variants.
  3. Variant and Parameter Definitions
    Both attention variant definitions and additional parameters can be expressed as macros generated by Python. This approach allows us to maintain a single source of truth for all configurations while still supporting both compilation modes.

This macro-based design ensures we have only one copy of the code, which simplifies maintenance, reduces errors, and makes debugging easier by allowing syntax highlighting and other compiler-based tooling.

Implementation Details

Below are two illustrative headers: one for JIT mode (showing generated constants) and one for AOT mode (showing parameter dispatch). Both rely on the same C++ code for the core functionality.

JIT Header

// header from JIT instance

constexpr int HEAD_DIM = 128;
using DTypeQ = half;

// Do not dispatch anything here
#define DISPATCH_head_dim(expr, const_expr, ...) \
  __VA_ARGS__

template <typename ParamsT_>
struct FlashSigmoid {
  using ParamsT = ParamsT_;
  using DTypeQ = typename ParamsT::DTypeQ;
  using DTypeKV = typename ParamsT::DTypeKV;
  using DTypeO = typename ParamsT::DTypeO;
  using IdType = typename ParamsT::IdType;

  static constexpr bool use_softmax = false;

  uint32_t window_left, qo_len, kv_len;
  float sigmoid_bias_log2e;

  // Create closure
  __device__ __host__ FlashSigmoid(const ParamsT& params, uint32_t batch_idx, uint8_t* smem_ptr) {
    qo_len = params.get_qo_len(batch_idx);
    kv_len = params.get_kv_len(batch_idx);
    window_left = kv_len;
    sigmoid_bias_log2e = params.sigmoid_bias * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T QueryTransform(const ParamsT& params, T q) {
    return float(q) * params.logits_scale * math::log2e;
  }

  template <typename T>
  __device__ __forceinline__ T LogitsTransform(
      const ParamsT& params,
      T logits,
      uint32_t batch_idx,
      uint32_t qo_idx,
      uint32_t kv_idx,
      uint32_t qo_head_idx,
      uint32_t kv_head_idx) {
    return math::ptx_rcp(1.f + math::ptx_exp2(-float(logits + sigmoid_bias_log2e)));
  }

  __device__ __forceinline__ bool LogitsMask(
      const ParamsT& params,
      uint32_t batch_idx,
      uint32_t qo_idx,
      uint32_t kv_idx,
      uint32_t qo_head_idx,
      uint32_t kv_head_idx) {
    return true;
  }
};

#define ATTENTION_VARIANT LogitsSoftCap
#define ADDITIONAL_PARAMS_DECL float logits_scale, float sigmoid_bias

AOT Header

// header from AOT

#define _DISPATCH_CASES_head_dim(case_var, ...) \
  _DISPATCH_CASE(64, case_var, __VA_ARGS__)     \
  _DISPATCH_CASE(128, case_var, __VA_ARGS__)    \
  _DISPATCH_CASE(256, case_var, __VA_ARGS__)    \
// EOL

#define ATTENTION_VARIANT ComposedAttention<...>
// Empty additional parameters
#define ADDITIONAL_PARAMS_DECL

In this approach:

ATTENTION_VARIANT can be overridden to provide different attention implementations.
ADDITIONAL_PARAMS_DECL can be defined or left empty depending on the mode (JIT or AOT).

Conclusion

By removing the Python-string-based Jinja templates and transitioning to a C macro–based approach, we eliminate duplicate code for JIT and AOT, reduce the possibility of interface mismatch, and regain the benefits of compiler tooling (e.g., syntax highlighting, linting). We believe this unified approach will be easier to maintain, more robust, and simpler to extend in the future.

cc @lsrcz, @hyhieu for visibility and feedback.

@yzh119 yzh119 changed the title [RFC] JIT refactor [RFC] Unifying AOT and JIT C++ Code via C Macros Dec 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant