diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 77c50c9f4..655fdb9c3 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -721,8 +721,6 @@ def find_module(self, model, regions: List): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. - Specifically, it allows to map nn.MultiheadAttetion to its quantized counterpart and not its - Linear submodules. """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):