Skip to content

Commit

Permalink
Add support for x3d model family on UCF101 dataset (#4)
Browse files Browse the repository at this point in the history
ptoupas authored Nov 16, 2023
1 parent 97c4bea commit 362d306
Showing 6 changed files with 155 additions and 39 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -51,6 +51,11 @@ bash scripts/run_quantization.sh
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
| unet | [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) | 69.10 | 69.10 | 1.98 | 61.74 | 68.43 |

### ucf101 (val-split1, top-1 acc)
| Model | Source | Float32 | Fixed16 | Fixed8 | BFP8 (Layer) | BFP8 (Channel) |
|-------|----------------------------------------------------------------|---------|---------|--------|--------------|----------------|
| x3d_m | [mmaction2](https://github.com/open-mmlab/mmaction2) | 95.58 | 95.69 | - | - | - |

## Links to other repos
* Optimizer: https://github.com/AlexMontgomerie/fpgaconvnet-optimiser; https://github.com/AlexMontgomerie/samo
* Model: https://github.com/AlexMontgomerie/fpgaconvnet-model
11 changes: 9 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os


def initialize_wrapper(dataset_name, model_name,
dataset_path, batch_size, workers):
model_wrapper = None
@@ -31,10 +32,16 @@ def initialize_wrapper(dataset_name, model_name,
if model_name in ["unet"]:
from models.segmentation.lggmri import BrainModelWrapper
model_wrapper = BrainModelWrapper(model_name)

elif dataset_name == "ucf101":
os.environ['UCF101_PATH'] = dataset_path
if model_name in ["x3d_s", "x3d_m"]:
from models.action_recognition.ucf101 import MmactionModelWrapper
model_wrapper = MmactionModelWrapper(model_name)

if model_wrapper is None:
raise NotImplementedError("Unknown dataset/model combination")

model_wrapper.load_data(batch_size, workers)
model_wrapper.load_model()
return model_wrapper

return model_wrapper
82 changes: 82 additions & 0 deletions models/action_recognition/ucf101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import onnx
import os
import torch

from mmengine.config import Config, DictAction
from mmengine.runner import Runner
from mmengine.runner import load_checkpoint
from models.base import TorchModelWrapper
from onnxsim import simplify


class MmactionModelWrapper(TorchModelWrapper):

def load_model(self, val=True):
assert self.model_name in ["x3d_s", "x3d_m"]
# todo: add mmaction2 as submodule?
MMACTION_PATH = os.environ.get(
"MMACTION_PATH", os.path.expanduser("../mmaction2"))

match self.model_name:
case "x3d_s":
config_path = os.path.join(
MMACTION_PATH, "configs/recognition/x3d/x3d_s_13x6x1_facebook-kinetics400-rgb.py")
checkpoint_path = "https://download.openmmlab.com/mmaction/v1.0/recognition/x3d/facebook/x3d_s_13x6x1_facebook-kinetics400-rgb_20201027-623825a0.pth"
raise NotImplementedError(
"x3d_s has not been trained yet on UCF101")
case "x3d_m":
config_path = os.path.join(
MMACTION_PATH, "configs/recognition/x3d/x3d_m_16x5x1_facebook-kinetics400-rgb.py")
checkpoint_path = "https://drive.google.com/uc?export=download&id=1l6x6LOmSfpugMOSuEZYb4foRIC8jXMQU"

cfg = Config.fromfile(config_path)
# runner only load checkpoint when running inference, too late for compression, as model is already substituted
# cfg.load_from = checkpoint_path

cfg.work_dir = os.path.join('./mmaction2_work_dirs', self.model_name)
cfg.data_root = os.path.join(os.environ.get(
"UCF101_PATH", os.path.expanduser("~/dataset/ucf101")), "videos")
cfg.data_root_val = cfg.data_root
cfg.ann_file_test = os.path.join(os.environ.get("UCF101_PATH", os.path.expanduser(
"~/dataset/ucf101")), "testlist01_mmaction_videos.txt")
cfg.model.cls_head.num_classes = 101

cfg.test_dataloader.dataset.data_prefix = dict(video=cfg.data_root)
cfg.test_dataloader.dataset.ann_file = cfg.ann_file_test

cfg.test_dataloader.batch_size = 8
cfg.test_dataloader.num_workers = 8

# cfg.log_level = "WARNING"
self.runner = Runner.from_cfg(cfg)
self.model = self.runner.model
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[
'state_dict']
self.model.load_state_dict(state_dict)
# load_checkpoint(self.model, checkpoint_path, map_location="cpu")

def load_data(self, batch_size, workers): # todo: fix this
# let the runner handle the data loading
# todo: download ucf101 dataset (https://www.crcv.ucf.edu/data/UCF101/UCF101.rar)
# todo: dowload Train/Test Splits for Action Recognition on UCF101 (https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip)
# todo: prepare the dataset following the guidelines from mmaction (https://github.com/open-mmlab/mmaction2/blob/main/docs/en/user_guides/2_dataset_prepare.md#prepare-datasets)
pass

def inference(self, mode="validate"):
mode = "validate" if mode == "test" else mode
print("Inference mode: {}".format(mode))
if mode in ["validate", "calibrate"]:
results = self.runner.test()
print(results)

def onnx_exporter(self, onnx_path):
# todo: support other input sizes
random_input = torch.randn(1, 1, 3, 16, 256, 256)
if torch.cuda.is_available():
random_input = random_input.cuda()
torch.onnx.export(self.model, random_input, onnx_path,
verbose=False, keep_initializers_as_inputs=True)

model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_path)
26 changes: 17 additions & 9 deletions models/segmentation/cityscapes.py
Original file line number Diff line number Diff line change
@@ -7,28 +7,33 @@
from models.base import TorchModelWrapper
from onnxsim import simplify


class MmsegmentationModelWrapper(TorchModelWrapper):
def load_model(self, eval=True):
assert self.model_name == 'unet'
# todo: add mmseg as submodule?
MMSEG_PATH = os.environ.get("MMSEG_PATH", os.path.expanduser("../mmsegmentation"))
config_path = os.path.join(MMSEG_PATH, "configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py")
MMSEG_PATH = os.environ.get(
"MMSEG_PATH", os.path.expanduser("../mmsegmentation"))
config_path = os.path.join(
MMSEG_PATH, "configs/unet/unet-s5-d16_fcn_4xb4-160k_cityscapes-512x1024.py")
cfg = Config.fromfile(config_path)
checkpoint_path = "https://download.openmmlab.com/mmsegmentation/v0.5/unet/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes/fcn_unet_s5-d16_4x4_512x1024_160k_cityscapes_20211210_145204-6860854e.pth"
# runner only load checkpoint when running inference, too late for compression, as model is already substituted
#cfg.load_from = checkpoint_path
# cfg.load_from = checkpoint_path
cfg.work_dir = os.path.join('./mmseg_work_dirs', self.model_name)
cfg.data_root = os.environ.get("CITYSCAPES_PATH", os.path.expanduser("~/dataset/cityscapes"))
cfg.data_root = os.environ.get(
"CITYSCAPES_PATH", os.path.expanduser("~/dataset/cityscapes"))
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.test_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_root = cfg.data_root

self.runner = Runner.from_cfg(cfg)
self.model = self.runner.model
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)['state_dict']
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)[
'state_dict']
self.model.load_state_dict(state_dict)

