Skip to content

Commit

Permalink
added tests to check if avgpool is replace with quantconvs and that m…
Browse files Browse the repository at this point in the history
…ergebatchnorm correctly removes batchnorm
  • Loading branch information
costigt-dev committed Feb 8, 2024
1 parent 479eb9f commit 51771bc
Showing 1 changed file with 59 additions and 29 deletions.
88 changes: 59 additions & 29 deletions tests/brevitas/graph/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@
from torchvision import models

from brevitas.fx import symbolic_trace
from brevitas.graph import AvgPoolToQuantDepthwiseConv
from brevitas.graph import DuplicateSharedStatelessModule
from brevitas.graph import FnToModule
from brevitas.graph import MeanMethodToAdaptiveAvgPool2d
from brevitas.graph import MergeBatchNorm
from brevitas.graph import MethodToModule
from brevitas.graph.base import ModuleToModuleByInstance
from brevitas.nn import QuantConv1d
from brevitas.nn import QuantConv2d
from brevitas.nn import QuantConv3d

SEED = 123456
INPUT_SIZE = (1, 3, 224, 224)
Expand Down Expand Up @@ -50,64 +54,90 @@ def test_rewriter_merge_bn(model_name: str, pretrained: bool):
assert is_close


def test_conv1d_merge_bn():
@pytest.mark.parametrize("dims", [1, 2, 3])
def test_conv_merge_bn(dims):

class TestModel(nn.Module):

def __init__(self):
def __init__(self, dims):
super(TestModel, self).__init__()
layers = []

self.net = nn.Sequential(nn.Conv1d(16, 33, 3, stride=2), nn.BatchNorm1d(33), nn.ReLU())

def forward(self, x):
return self.net(x)

model = TestModel()
graph = symbolic_trace(model)
graph = MergeBatchNorm().apply(graph)

for m in graph.modules():
assert not isinstance(m, nn.BatchNorm1d)

if dims == 1:
layers.append(nn.Conv1d(16, 33, 3, stride=2))
layers.append(nn.BatchNorm1d(33))
elif dims == 2:
layers.append(nn.Conv2d(16, 33, 3, stride=2))
layers.append(nn.BatchNorm2d(33))
else:
layers.append(nn.Conv3d(16, 33, 3, stride=2))
layers.append(nn.BatchNorm3d(33))

def test_conv2d_merge_bn():
layers.append(nn.ReLU())

class TestModel(nn.Module):

def __init__(self):
super(TestModel, self).__init__()

self.net = nn.Sequential(nn.Conv2d(16, 33, 3, stride=2), nn.BatchNorm2d(33), nn.ReLU())
self.net = nn.Sequential(*layers)

def forward(self, x):
return self.net(x)

model = TestModel()
model = TestModel(dims)
graph = symbolic_trace(model)
graph = MergeBatchNorm().apply(graph)

for m in graph.modules():
assert not isinstance(m, nn.BatchNorm2d)
if dims == 1:
assert not isinstance(m, nn.BatchNorm1d)
elif dims == 2:
assert not isinstance(m, nn.BatchNorm2d)
else:
assert not isinstance(m, nn.BatchNorm3d)


def test_conv3d_merge_bn():
@pytest.mark.parametrize("dims", [1, 2, 3])
def test_avg_pool_to_quant_conv(dims):

class TestModel(nn.Module):

def __init__(self):
def __init__(self, dims):
super(TestModel, self).__init__()

self.net = nn.Sequential(nn.Conv3d(16, 33, 3, stride=2), nn.BatchNorm3d(33), nn.ReLU())
if dims == 1:
self.net = nn.Sequential(nn.AvgPool1d(3, stride=2), nn.ReLU())
elif dims == 2:
self.net = nn.Sequential(nn.AvgPool2d(3, stride=2), nn.ReLU())
else:
self.net = nn.Sequential(nn.AvgPool3d(3, stride=2), nn.ReLU())

def forward(self, x):
return self.net(x)

model = TestModel()
model = TestModel(dims)

args = None
if dims == 1:
args = torch.randn(20, 16, 10)
elif dims == 2:
args = torch.randn(20, 16, 10, 50)
else:
args = torch.randn(20, 16, 10, 50, 100)

graph = symbolic_trace(model)
graph = MergeBatchNorm().apply(graph)
graph = AvgPoolToQuantDepthwiseConv().apply(graph, args)

has_quant_conv = False
for m in graph.modules():
assert not isinstance(m, nn.BatchNorm3d)
if isinstance(m, QuantConv1d):
has_quant_conv = True
if isinstance(m, QuantConv2d):
has_quant_conv = True
if isinstance(m, QuantConv3d):
has_quant_conv = True

assert not isinstance(m, nn.AvgPool1d)
assert not isinstance(m, nn.AvgPool2d)
assert not isinstance(m, nn.AvgPool3d)

assert has_quant_conv


def test_rewriter_duplicate_shared_relu():
Expand Down

0 comments on commit 51771bc

Please sign in to comment.