-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompression_tactics.py
64 lines (51 loc) · 3.12 KB
/
compression_tactics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from decimal import Decimal
# Compression-related imports
from aimet_common.defs import CostMetric, CompressionScheme, GreedySelectionParameters
from aimet_torch.defs import SpatialSvdParameters, ChannelPruningParameters
from aimet_torch.compress import ModelCompressor
def spatial_svd_auto_mode(model:torch.nn.Module, eval_loader, evaluate_model):
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.4),
num_comp_ratio_candidates=10)
auto_params = SpatialSvdParameters.AutoModeParams(greedy_params,
modules_to_ignore=[model.conv1])
params = SpatialSvdParameters(mode=SpatialSvdParameters.Mode.auto,
params=auto_params, multiplicity=8)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.spatial_svd,
cost_metric=CostMetric.mac,
parameters=params,
visualization_url=None)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
return compressed_model, stats
def channel_pruning_auto_mode(model:torch.nn.Module, train_loader, evaluate_model):
# Specify the necessary parameters
greedy_params = GreedySelectionParameters(target_comp_ratio=Decimal(0.4),
num_comp_ratio_candidates=10)
auto_params = ChannelPruningParameters.AutoModeParams(greedy_params,
modules_to_ignore=[model.conv1])
params = ChannelPruningParameters(data_loader=train_loader,
num_reconstruction_samples=500,
allow_custom_downsample_ops=True,
mode=ChannelPruningParameters.Mode.auto,
params=auto_params)
# Single call to compress the model
results = ModelCompressor.compress_model(model,
eval_callback=evaluate_model,
eval_iterations=1000,
input_shape=(1, 1, 28, 28),
compress_scheme=CompressionScheme.channel_pruning,
cost_metric=CostMetric.mac,
parameters=params,
visualization_url=None)
compressed_model, stats = results
print(compressed_model)
print(stats) # Stats object can be pretty-printed easily
return compressed_model, stats