def load_data(self, batch_size, workers): # todo: fix this
def load_data(self, batch_size, workers): # todo: fix this
# let the runner handle the data loading
# todo: download cityscapes dataset
# https://github.com/open-mmlab/mmsegmentation/blob/main/docs/en/user_guides/2_dataset_prepare.md#prepare-datasets
@@ -38,15 +43,18 @@ def inference(self, mode="validate"):
mode = "validate" if mode == "test" else mode
print("Inference mode: {}".format(mode))
if mode in ["validate", "calibrate"]:
# todo: we should probably use the runner.test() method instead
results = self.runner.val()
print(results)

def onnx_exporter(self, onnx_path):
random_input = torch.randn(1,3,512,1024) # todo: support other input sizes
# todo: support other input sizes
random_input = torch.randn(1, 3, 512, 1024)
if torch.cuda.is_available():
random_input = random_input.cuda()
torch.onnx.export(self, random_input, onnx_path, verbose=False, keep_initializers_as_inputs=True)
torch.onnx.export(self, random_input, onnx_path,
verbose=False, keep_initializers_as_inputs=True)

model = onnx.load(onnx_path)
model_simp, check = simplify(model)
onnx.save(model_simp, onnx_path)
onnx.save(model_simp, onnx_path)
28 changes: 17 additions & 11 deletions quantization/utils.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,14 @@ class QuantMode(Enum):
CHANNEL_BFP = 3

