Skip to content

Commit

Permalink
global weight pruning (#5)
Browse files Browse the repository at this point in the history
* global weight pruning

* print avg sparsity

* add init results
  • Loading branch information
Yu-Zhewen authored Dec 5, 2023
1 parent dfbd37d commit 33f8133
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 8 deletions.
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ python threshold_relu_example.py
* `camvid`: `unet`
* `cityscapes`: `unet`

## Quantization Results
@ commit ec09e56
```
bash scripts/run_quantization.sh
```
## Quantization Results

### imagenet (val, top-1 acc)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
Expand Down Expand Up @@ -63,6 +59,23 @@ bash scripts/run_quantization.sh
| x3d_s | [mmaction2](https://github.com/open-mmlab/mmaction2) | 93.68 | 93.57 | 1.13 | 90.21 | 93.57 |
| x3d_m | [mmaction2](https://github.com/open-mmlab/mmaction2) | 96.40 | 96.40 | 0.81 | 95.24 | 96.29 |


## Sparsity Results
* Q - Fixed16 Quantization
* AS - Activation Sparsity
* WS - Weight Sparsity (applying global pruning threshold)
* Post-training, without fine-tuning

### imagenet

| Model | Experiment | Accuracy | Sparsity |
|----------|----------------|----------|----------|
| resnet18 | Q+AS | 69.74 | 50.75 |
| resnet18 | Q+AS+WS(0.005) | 69.42 | 56.33 |
| resnet18 | Q+AS+WS(0.010) | 67.36 | 61.47 |
| resnet18 | Q+AS+WS(0.015) | 58.38 | 65.91 |
| resnet18 | Q+AS+WS(0.020) | 27.91 | 69.63 |

## Links to other repos
* Optimizer: https://github.com/AlexMontgomerie/fpgaconvnet-optimiser; https://github.com/AlexMontgomerie/samo
* Model: https://github.com/AlexMontgomerie/fpgaconvnet-model
Expand Down
2 changes: 1 addition & 1 deletion optimiser_interface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def opt_cli_launcher(model_name, onnx_path, output_dir,
opt_obj='throughput', opt_solver='greedy_partition', opt_cfg="single_partition_throughput"):

platform_path = os.path.join(os.environ['FPGACONVNET_OPTIMISER'], f'examples/platforms/{device}.toml')
opt_cfg_path = os.path.join(os.environ['FPGACONVNET_OPTIMISER'], f'examples/{opt_cfg}.toml')
opt_cfg_path = os.path.join(os.environ['FPGACONVNET_OPTIMISER'], f'examples/optimisers/{opt_cfg}.toml')
saved_argv = sys.argv
sys.argv = ['cli.py']
sys.argv += ['--name', model_name]
Expand Down
13 changes: 13 additions & 0 deletions sparsity/prune_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torch.nn as nn

WEIGHT_PRUNE_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear, nn.ConvTranspose2d, nn.ConvTranspose3d)

def apply_weight_pruning(model_wrapper, threshold):
for name, module in model_wrapper.named_modules():
if isinstance(module, WEIGHT_PRUNE_MODULES):
module.weight.data = torch.where(
torch.abs(module.weight.data) < threshold,
torch.tensor(0.0).to(module.weight.device),
module.weight.data)

17 changes: 15 additions & 2 deletions activation_sparsity_example.py → sparsity_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from models import initialize_wrapper
from quantization.utils import QuantMode, quantize_model
from sparsity.prune_utils import apply_weight_pruning
from sparsity.utils import measure_model_sparsity

def main():
Expand All @@ -25,6 +26,9 @@ def main():
parser.add_argument('--gpu', default=None, type=int,
help='GPU id to use.')

parser.add_argument('--weight_threshold', default=0.005, type=float,
help='threshold for weight pruning')

parser.add_argument('--output_path', default=None, type=str,
help='output path')

Expand All @@ -48,10 +52,19 @@ def main():
quantize_model(model_wrapper, {'weight_width': 16, 'data_width': 16, 'mode': QuantMode.NETWORK_FP})
top1, top5 = model_wrapper.inference("test")

# post-activation sparsity has zero impact on accuracy
if args.weight_threshold is None:
# post-activation sparsity has zero impact on accuracy
print("POST-ACTIVATION SPARSITY")
else:
# apply weight pruning
print("WEIGHT PRUNING")
apply_weight_pruning(model_wrapper, args.weight_threshold)
top1, top5 = model_wrapper.inference("test")

# measure sparsity-related stats on calibration set
measure_model_sparsity(model_wrapper)
avg_sparsity = measure_model_sparsity(model_wrapper)
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "sparse"))
print(f"Average sparsity: {avg_sparsity}")

if __name__ == '__main__':
main()
Expand Down

0 comments on commit 33f8133

Please sign in to comment.