Skip to content

Commit

Permalink
Use flash-attn in gemma. (#2195)
Browse files Browse the repository at this point in the history
* Use flash-attn in gemma.

* Fix flash-attn for head dim 256.
  • Loading branch information
LaurentMazare authored May 18, 2024
1 parent eefc1c7 commit 7ebc354
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 20 deletions.
5 changes: 4 additions & 1 deletion candle-examples/examples/gemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ struct Args {
/// The model to use.
#[arg(long, default_value = "2b")]
which: Which,

#[arg(long)]
use_flash_attn: bool,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -270,7 +273,7 @@ fn main() -> Result<()> {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(args.use_flash_attn, &config, vb)?;

println!("loaded the model in {:?}", start.elapsed());

Expand Down
4 changes: 4 additions & 0 deletions candle-flash-attn/kernels/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
Expand Down
4 changes: 3 additions & 1 deletion candle-flash-attn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ impl FlashAttn {

let elem_count = out_shape.elem_count();
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let softmax_lse = dev
.alloc_zeros::<f32>(b_sz * 128 * num_heads * seqlen_q)
.w()?;

let is_bf16 = if is_bf16 { 1 } else { 0 };

Expand Down
62 changes: 44 additions & 18 deletions candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ struct RotaryEmbedding {
cos: Tensor,
}

fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}

impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.head_dim;
Expand All @@ -94,7 +87,6 @@ impl RotaryEmbedding {
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
Expand All @@ -110,10 +102,8 @@ impl RotaryEmbedding {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
Ok((q_embed, k_embed))
}
}
Expand Down Expand Up @@ -163,10 +153,16 @@ struct Attention {
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_flash_attn: bool,
}

impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
Expand All @@ -188,6 +184,7 @@ impl Attention {
head_dim,
rotary_emb,
kv_cache: None,
use_flash_attn,
})
}

Expand Down Expand Up @@ -231,7 +228,14 @@ impl Attention {
let value_states =
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;

let attn_output = {
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, scale, attention_mask.is_some())?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;

Expand All @@ -253,6 +257,22 @@ impl Attention {
}
}

#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}

#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}

#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
Expand All @@ -262,8 +282,13 @@ struct DecoderLayer {
}

impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
fn new(
rotary_emb: Arc<RotaryEmbedding>,
use_flash_attn: bool,
cfg: &Config,
vb: VarBuilder,
) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, use_flash_attn, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
Expand Down Expand Up @@ -312,15 +337,16 @@ pub struct Model {
}

impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
pub fn new(use_flash_attn: bool, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
let layer =
DecoderLayer::new(rotary_emb.clone(), use_flash_attn, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
Expand Down

0 comments on commit 7ebc354

Please sign in to comment.