@@ -23,6 +23,7 @@ pub struct Qwen3Config {
23
23
pub rope_theta : f32 ,
24
24
pub sliding_window : Option < usize > ,
25
25
pub use_sliding_window : bool ,
26
+ pub eos_token_id : usize ,
26
27
}
27
28
28
29
struct Qwen3Attention {
@@ -164,8 +165,8 @@ impl Qwen3Attention {
164
165
. concat ( ) ,
165
166
) ?;
166
167
167
- let ( q, _res ) = self . q_norm . forward ( & q, None ) ?;
168
- let ( k, _res ) = self . k_norm . forward ( & k, None ) ?;
168
+ let ( q, _ ) = self . q_norm . forward ( & q, None ) ?;
169
+ let ( k, _ ) = self . k_norm . forward ( & k, None ) ?;
169
170
170
171
let q = q. transpose ( 1 , 2 ) ?;
171
172
let k = k. transpose ( 1 , 2 ) ?;
@@ -355,16 +356,21 @@ impl Qwen3Layer {
355
356
) -> Result < Tensor > {
356
357
let _enter = self . span . enter ( ) ;
357
358
358
- let ( normed_hidden_states, res) = self . input_layer_norm . forward ( hidden_states, None ) ?;
359
+ let ( normed_hidden_states, residual) =
360
+ self . input_layer_norm . forward ( hidden_states, None ) ?;
361
+
359
362
let attn_output =
360
363
self . attention
361
364
. forward ( & normed_hidden_states, attention_bias, cos, sin) ?;
365
+
362
366
let ( normed_attn_res_output, attn_res) = self
363
367
. post_attention_layer_norm
364
- . forward ( & attn_output, Some ( & res) ) ?;
368
+ . forward ( & attn_output, Some ( & residual) ) ?;
369
+
365
370
let mlp_output = self . mlp . forward ( & normed_attn_res_output) ?;
366
371
367
372
let output = ( & mlp_output + & attn_res) ?;
373
+
368
374
Ok ( output)
369
375
}
370
376
}
@@ -378,6 +384,7 @@ pub struct Qwen3Model {
378
384
pool : Pool ,
379
385
pub device : Device ,
380
386
num_attention_heads : usize ,
387
+ pad_token_id : u32 ,
381
388
382
389
span : tracing:: Span ,
383
390
}
@@ -427,12 +434,35 @@ impl Qwen3Model {
427
434
rotary_cache,
428
435
rotary_dim,
429
436
pool,
437
+ pad_token_id : config. eos_token_id as u32 ,
430
438
device : vb. device ( ) . clone ( ) ,
431
439
num_attention_heads : config. num_attention_heads ,
432
440
span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
433
441
} )
434
442
}
435
443
444
+ fn get_causal_attention_bias ( & self , attention_bias : Tensor ) -> Result < Tensor > {
445
+ let ( bs, dim, seq_len, _) = attention_bias. dims4 ( ) ?;
446
+
447
+ let device = attention_bias. device ( ) ;
448
+
449
+ let mask: Vec < u8 > = ( 0 ..seq_len)
450
+ . flat_map ( |i| ( 0 ..seq_len) . map ( move |j| ( j > i) as u8 ) )
451
+ . collect ( ) ;
452
+
453
+ let causal_mask = Tensor :: from_slice ( & mask, ( seq_len, seq_len) , & Device :: Cpu ) ?;
454
+ let causal_mask = causal_mask. expand ( & [ bs, dim, seq_len, seq_len] ) ?;
455
+
456
+ let negatives = Tensor :: full ( f32:: MIN , attention_bias. shape ( ) , & Device :: Cpu ) ?;
457
+ let zeros = Tensor :: zeros_like ( & attention_bias) ?. to_device ( & Device :: Cpu ) ?;
458
+
459
+ let causal_mask = causal_mask
460
+ . where_cond ( & negatives, & zeros) ?
461
+ . to_device ( device) ?;
462
+
463
+ attention_bias. broadcast_add ( & causal_mask)
464
+ }
465
+
436
466
pub fn forward ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
437
467
let _enter = self . span . enter ( ) ;
438
468
@@ -441,93 +471,77 @@ impl Qwen3Model {
441
471
442
472
let shape = ( batch_size, max_length) ;
443
473
444
- let ( input_ids, position_ids, input_lengths, attention_bias, _attention_mask) =
445
- if batch_size > 1 {
446
- // Prepare padded batch
447
- let elems = batch_size * max_length;
448
-
449
- let mut input_ids = Vec :: with_capacity ( elems) ;
450
- let mut position_ids = Vec :: with_capacity ( elems) ;
451
- let mut attention_mask = Vec :: with_capacity ( elems) ;
452
- let mut attention_bias = Vec :: with_capacity ( elems) ;
453
- let mut input_lengths = Vec :: with_capacity ( batch_size) ;
454
- let mut masking = false ;
455
-
456
- for i in 0 ..batch_size {
457
- let start = batch. cumulative_seq_lengths [ i] as usize ;
458
- let end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
459
- let seq_length = end - start;
460
- input_lengths. push ( seq_length) ;
461
-
462
- // Input ids
463
- for j in start..end {
464
- input_ids. push ( batch. input_ids [ j] ) ;
465
- position_ids. push ( batch. position_ids [ j] ) ;
466
- attention_mask. push ( 1.0_f32 ) ;
467
- attention_bias. push ( 0.0 ) ;
468
- }
474
+ let ( input_ids, position_ids, input_lengths, attention_bias) = if batch_size > 1 {
475
+ // Prepare padded batch
476
+ let elems = batch_size * max_length;
477
+
478
+ let mut input_ids = Vec :: with_capacity ( elems) ;
479
+ let mut position_ids = Vec :: with_capacity ( elems) ;
480
+ let mut attention_bias = Vec :: with_capacity ( elems) ;
481
+ let mut input_lengths = Vec :: with_capacity ( batch_size) ;
482
+ let mut masking = false ;
483
+
484
+ for i in 0 ..batch_size {
485
+ let start = batch. cumulative_seq_lengths [ i] as usize ;
486
+ let end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
487
+ let seq_length = end - start;
488
+ input_lengths. push ( seq_length) ;
489
+
490
+ for j in start..end {
491
+ input_ids. push ( batch. input_ids [ j] ) ;
492
+ position_ids. push ( batch. position_ids [ j] ) ;
493
+ attention_bias. push ( 0.0 ) ;
494
+ }
469
495
470
- // Pad to max_length
471
- for _ in seq_length..max_length {
472
- input_ids . push ( 0 ) ;
473
- position_ids . push ( 0 ) ;
474
- attention_mask . push ( 0.0_f32 ) ;
475
- attention_bias . push ( f32 :: NEG_INFINITY ) ;
476
- masking = true ;
496
+ let padding = max_length - seq_length ;
497
+ if padding > 0 {
498
+ masking = true ;
499
+ for _ in 0 ..padding {
500
+ input_ids . insert ( start , self . pad_token_id ) ;
501
+ position_ids . insert ( start , 0 ) ;
502
+ attention_bias . insert ( start , f32 :: MIN ) ;
477
503
}
478
504
}
505
+ }
479
506
480
- let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
481
- let position_ids = Tensor :: from_vec ( position_ids, shape, & self . device ) ?;
482
- let attention_mask = if masking {
483
- Some ( Tensor :: from_vec ( attention_mask, shape, & self . device ) ?)
484
- } else {
485
- None
486
- } ;
487
-
488
- let attention_bias = if masking {
489
- let attention_bias = Tensor :: from_vec (
490
- attention_bias,
491
- ( batch_size, 1 , 1 , max_length) ,
492
- & self . device ,
493
- ) ?;
494
- // Broadcast once instead of at every layer
495
- let attention_bias = attention_bias
496
- . broadcast_as ( (
497
- batch_size,
498
- self . num_attention_heads ,
499
- max_length,
500
- max_length,
501
- ) ) ?
502
- . contiguous ( ) ?;
503
- Some ( attention_bias)
504
- } else {
505
- None
506
- } ;
507
-
508
- (
509
- input_ids,
510
- position_ids,
511
- input_lengths,
512
- attention_bias,
513
- attention_mask,
514
- )
507
+ let input_ids = Tensor :: from_vec ( input_ids, shape, & self . device ) ?;
508
+ let position_ids = Tensor :: from_vec ( position_ids, shape, & self . device ) ?;
509
+
510
+ let attention_bias = if masking {
511
+ let attention_bias =
512
+ Tensor :: from_vec ( attention_bias, ( batch_size, 1 , 1 , max_length) , & self . device ) ?;
513
+ // Broadcast once instead of at every layer
514
+ let attention_bias = attention_bias
515
+ . broadcast_as ( ( batch_size, self . num_attention_heads , max_length, max_length) ) ?
516
+ . contiguous ( ) ?;
517
+ Some ( attention_bias)
515
518
} else {
516
- let input_ids = Tensor :: from_vec (
517
- batch. input_ids . clone ( ) ,
518
- ( 1 , batch. input_ids . len ( ) ) ,
519
- & self . device ,
520
- ) ?;
521
- let position_ids = Tensor :: from_vec (
522
- batch. position_ids . clone ( ) ,
523
- ( 1 , batch. position_ids . len ( ) ) ,
524
- & self . device ,
525
- ) ?;
526
- let input_lengths = vec ! [ batch. input_ids. len( ) ] ;
527
-
528
- ( input_ids, position_ids, input_lengths, None , None )
519
+ None
529
520
} ;
530
521
522
+ ( input_ids, position_ids, input_lengths, attention_bias)
523
+ } else {
524
+ let input_ids = Tensor :: from_vec (
525
+ batch. input_ids . clone ( ) ,
526
+ ( 1 , batch. input_ids . len ( ) ) ,
527
+ & self . device ,
528
+ ) ?;
529
+ let position_ids = Tensor :: from_vec (
530
+ batch. position_ids . clone ( ) ,
531
+ ( 1 , batch. position_ids . len ( ) ) ,
532
+ & self . device ,
533
+ ) ?;
534
+ let input_lengths = vec ! [ batch. input_ids. len( ) ] ;
535
+
536
+ ( input_ids, position_ids, input_lengths, None )
537
+ } ;
538
+
539
+ let attention_bias = if let Some ( attn_bias) = attention_bias {
540
+ Some ( self . get_causal_attention_bias ( attn_bias) ?)
541
+ } else {
542
+ None
543
+ } ;
544
+
531
545
let mut hidden_states = self . embeddings . forward ( & input_ids) ?;
532
546
533
547
let cos = self
@@ -583,7 +597,7 @@ impl Qwen3Model {
583
597
. iter ( )
584
598
. map ( |& i| {
585
599
let i = i as usize ;
586
- let last_token_idx = input_lengths [ i ] - 1 ;
600
+ let last_token_idx = max_length - 1 ;
587
601
outputs. i ( ( i, last_token_idx) ) ?. unsqueeze ( 0 )
588
602
} )
589
603
. collect ( ) ;
0 commit comments