Skip to content

Commit

Permalink
fix/test (nn/conv): Fixed conv instantiation and added extra tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Sep 3, 2024
1 parent 3887f3b commit bc703e5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))):
if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -132,7 +133,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))):
if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -220,7 +222,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))):
if padding_mode == 'zeros' and padding == 'same' and stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride)):
padding = 0
is_same_padded_strided = True
else:
Expand Down
12 changes: 9 additions & 3 deletions tests/brevitas/nn/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest_cases
import torch
from torch.nn import BatchNorm2d
from torch.nn import Conv2d
from torch.nn import Module

from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling
from brevitas.nn import QuantConv2d
Expand All @@ -18,12 +18,18 @@

class TestQuantConv2d:

def test_module_init(self):
@pytest_cases.parametrize(
'kwargs', [{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1)}],
ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1)'])
def test_module_init(self, kwargs):
mod = QuantConv2d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False)
bias=False,
**kwargs)

def test_fp_quant_module(self):
float_mod = Conv2d(
Expand Down
11 changes: 9 additions & 2 deletions tests/brevitas/nn/test_conv3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest_cases
import torch
from torch.nn import BatchNorm3d
from torch.nn import Conv3d
Expand All @@ -17,12 +18,18 @@

class TestQuantConv3d:

def test_module_init(self):
@pytest_cases.parametrize(
'kwargs', [{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1, 1)}],
ids=['defaults', 'padding="same",stride=1', 'padding="same",stride=(1,1,1)'])
def test_module_init(self, kwargs):
mod = QuantConv3d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False)
bias=False,
**kwargs)

def test_fp_quant_module(self):
float_mod = Conv3d(
Expand Down

0 comments on commit bc703e5

Please sign in to comment.