@@ -41,6 +41,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
41
41
return quantized_size , quantized_offset
42
42
43
43
44
+ def adjust_scalar_to_fused_array (param , loaded_weight , shard_id ):
45
+ """For fused modules (QKV and MLP) we have an array of length
46
+ N that holds 1 scale for each "logical" matrix. So the param
47
+ is an array of length N. The loaded_weight corresponds to
48
+ one of the shards on disk. Here, we slice the param based on
49
+ the shard_id for loading.
50
+ """
51
+ qkv_idxs = {"q" : 0 , "k" : 1 , "v" : 2 }
52
+
53
+ if isinstance (shard_id , str ):
54
+ shard_id = qkv_idxs [shard_id ]
55
+ elif not isinstance (shard_id , int ):
56
+ raise ValueError (f"Unknown Shard Id { shard_id } " )
57
+
58
+ # AutoFP8 scales do not have a shape
59
+ # compressed-tensors scales do have a shape
60
+ if len (loaded_weight .shape ) != 0 :
61
+ assert loaded_weight .shape [0 ] == 1
62
+ loaded_weight = loaded_weight [0 ]
63
+
64
+ return param [shard_id ], loaded_weight
65
+
66
+
44
67
class LinearMethodBase (QuantizeMethodBase ):
45
68
"""Base class for different (maybe quantized) linear methods."""
46
69
@@ -358,37 +381,15 @@ def weight_loader(self,
358
381
output_dim = getattr (param , "output_dim" , None )
359
382
# Special case for AQLM codebooks.
360
383
is_metadata = getattr (param , "is_metadata" , False )
361
-
362
- param_shard_splitter = getattr (param , "shard_splitter" , None )
363
-
364
- if output_dim is not None and param_shard_splitter is not None :
365
- raise NotImplementedError (
366
- "We do not currently support output_dim != None and "
367
- "shard_splitter != None for a parameter. Please open an issue."
368
- )
369
- # If a parameter has defined a shard_splitter to be used for
370
- # the weight, it should be applied before the weight is
371
- # loaded/copied to the parameter. The shard_splitter applies
372
- # logic by using the loaded_shard_id to ensure that the loaded
373
- # param is loaded to the correct location
374
- # within the parameter defined by the linear method.
375
- if loaded_shard_id is None and param_shard_splitter is not None :
376
- raise NotImplementedError (
377
- "We do not currently support loaded_shard_id == None and "
378
- "shard_splitter != None for a parameter. Please open an issue."
379
- )
380
-
381
- # Special case for Fp8 scales.
382
- fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
383
- None )
384
+ # Special case for per-tensor scale to load scalar into fused array.
385
+ needs_scalar_to_array = getattr (param , "needs_scalar_to_array" , False )
384
386
385
387
if loaded_shard_id is None :
386
388
# Loaded weight is already fused on disk (qkv/mlp).
387
389
if output_dim is None :
388
- # If fp8 + scale, need to send to each shard.
389
- if fp8_scales_shard_indexer is not None :
390
- param_data , loaded_weight = fp8_scales_shard_indexer (
391
- param_data , loaded_weight , loaded_shard_id )
390
+ if needs_scalar_to_array is not None :
391
+ param_data , loaded_weight = adjust_scalar_to_fused_array (
392
+ param_data , loaded_weight , 0 )
392
393
393
394
assert param_data .shape == loaded_weight .shape
394
395
param_data .copy_ (loaded_weight )
@@ -450,15 +451,9 @@ def weight_loader(self,
450
451
shard_offset = loaded_shard_id * shard_size
451
452
param_data = param_data .narrow (0 , shard_offset , shard_size )
452
453
453
- # If a param_shard_splitter is defined by the LinearMethod, use it.
454
- elif param_shard_splitter is not None :
455
- logical_widths = getattr (param , "logical_widths" , None )
456
- param_data , loaded_weight = param_shard_splitter (
457
- param_data , loaded_weight , loaded_shard_id , logical_widths )
458
-
459
- # Special case for Fp8 scales.
460
- elif fp8_scales_shard_indexer is not None :
461
- param_data , loaded_weight = fp8_scales_shard_indexer (
454
+ # Special case for per-tensor scales in fused case.
455
+ elif needs_scalar_to_array :
456
+ param_data , loaded_weight = adjust_scalar_to_fused_array (
462
457
param_data , loaded_weight , loaded_shard_id )
463
458
464
459
else :
@@ -548,36 +543,15 @@ def weight_loader(self,
548
543
# Special case for AQLM codebooks.
549
544
is_metadata = getattr (param , "is_metadata" , False )
550
545
551
- param_shard_splitter = getattr (param , "shard_splitter" , None )
552
-
553
- if output_dim is not None and param_shard_splitter is not None :
554
- raise NotImplementedError (
555
- "We do not currently support output_dim != None and "
556
- "shard_splitter != None for a parameter. Please open an issue."
557
- )
558
- # If a parameter has defined a shard_splitter to be used for
559
- # the weight, it should be applied before the weight is
560
- # loaded/copied to the parameter. The shard_splitter applies
561
- # logic by using the loaded_shard_id to ensure that the loaded
562
- # param is loaded to the correct location
563
- # within the parameter defined by the linear method.
564
- if loaded_shard_id is None and param_shard_splitter is not None :
565
- raise NotImplementedError (
566
- "We do not currently support loaded_shard_id == None and "
567
- "shard_splitter != None for a parameter. Please open an issue."
568
- )
569
-
570
- # Special case for Fp8 scales.
571
- fp8_scales_shard_indexer = getattr (param , "fp8_scales_shard_indexer" ,
572
- None )
546
+ # Special case for per-tensor scales in fused case.
547
+ needs_scalar_to_array = getattr (param , "needs_scalar_to_array" , False )
573
548
574
549
if loaded_shard_id is None :
575
550
# Loaded weight is already fused on disk (qkv/mlp).
576
551
if output_dim is None :
577
- # If fp8 + scale, need to send to each shard.
578
- if fp8_scales_shard_indexer is not None :
579
- param_data , loaded_weight = fp8_scales_shard_indexer (
580
- param_data , loaded_weight , loaded_shard_id )
552
+ if needs_scalar_to_array is not None :
553
+ param_data , loaded_weight = adjust_scalar_to_fused_array (
554
+ param_data , loaded_weight , 0 )
581
555
582
556
assert param_data .shape == loaded_weight .shape
583
557
param_data .copy_ (loaded_weight )
@@ -667,15 +641,9 @@ def weight_loader(self,
667
641
shard_index = ["q" , "k" , "v" ].index (loaded_shard_id )
668
642
param_data = param_data .narrow (0 , shard_index * shard_size ,
669
643
shard_size )
670
- # If a param_shard_splitter is defined by the LinearMethod, use it.
671
- elif param_shard_splitter is not None :
672
- logical_widths = getattr (param , "logical_widths" , None )
673
- param_data , loaded_weight = param_shard_splitter (
674
- param_data , loaded_weight , loaded_shard_id , logical_widths )
675
-
676
- # Special case for Fp8 scales.
677
- elif fp8_scales_shard_indexer is not None :
678
- param_data , loaded_weight = fp8_scales_shard_indexer (
644
+ # Special case for per-tensor scales in fused case.
645
+ elif needs_scalar_to_array :
646
+ param_data , loaded_weight = adjust_scalar_to_fused_array (
679
647
param_data , loaded_weight , loaded_shard_id )
680
648
else :
681
649
ignore_warning = getattr (param , "ignore_warning" , False )
0 commit comments