-
Notifications
You must be signed in to change notification settings - Fork 188
/
autoquant.py
1289 lines (1086 loc) · 44.5 KB
/
autoquant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import torch.nn.functional as F
from torch.utils._python_dispatch import return_and_correct_aliasing
import torchao
from torchao.dtypes import (
AffineQuantizedTensor,
Float8Layout,
MarlinSparseLayout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
)
from torchao.dtypes.utils import Layout
from torchao.float8.inference import Float8MMConfig
from torchao.kernel import safe_int_mm
from torchao.quantization.linear_activation_quantized_tensor import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.utils import (
compute_error,
quantize_activation_per_token_absmax,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TorchAOBaseTensor,
is_sm_at_least_89,
is_sm_at_least_90,
)
from .granularity import (
PerRow,
PerTensor,
)
from .subclass import ( # noqa
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
__all__ = [
"AutoQuantizableLinearWeight",
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"GEMLITE_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
"ALL_AUTOQUANT_CLASS_LIST",
]
aten = torch.ops.aten
AUTOQUANT_CACHE = {}
def check_cache(cls, shapes_and_dtype):
return AUTOQUANT_CACHE.get((cls,) + shapes_and_dtype, None)
def update_cache(cls, shapes_and_dtype, res):
AUTOQUANT_CACHE[(cls,) + shapes_and_dtype] = res
# TODO: Document the methods
class AutoQuantizableLinearWeight(torch.Tensor):
"""
A subclass of torch.Tensor that, when run, finds the best type of quantization for itself and swaps
its data with the quantized version.
Args:
weight (torch.Tensor): The initial weight tensor.
qtensor_class_list (list): A list of tensor classes to be considered for quantization.
*args: Additional positional arguments.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
**kwargs: Additional keyword arguments.
"""
@staticmethod
def __new__(
cls,
weight,
qtensor_class_list,
*args,
mode=["relu", None],
min_sqnr=None,
**kwargs,
):
kwargs["device"] = weight.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else weight.layout
)
kwargs["dtype"] = (
kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype
)
kwargs["requires_grad"] = False
shape = kwargs.pop("shape", weight.shape)
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
def __init__(
self,
weight,
qtensor_class_list,
*args,
mode=["relu", None],
min_sqnr=None,
**kwargs,
):
self.weight = weight
self.qtensor_class_list = qtensor_class_list
self.logged_data = {}
self.mode = mode
self.min_sqnr = min_sqnr
def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})"
)
@staticmethod
def log_shape(act_mat, w_autoquant, bias):
act_mat = act_mat.reshape(-1, act_mat.shape[-1])
logged_dtype = act_mat.dtype
logged_shapes = (
act_mat.shape,
w_autoquant.shape,
None if bias is None else bias.shape,
)
shapes_and_dtype = logged_shapes + (logged_dtype,)
w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(
shapes_and_dtype, 0
)
for q_cls in w_autoquant.qtensor_class_list:
if check_cache(q_cls, shapes_and_dtype) is None:
update_cache(q_cls, shapes_and_dtype, None)
def tune_autoquant(self, q_cls, shapes_and_dtype, best_time):
act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype
if check_cache(q_cls, shapes_and_dtype) is None:
with torch.no_grad():
act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device)
bias = (
None
if bias_shape is None
else torch.randn(bias_shape, dtype=act_dtype, device=self.device)
)
try:
ref_output = AQDefaultLinearWeight._quantized_linear_op(
act_mat, self.weight, bias
)
q_output = q_cls._quantized_linear_op(
act_mat, q_cls.from_float(self.weight), bias
)
if (
self.min_sqnr is not None
and (sqnr := compute_error(q_output, ref_output))
< self.min_sqnr
):
print(
f"skipping q_cls: {q_cls} because the sqnr is too small, minimum expected sqnr: {self.min_sqnr}, got {sqnr}"
)
res = torch.inf
else:
res = q_cls._autoquant_test(
act_mat, self.weight, bias, best_time, self.mode
)
except Exception as e:
print(
f"warning: failed to autoquant {q_cls.__name__} for shape: {shapes_and_dtype} due to {e}"
)
res = torch.inf
update_cache(q_cls, shapes_and_dtype, res)
@torch.no_grad()
def to_quantized(self, error_on_unseen, **kwargs):
if error_on_unseen and self.logged_data == {}:
raise RuntimeError(
"must run module normally to get shape, dtype info for autoquant"
)
elif (self.logged_data == {}) and not error_on_unseen:
# default back to non-quantized weight if not seen
self = AQDefaultLinearWeight.from_float(self.weight)
return self
# only want to print shape (at start) and final result (at end)
# once per shape+quantization subclass combination.
ran_new_benchmarks = False
print_shape_once = True
def count_shapes(self, do_print=True):
differe_shape_count = 0
for shapes_and_dtype, times_seen in self.logged_data.items():
differe_shape_count += 1
if do_print:
act_shape, weight_shape, bias_shape, dtype = shapes_and_dtype
print(f"activation_shapes: {act_shape}, times_seen: {times_seen}")
if do_print:
print(
f"weight_shape: {weight_shape}, dtype: {dtype}, bias_shape: {bias_shape}"
)
return differe_shape_count
# check each class
best_time = torch.inf
best_cls = None
for q_cls in self.qtensor_class_list:
# for each logged shape+dtype, benchmark
cur_time = 0
total_seen = 0
shape_count = count_shapes(self, do_print=False)
for shapes_and_dtype, times_seen in self.logged_data.items():
if check_cache(q_cls, shapes_and_dtype) is None:
# only print shapes once
if print_shape_once:
print_shape_once = False
count_shapes(self, do_print=True)
time_for_best_shape = check_cache(best_cls, shapes_and_dtype)
time_for_best_shape = (
torch.inf
if time_for_best_shape is None
else time_for_best_shape
)
self.tune_autoquant(q_cls, shapes_and_dtype, time_for_best_shape)
ran_new_benchmarks = True
torch._dynamo.reset()
cur_time += check_cache(q_cls, shapes_and_dtype) * times_seen
total_seen += times_seen
cur_time = cur_time / total_seen
# print aggregated time if there were multiple shapes to aggregate and some new benchmarking was done
if shape_count is not None and shape_count > 1 and ran_new_benchmarks:
print(
f">time (all shapes): {cur_time:0.4f}ms for {q_cls}, prev_best: {best_time:0.4f}ms"
)
if cur_time != torch.inf and best_time >= cur_time:
best_time = cur_time
best_cls = q_cls
# if no new benchmarking was done, don't print the final result, it will be the same as for another layer
if ran_new_benchmarks:
print(f"best_cls={best_cls}\n")
if best_cls is None:
best_cls = AQDefaultLinearWeight
# TODO handle random cls args/kwargs? or should they be curried?
self = best_cls.from_float(self.weight)
return self
def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.weight),
self.qtensor_class_list,
dtype=self.dtype,
mode=self.mode,
min_sqnr=self.min_sqnr,
)
def __tensor_flatten__(self):
return ["weight"], [
self.qtensor_class_list,
self.mode,
self.min_sqnr,
self.dtype,
self.shape,
]
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None
):
weight = tensor_data_dict["weight"]
qtensor_class_list, mode, min_sqnr, dtype, shape = tensor_attributes
return cls(
weight,
qtensor_class_list,
mode=mode,
min_sqnr=min_sqnr,
shape=shape if outer_size is None else outer_size,
dtype=dtype,
strides=outer_stride,
)
@classmethod
def from_float(cls, weight, qtensor_class_list, **kwargs):
return cls(weight, qtensor_class_list, **kwargs)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if func is torch.nn.functional.linear:
mat1, w_autoquant, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
cls.log_shape(mat1, w_autoquant, bias)
return func(mat1, w_autoquant.weight, bias)
try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except Exception:
print(f"ERR: subclass doesn't implement {func}")
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
@torch.no_grad()
def do_autoquant_bench(op, *args, **kwargs):
"""
runs benchmark op(*args, **kwargs) avoiding torch.compile overhead
"""
rep = kwargs.pop("rep", 100)
warmup = kwargs.pop("warmup", 25)
with torch.no_grad():
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
op(*args, **kwargs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
if TORCH_VERSION_AT_LEAST_2_5:
from torch._inductor.runtime.benchmarking import benchmarker
res = benchmarker.benchmark_gpu(
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
)
elif TORCH_VERSION_AT_LEAST_2_3:
from torch._inductor.runtime.runtime_utils import do_bench_gpu
res = do_bench_gpu(
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
)
else:
from torch._inductor.utils import do_bench
res = do_bench(
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
)
return res
def _is_interpolate_mode(mode):
if (
isinstance(mode, list)
and mode[0] == "interpolate"
and len(mode) == 2
and isinstance(mode[1], float)
):
return True
return False
class AQMixin:
"""
Tests and benchmarks the autoquantization process for the given activation matrix, weight, and bias.
Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
Returns:
float: The benchmarked time for the autoquantization process.
"""
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
w_qtensor = cls.from_float(weight)
if _is_interpolate_mode(mode):
q_c_op = torch.compile(
cls._quantized_linear_op, mode="max-autotune-no-cudagraphs"
)
else:
func = lambda a, b, c: F.relu(cls._quantized_linear_op(F.relu(a), b, c))
q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs")
res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=100)
if res < best_time * 1.1:
res2 = do_autoquant_bench(
q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900
)
res = res2 * 0.9 + res * 0.1
print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ")
return res
class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
"""
layout: Layout = PlainLayout()
@classmethod
def from_float(cls, weight):
# TODO test if this is valid
# in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
# if in_features <= 16:
# return weight
if weight.dim() != 2:
return weight
# avoid circular dep
from torchao.dtypes import to_affine_quantized_intx
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
return block_size
input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
_layout = cls.layout
input_quant_func = lambda x: to_affine_quantized_intx(
x,
input_mapping_type,
get_per_token_block_size(x),
input_target_dtype,
eps=input_eps,
quant_min=input_quant_min,
quant_max=input_quant_max,
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=_layout,
)
weight = super(AQInt8DynamicallyQuantizedLinearWeight, cls).from_float(
weight, input_quant_func
)
return weight
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Tests and benchmarks the autoquantization process with special handling for interpolate mode.
Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
Returns:
float: The benchmarked time for the autoquantization process.
"""
if not _is_interpolate_mode(mode):
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)
# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_int8, x_scales = quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
)
quantized_matmul = (
lambda x_vals_int8, x_scales, w_vals_int8: safe_int_mm(
x_vals_int8, w_vals_int8
)
* x_scales
)
q_c_matmul = torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
w_vals_int8 = (
w_qtensor.original_weight_tensor.tensor_impl.int_data.contiguous().t()
)
res_matmul = do_autoquant_bench(
q_c_matmul, x_vals_int8, x_scales.reshape(-1, 1), w_vals_int8
)
print(
f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms"
)
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul >= best_time:
return res_matmul
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT) * (
best_time - res_matmul
)
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
max_int_const_win = (best_time - res_matmul) / (res - res_matmul)
res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT) * res_matmul
print(
f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}"
)
return res_f
class AQInt8DynamicallyQuantizedSemiSparseLinearWeight(
AQInt8DynamicallyQuantizedLinearWeight
):
layout: Layout = SemiSparseLayout()
@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
return super()._autoquant_test(act_mat, weight, bias, best_time, None)
class AQInt8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
block_size = (1, weight.shape[1])
return super(AQInt8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(
weight,
mapping_type,
block_size,
target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype,
)
class AQInt8WeightOnlyQuantizedLinearWeight2(
AQInt8WeightOnlyQuantizedLinearWeight, AQMixin
):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
"""
Performs the quantized linear operations
Args:
act_mat (torch.Tensor): The activation matrix.
w_qtensor (torch.Tensor): The quantized weight tensor.
bias (torch.Tensor or None): The bias tensor.
Returns:
torch.Tensor: The result of the quantized operation.
"""
orig_dtype = act_mat.dtype
orig_shape = act_mat.shape
act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1)
y = (act_mat * w_qtensor.tensor_impl.int_data.t().unsqueeze(0)).sum(dim=-2)
y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.tensor_impl.scale
if bias is not None:
y += bias
return y.to(orig_dtype)
@classmethod
def _autoquant_test(cls, act_mat, *args):
# if act_mat has batchsize>2 don't use this kernel
if act_mat.reshape(-1, act_mat.shape[-1]).shape[0] > 32:
return torch.inf
return super()._autoquant_test(act_mat, *args)
class AQInt8WeightOnlyQuantizedLinearWeight3(
AQInt8WeightOnlyQuantizedLinearWeight, AQMixin
):
"""
AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
uses a different kernel
"""
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
orig_shape = act_mat.shape
y = torch.mm(
act_mat.reshape(-1, orig_shape[-1]),
w_qtensor.tensor_impl.int_data.t() * w_qtensor.tensor_impl.scale,
)
y = y.reshape(*orig_shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y
class AQInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Int4WeightOnlyQuantizedLinearWeight
"""
group_size: int = 32
layout: Layout = TensorCoreTiledLayout(inner_k_tiles=8)
@classmethod
def from_float(cls, weight):
group_size = cls.group_size
_layout = cls.layout
if weight.shape[-1] % group_size != 0:
return weight
use_hqq = True
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
if isinstance(_layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
use_hqq = False
return super(AQInt4G32WeightOnlyQuantizedLinearWeight, cls).from_hp_to_intx(
weight,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
_layout=_layout,
use_hqq=use_hqq,
)
class AQInt4G64WeightOnlyQuantizedLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 64
class AQInt4G128WeightOnlyQuantizedLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128
class AQInt4G256WeightOnlyQuantizedLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 256
class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight(
AQInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128
layout: Layout = MarlinSparseLayout()
class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
group_size: int = 32
@classmethod
def from_float(cls, weight):
from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs
bit_width = 4
packing_bitwidth = 32
contiguous = None
use_hqq = True
aqt_kwargs = get_gemlite_aqt_kwargs(
weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq
)
return super(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls
).from_hp_to_intx(weight, **aqt_kwargs)
class AQGemliteInt4G64WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 64
class AQGemliteInt4G128WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128
class AQGemliteInt4G256WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 256
class AQDefaultLinearWeight(torch.Tensor, AQMixin):
"""
A class to be used in concert with AutoQuantizableLinearWeight to provide a
default/non-quantized option. Only implements the bare minimum needed to work with the
AutoQuantizableLinearWeight class using the same interfaces that would normally be
used by QTensor subclasses but for a default linear op instead. Result of from_float
is not a tensor subclass, but rather the float tensor.
"""
def __init__(self):
super().__init__()
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor, bias)
@classmethod
def from_float(cls, weight):
return weight
class Float32Tensor(TorchAOBaseTensor):
"""Tensor subclass tensor for fp32 dtype"""
def __init__(self, weight):
self.weight = weight.to(torch.float32)
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.float32
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)
def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.weight),
)
@classmethod
def from_float(cls, weight):
return cls(weight)
@Float32Tensor.implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
@Float32Tensor.implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
@Float32Tensor.implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
@Float32Tensor.implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)
class BFloat16Tensor(Float32Tensor):
def __init__(self, weight):
self.weight = weight.to(torch.bfloat16)
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.bfloat16
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)
class Float16Tensor(Float32Tensor):
def __init__(self, weight):
self.weight = weight.to(torch.float16)
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
_DTYPE = torch.float16
orig_dtype = act_mat.dtype
return torch.nn.functional.linear(
act_mat.to(_DTYPE),
w_qtensor.weight,
bias.to(_DTYPE) if bias is not None else bias,
).to(dtype=orig_dtype)
class AQFloat32LinearWeight(Float32Tensor, AQMixin):
"""
AutoQuantizable version for float32 precision weight
(also converts input activation and bias to float32, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQFloat32LinearWeight, cls).from_float(weight)
class AQBFloat16LinearWeight(BFloat16Tensor, AQMixin):
"""
AutoQuantizable version for bfloat16 precision weight
(also converts input activation and bias to bfloat16, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQBFloat16LinearWeight, cls).from_float(weight)
class AQFloat16LinearWeight(Float16Tensor, AQMixin):
"""
AutoQuantizable version for float16 precision weight
(also converts input activation and bias to float16, and restores the original precision after
linear)
"""
@classmethod
def from_float(cls, weight):
return super(AQFloat16LinearWeight, cls).from_float(weight)
class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
"""
target_dtype: torch.dtype = torch.float8_e4m3fn
@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)
@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(
weight, block_size, target_dtype=cls.target_dtype, _layout=Float8Layout()
)
class AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight(
AQMixin, LinearActivationQuantizedTensor
):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per row scaling
"""
activation_granularity = PerRow()
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
# weight settings
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.float8_e4m3fn
# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
return block_size
input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight, cls
).from_float(weight, input_quant_func)
return weight
class AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight(
AQMixin, LinearActivationQuantizedTensor
):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight using per tensor scaling
"""
activation_granularity = PerTensor()
@classmethod
def from_float(cls, weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
from torchao.quantization.quant_api import _input_activation_quant_func_fp8
# weight settings
def get_weight_block_size(x):
assert x.ndim == 2, "Only works for 2D tensors"
return x.shape
target_dtype = torch.float8_e4m3fn
input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
activation_dtype=input_target_dtype,
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
_layout=_layout,
scale_dtype=torch.float32,
)
weight = super(
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight, cls
).from_float(weight, input_quant_func)
return weight