# todo: support 3D layers
ACTIVA_QUANT_MODULES = (nn.Conv2d, nn.Linear, nn.ConvTranspose2d, nn.ReLU, nn.ReLU6, nn.MaxPool2d, nn.AdaptiveAvgPool2d, nn.AvgPool2d)
WEIGHT_QUANT_MODULES = (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)
ACTIVA_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear, nn.ConvTranspose2d, nn.ConvTranspose3d, nn.ReLU, nn.ReLU6, nn.MaxPool2d, nn.MaxPool3d, nn.AdaptiveAvgPool2d, nn.AdaptiveAvgPool3d, nn.AvgPool2d, nn.AvgPool3d)
WEIGHT_QUANT_MODULES = (nn.Conv2d, nn.Conv3d, nn.Linear, nn.ConvTranspose2d, nn.ConvTranspose3d)

def linear_quantize(x, scaling_factor, zero_point):
if len(x.shape) == 4:
if len(x.shape) == 5:
scaling_factor = scaling_factor.view(-1, 1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1, 1)
elif len(x.shape) == 4:
scaling_factor = scaling_factor.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
elif len(x.shape) == 2:
@@ -28,7 +31,10 @@ def linear_quantize(x, scaling_factor, zero_point):
return x_quant

def linear_dequantize(x_quant, scaling_factor, zero_point):
if len(x_quant.shape) == 4:
if len(x_quant.shape) == 5:
scaling_factor = scaling_factor.view(-1, 1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1, 1)
elif len(x_quant.shape) == 4:
scaling_factor = scaling_factor.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
elif len(x_quant.shape) == 2:
@@ -37,7 +43,7 @@ def linear_dequantize(x_quant, scaling_factor, zero_point):
else:
assert False
x = (x_quant + zero_point) / scaling_factor
return x
return x

#Asymmetric Quantiation: x_q = round((x_f - min_xf) * (2^n - 1) / (max_xf - min_xf))
def asymmetric_linear_no_clipping(word_length, x_min, x_max):
@@ -87,7 +93,7 @@ def apply(self, w, word_length, mode):
w_max = w_max.to(w.device)
scaling_factor, zero_point = asymmetric_linear_no_clipping(word_length, w_min, w_max)
w_quant = linear_quantize(w, scaling_factor, zero_point)
w_quant = saturate(w_quant, word_length)
w_quant = saturate(w_quant, word_length)
w_approx = linear_dequantize(w_quant, scaling_factor, zero_point)
return w_approx

