From a839ef3d687b31ec7f57f736a77ada4603e7ab7e Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 18 Jan 2024 19:20:38 +0000 Subject: [PATCH] add source comments --- tch-cuda/build.rs | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tch-cuda/build.rs b/tch-cuda/build.rs index 6e6c2759..c7143633 100644 --- a/tch-cuda/build.rs +++ b/tch-cuda/build.rs @@ -1,12 +1,12 @@ -// Build script to run nvcc and generate the C glue code for launching the flash-attention kernel. -// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment -// variable in order to cache the compiled artifacts and avoid recompiling too often. use anyhow::{Context, Result}; use rayon::prelude::*; use std::env; use std::path::PathBuf; use std::str::FromStr; +// build logic from https://github.com/huggingface/candle/blob/30313c308106fff7b20fc8cb2b27eb79800cb818/candle-flash-attn/build.rs + +// torch detection from https://github.com/LaurentMazare/tch-rs/blob/5480d6fd4be12e748e0d87555db54a5f6e74edf2/torch-sys/build.rs const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" import torch from torch.utils import cpp_extension @@ -134,30 +134,30 @@ impl SystemInfo { } const KERNEL_FILES: [&str; 25] = [ - "flash_api.cpp", - "flash_fwd_split_hdim128_bf16_sm80.cu", - "flash_fwd_split_hdim160_bf16_sm80.cu", - "flash_fwd_split_hdim192_bf16_sm80.cu", - "flash_fwd_split_hdim224_bf16_sm80.cu", - "flash_fwd_split_hdim256_bf16_sm80.cu", - "flash_fwd_split_hdim32_bf16_sm80.cu", - "flash_fwd_split_hdim64_bf16_sm80.cu", - "flash_fwd_split_hdim96_bf16_sm80.cu", - "flash_fwd_split_hdim128_fp16_sm80.cu", - "flash_fwd_split_hdim160_fp16_sm80.cu", - "flash_fwd_split_hdim192_fp16_sm80.cu", - "flash_fwd_split_hdim224_fp16_sm80.cu", - "flash_fwd_split_hdim256_fp16_sm80.cu", - "flash_fwd_split_hdim32_fp16_sm80.cu", - "flash_fwd_split_hdim64_fp16_sm80.cu", - "flash_fwd_split_hdim96_fp16_sm80.cu", + "flash_attn/flash_api.cpp", + "flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu", + "flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu", + "flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu", "vllm/activation_kernels.cu", "vllm/cache_kernels.cu", "vllm/cuda_utils_kernels.cu", "vllm/layernorm_kernels.cu", "vllm/pos_encoding_kernels.cu", "vllm/attention/attention_kernels.cu", - "vllm/bindings.cpp", + "vllm_bindings.cpp", "cuda.cpp", ];