Skip to content

Commit

Permalink
add source comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Jan 18, 2024
1 parent d758cdb commit a839ef3
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tch-cuda/build.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel.
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result};
use rayon::prelude::*;
use std::env;
use std::path::PathBuf;
use std::str::FromStr;

// build logic from https://github.com/huggingface/candle/blob/30313c308106fff7b20fc8cb2b27eb79800cb818/candle-flash-attn/build.rs

// torch detection from https://github.com/LaurentMazare/tch-rs/blob/5480d6fd4be12e748e0d87555db54a5f6e74edf2/torch-sys/build.rs
const PYTHON_PRINT_PYTORCH_DETAILS: &str = r"
import torch
from torch.utils import cpp_extension
Expand Down Expand Up @@ -134,30 +134,30 @@ impl SystemInfo {
}

const KERNEL_FILES: [&str; 25] = [
"flash_api.cpp",
"flash_fwd_split_hdim128_bf16_sm80.cu",
"flash_fwd_split_hdim160_bf16_sm80.cu",
"flash_fwd_split_hdim192_bf16_sm80.cu",
"flash_fwd_split_hdim224_bf16_sm80.cu",
"flash_fwd_split_hdim256_bf16_sm80.cu",
"flash_fwd_split_hdim32_bf16_sm80.cu",
"flash_fwd_split_hdim64_bf16_sm80.cu",
"flash_fwd_split_hdim96_bf16_sm80.cu",
"flash_fwd_split_hdim128_fp16_sm80.cu",
"flash_fwd_split_hdim160_fp16_sm80.cu",
"flash_fwd_split_hdim192_fp16_sm80.cu",
"flash_fwd_split_hdim224_fp16_sm80.cu",
"flash_fwd_split_hdim256_fp16_sm80.cu",
"flash_fwd_split_hdim32_fp16_sm80.cu",
"flash_fwd_split_hdim64_fp16_sm80.cu",
"flash_fwd_split_hdim96_fp16_sm80.cu",
"flash_attn/flash_api.cpp",
"flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu",
"flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu",
"flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu",
"vllm/activation_kernels.cu",
"vllm/cache_kernels.cu",
"vllm/cuda_utils_kernels.cu",
"vllm/layernorm_kernels.cu",
"vllm/pos_encoding_kernels.cu",
"vllm/attention/attention_kernels.cu",
"vllm/bindings.cpp",
"vllm_bindings.cpp",
"cuda.cpp",
];

Expand Down

0 comments on commit a839ef3

Please sign in to comment.