@@ -126,8 +132,8 @@ def forward(self, x):
x_max = x.data.max()
# in-place operation used on multi-gpus
self.x_min += -self.x_min + torch.minimum(self.x_min, x_min)
self.x_max += -self.x_max + torch.maximum(self.x_max, x_max)
return x
self.x_max += -self.x_max + torch.maximum(self.x_max, x_max)
return x
else:
if self.mode == QuantMode.CHANNEL_BFP:
x = x.transpose(0, 1)
@@ -145,9 +151,9 @@ def __init__(self, model_wrapper):
def apply(self, word_length, mode):
# add activation quantisation module
replace_dict ={}
for module in self.model_wrapper.modules():
for module in self.model_wrapper.modules():
if isinstance(module, ACTIVA_QUANT_MODULES):
module_quant = nn.Sequential(*[QuantAct(word_length, mode), copy.deepcopy(module), QuantAct(word_length, mode)])
module_quant = nn.Sequential(*[QuantAct(word_length, mode), copy.deepcopy(module), QuantAct(word_length, mode)])
replace_dict[module] = module_quant
self.model_wrapper.replace_modules(replace_dict)
if torch.cuda.is_available():
@@ -180,7 +186,7 @@ def apply(self, word_length, mode):
def quantize_model(model_wrapper, info):
model_wrapper.sideband_info['quantization'] = info
weight_quantizer = ModelParamQuantizer(model_wrapper)
for name, module in model_wrapper.named_modules():
for name, module in model_wrapper.named_modules():
if isinstance(module, WEIGHT_QUANT_MODULES):
if isinstance(module, nn.ConvTranspose2d):
weight = module.weight.data.transpose(0, 1)
42 changes: 25 additions & 17 deletions quantization_example.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,11 @@
from models import initialize_wrapper
from quantization.utils import QuantMode, quantize_model


def main():
parser = argparse.ArgumentParser(description='Quantization Example')
parser.add_argument('--dataset_name', default="imagenet", type=str,
help='dataset name')
help='dataset name')
parser.add_argument('--dataset_path', metavar='DIR', default="~/dataset/ILSVRC2012_img",
help='path to dataset')
parser.add_argument('--model_name', metavar='ARCH', default='resnet18',
@@ -25,8 +26,7 @@ def main():
help='GPU id to use.')

parser.add_argument('--output_path', default=None, type=str,
help='output path')

help='output path')

args = parser.parse_args()
if args.output_path == None:
@@ -41,43 +41,51 @@ def main():
torch.manual_seed(0)

model_wrapper = initialize_wrapper(args.dataset_name, args.model_name,
os.path.expanduser(args.dataset_path), args.batch_size, args.workers)
os.path.expanduser(args.dataset_path), args.batch_size, args.workers)

# TEST 1
print("FLOAT32 Inference")
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "float32"))

model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "float32"))
# TEST 2
print("NETWORK FP16 Inference")
# reload the model everytime a new quantization mode is tested
model_wrapper.load_model()
quantize_model(model_wrapper, {'weight_width': 16, 'data_width': 16, 'mode': QuantMode.NETWORK_FP})
quantize_model(model_wrapper, {
'weight_width': 16, 'data_width': 16, 'mode': QuantMode.NETWORK_FP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "network_fp16"))
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "network_fp16"))

# TEST 3
print("NETWORK FP8 Inference")
model_wrapper.load_model()
quantize_model(model_wrapper, {'weight_width': 8, 'data_width': 8, 'mode': QuantMode.NETWORK_FP})
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.NETWORK_FP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "network_fp8"))
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "network_fp8"))

# TEST 4
print("LAYER BFP8 Inference")
model_wrapper.load_model()
quantize_model(model_wrapper, {'weight_width': 8, 'data_width': 8, 'mode': QuantMode.LAYER_BFP})
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.LAYER_BFP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "layer_bfp8"))
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "layer_bfp8"))

# TEST 5
print("CHANNEL BFP8 Inference")
print("CHANNEL BFP8 Inference")
# note: CHANNEL_BFP can be worse than LAYER_BFP, if calibration size is small!
model_wrapper.load_model()
quantize_model(model_wrapper, {'weight_width': 8, 'data_width': 8, 'mode': QuantMode.CHANNEL_BFP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(os.path.join(args.output_path, "channel_bfp8"))
quantize_model(model_wrapper, {
'weight_width': 8, 'data_width': 8, 'mode': QuantMode.CHANNEL_BFP})
model_wrapper.inference("test")
model_wrapper.generate_onnx_files(
os.path.join(args.output_path, "channel_bfp8"))


if __name__ == '__main__':
main()

0 comments on commit 362d306

Please sign in to comment.