diff --git a/src/brevitas/nn/quant_conv.py b/src/brevitas/nn/quant_conv.py index 5c6c5afea..18fb3282e 100644 --- a/src/brevitas/nn/quant_conv.py +++ b/src/brevitas/nn/quant_conv.py @@ -46,7 +46,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: @@ -132,7 +132,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: @@ -220,7 +220,7 @@ def __init__( dtype: Optional[torch.dtype] = None, **kwargs) -> None: # avoid an init error in the super class by setting padding to 0 - if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)): + if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, list(stride))): padding = 0 is_same_padded_strided = True else: