From 7db89eec09da86320695798f4d4547312720cbe4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 6 Sep 2024 17:57:48 +0100 Subject: [PATCH 1/2] Fix (graph/quant): Bugfix in blacklist matching in `find_module` --- src/brevitas/graph/quantize_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index ed6382907..535f9a8f9 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -516,7 +516,7 @@ def find_module( else: for name, module in model.named_children(): full_name = prefix + '.' + name if prefix != '' else name - if name_blacklist is not None and name in name_blacklist: + if name_blacklist is not None and full_name in name_blacklist: continue find_module(module, layer_map, module_to_replace, name_blacklist, full_name) From b76dabcb210a6b993298e3885c3fdd4e5d466e09 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 6 Sep 2024 18:43:47 +0100 Subject: [PATCH 2/2] test (graph/layerwise_quantize): Added test for blacklist. --- tests/brevitas/graph/test_quantize.py | 39 +++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/brevitas/graph/test_quantize.py diff --git a/tests/brevitas/graph/test_quantize.py b/tests/brevitas/graph/test_quantize.py new file mode 100644 index 000000000..1d4ce0b67 --- /dev/null +++ b/tests/brevitas/graph/test_quantize.py @@ -0,0 +1,39 @@ +import pytest_cases +import torch.nn as nn + +from brevitas.graph.quantize import layerwise_quantize + + +@pytest_cases.parametrize( + 'kwargs', + [ + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'name_blacklist': [], + 'key': '0', + 'expected': ""}, + { + 'model': nn.Sequential(nn.Linear(2, 3)), + 'name_blacklist': ['0'], + 'key': '0', + 'expected': ""}, + { + 'model': nn.Sequential(nn.Sequential(nn.Linear(2, 3))), + 'name_blacklist': ['0.0'], + 'key': '0.0', + 'expected': ""},]) +def test_layerwise_quantize_blacklist(kwargs): + key = kwargs['key'] + exp = kwargs['expected'] + del kwargs['key'] + del kwargs['expected'] + qmodel = layerwise_quantize(**kwargs) + checked = False + found_names = [] + for n, m in qmodel.named_modules(): + found_names.append(n) + if n == key: + mt = str(type(m)) + assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}" + checked = True + assert checked, f"Layer named {key} not found. Layer names are: {found_names}"