From b76dabcb210a6b993298e3885c3fdd4e5d466e09 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 6 Sep 2024 18:43:47 +0100 Subject: [PATCH] test (graph/layerwise_quantize): Added test for blacklist. --- tests/brevitas/graph/test_quantize.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/brevitas/graph/test_quantize.py diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py new file mode 100644 index 000000000..1d4ce0b67 --- /dev/null +++ b/tests/brevitas/graph/test_quantize.py @@ -0,0 +1,39 @@ +import pytest_cases +import torch.nn as nn + +from brevitas.graph.quantize import layerwise_quantize + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'name_blacklist': [], + 'key': '0', + 'expected': ""}, + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'name_blacklist': ['0'], + 'key': '0', + 'expected': ""}, + { + 'model': nn.Sequential(nn.Sequential(nn.Linear(2, 3))), + 'name_blacklist': ['0.0'], + 'key': '0.0', + 'expected': ""},]) +def test_layerwise_quantize_blacklist(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + del kwargs['key'] + del kwargs['expected'] + qmodel = layerwise_quantize(**kwargs) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}"