Skip to content

Commit 1193add

Browse files
authored
Fix Qwen3 (#646)
1 parent d1d65cf commit 1193add

File tree

2 files changed

+3171
-3157
lines changed

2 files changed

+3171
-3157
lines changed

backends/candle/src/models/qwen3.rs

Lines changed: 99 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct Qwen3Config {
2323
pub rope_theta: f32,
2424
pub sliding_window: Option<usize>,
2525
pub use_sliding_window: bool,
26+
pub eos_token_id: usize,
2627
}
2728

2829
struct Qwen3Attention {
@@ -164,8 +165,8 @@ impl Qwen3Attention {
164165
.concat(),
165166
)?;
166167

167-
let (q, _res) = self.q_norm.forward(&q, None)?;
168-
let (k, _res) = self.k_norm.forward(&k, None)?;
168+
let (q, _) = self.q_norm.forward(&q, None)?;
169+
let (k, _) = self.k_norm.forward(&k, None)?;
169170

170171
let q = q.transpose(1, 2)?;
171172
let k = k.transpose(1, 2)?;
@@ -355,16 +356,21 @@ impl Qwen3Layer {
355356
) -> Result<Tensor> {
356357
let _enter = self.span.enter();
357358

358-
let (normed_hidden_states, res) = self.input_layer_norm.forward(hidden_states, None)?;
359+
let (normed_hidden_states, residual) =
360+
self.input_layer_norm.forward(hidden_states, None)?;
361+
359362
let attn_output =
360363
self.attention
361364
.forward(&normed_hidden_states, attention_bias, cos, sin)?;
365+
362366
let (normed_attn_res_output, attn_res) = self
363367
.post_attention_layer_norm
364-
.forward(&attn_output, Some(&res))?;
368+
.forward(&attn_output, Some(&residual))?;
369+
365370
let mlp_output = self.mlp.forward(&normed_attn_res_output)?;
366371

367372
let output = (&mlp_output + &attn_res)?;
373+
368374
Ok(output)
369375
}
370376
}
@@ -378,6 +384,7 @@ pub struct Qwen3Model {
378384
pool: Pool,
379385
pub device: Device,
380386
num_attention_heads: usize,
387+
pad_token_id: u32,
381388

382389
span: tracing::Span,
383390
}
@@ -427,12 +434,35 @@ impl Qwen3Model {
427434
rotary_cache,
428435
rotary_dim,
429436
pool,
437+
pad_token_id: config.eos_token_id as u32,
430438
device: vb.device().clone(),
431439
num_attention_heads: config.num_attention_heads,
432440
span: tracing::span!(tracing::Level::TRACE, "model"),
433441
})
434442
}
435443

444+
fn get_causal_attention_bias(&self, attention_bias: Tensor) -> Result<Tensor> {
445+
let (bs, dim, seq_len, _) = attention_bias.dims4()?;
446+
447+
let device = attention_bias.device();
448+
449+
let mask: Vec<u8> = (0..seq_len)
450+
.flat_map(|i| (0..seq_len).map(move |j| (j > i) as u8))
451+
.collect();
452+
453+
let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), &Device::Cpu)?;
454+
let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?;
455+
456+
let negatives = Tensor::full(f32::MIN, attention_bias.shape(), &Device::Cpu)?;
457+
let zeros = Tensor::zeros_like(&attention_bias)?.to_device(&Device::Cpu)?;
458+
459+
let causal_mask = causal_mask
460+
.where_cond(&negatives, &zeros)?
461+
.to_device(device)?;
462+
463+
attention_bias.broadcast_add(&causal_mask)
464+
}
465+
436466
pub fn forward(&self, batch: Batch) -> Result<(Option<Tensor>, Option<Tensor>)> {
437467
let _enter = self.span.enter();
438468

@@ -441,93 +471,77 @@ impl Qwen3Model {
441471

442472
let shape = (batch_size, max_length);
443473

444-
let (input_ids, position_ids, input_lengths, attention_bias, _attention_mask) =
445-
if batch_size > 1 {
446-
// Prepare padded batch
447-
let elems = batch_size * max_length;
448-
449-
let mut input_ids = Vec::with_capacity(elems);
450-
let mut position_ids = Vec::with_capacity(elems);
451-
let mut attention_mask = Vec::with_capacity(elems);
452-
let mut attention_bias = Vec::with_capacity(elems);
453-
let mut input_lengths = Vec::with_capacity(batch_size);
454-
let mut masking = false;
455-
456-
for i in 0..batch_size {
457-
let start = batch.cumulative_seq_lengths[i] as usize;
458-
let end = batch.cumulative_seq_lengths[i + 1] as usize;
459-
let seq_length = end - start;
460-
input_lengths.push(seq_length);
461-
462-
// Input ids
463-
for j in start..end {
464-
input_ids.push(batch.input_ids[j]);
465-
position_ids.push(batch.position_ids[j]);
466-
attention_mask.push(1.0_f32);
467-
attention_bias.push(0.0);
468-
}
474+
let (input_ids, position_ids, input_lengths, attention_bias) = if batch_size > 1 {
475+
// Prepare padded batch
476+
let elems = batch_size * max_length;
477+
478+
let mut input_ids = Vec::with_capacity(elems);
479+
let mut position_ids = Vec::with_capacity(elems);
480+
let mut attention_bias = Vec::with_capacity(elems);
481+
let mut input_lengths = Vec::with_capacity(batch_size);
482+
let mut masking = false;
483+
484+
for i in 0..batch_size {
485+
let start = batch.cumulative_seq_lengths[i] as usize;
486+
let end = batch.cumulative_seq_lengths[i + 1] as usize;
487+
let seq_length = end - start;
488+
input_lengths.push(seq_length);
489+
490+
for j in start..end {
491+
input_ids.push(batch.input_ids[j]);
492+
position_ids.push(batch.position_ids[j]);
493+
attention_bias.push(0.0);
494+
}
469495

470-
// Pad to max_length
471-
for _ in seq_length..max_length {
472-
input_ids.push(0);
473-
position_ids.push(0);
474-
attention_mask.push(0.0_f32);
475-
attention_bias.push(f32::NEG_INFINITY);
476-
masking = true;
496+
let padding = max_length - seq_length;
497+
if padding > 0 {
498+
masking = true;
499+
for _ in 0..padding {
500+
input_ids.insert(start, self.pad_token_id);
501+
position_ids.insert(start, 0);
502+
attention_bias.insert(start, f32::MIN);
477503
}
478504
}
505+
}
479506

480-
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
481-
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
482-
let attention_mask = if masking {
483-
Some(Tensor::from_vec(attention_mask, shape, &self.device)?)
484-
} else {
485-
None
486-
};
487-
488-
let attention_bias = if masking {
489-
let attention_bias = Tensor::from_vec(
490-
attention_bias,
491-
(batch_size, 1, 1, max_length),
492-
&self.device,
493-
)?;
494-
// Broadcast once instead of at every layer
495-
let attention_bias = attention_bias
496-
.broadcast_as((
497-
batch_size,
498-
self.num_attention_heads,
499-
max_length,
500-
max_length,
501-
))?
502-
.contiguous()?;
503-
Some(attention_bias)
504-
} else {
505-
None
506-
};
507-
508-
(
509-
input_ids,
510-
position_ids,
511-
input_lengths,
512-
attention_bias,
513-
attention_mask,
514-
)
507+
let input_ids = Tensor::from_vec(input_ids, shape, &self.device)?;
508+
let position_ids = Tensor::from_vec(position_ids, shape, &self.device)?;
509+
510+
let attention_bias = if masking {
511+
let attention_bias =
512+
Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?;
513+
// Broadcast once instead of at every layer
514+
let attention_bias = attention_bias
515+
.broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))?
516+
.contiguous()?;
517+
Some(attention_bias)
515518
} else {
516-
let input_ids = Tensor::from_vec(
517-
batch.input_ids.clone(),
518-
(1, batch.input_ids.len()),
519-
&self.device,
520-
)?;
521-
let position_ids = Tensor::from_vec(
522-
batch.position_ids.clone(),
523-
(1, batch.position_ids.len()),
524-
&self.device,
525-
)?;
526-
let input_lengths = vec![batch.input_ids.len()];
527-
528-
(input_ids, position_ids, input_lengths, None, None)
519+
None
529520
};
530521

522+
(input_ids, position_ids, input_lengths, attention_bias)
523+
} else {
524+
let input_ids = Tensor::from_vec(
525+
batch.input_ids.clone(),
526+
(1, batch.input_ids.len()),
527+
&self.device,
528+
)?;
529+
let position_ids = Tensor::from_vec(
530+
batch.position_ids.clone(),
531+
(1, batch.position_ids.len()),
532+
&self.device,
533+
)?;
534+
let input_lengths = vec![batch.input_ids.len()];
535+
536+
(input_ids, position_ids, input_lengths, None)
537+
};
538+
539+
let attention_bias = if let Some(attn_bias) = attention_bias {
540+
Some(self.get_causal_attention_bias(attn_bias)?)
541+
} else {
542+
None
543+
};
544+
531545
let mut hidden_states = self.embeddings.forward(&input_ids)?;
532546

533547
let cos = self
@@ -583,7 +597,7 @@ impl Qwen3Model {
583597
.iter()
584598
.map(|&i| {
585599
let i = i as usize;
586-
let last_token_idx = input_lengths[i] - 1;
600+
let last_token_idx = max_length - 1;
587601
outputs.i((i, last_token_idx))?.unsqueeze(0)
588602
})
589603
.collect();

0 commit comments

Comments
 (0)