Skip to content

Commit

Permalink
Fix static_quantization_tutorial error in qat_model (#2661)
Browse files Browse the repository at this point in the history
We need to use the qat variant of the fuse_modules method.
After this fix, the tutorial runs to completion on a
linux x86 system.

Fixes #1269

Signed-off-by: BJ Hargrave <[email protected]>
Co-authored-by: Svetlana Karslioglu <[email protected]>
  • Loading branch information
bjhargrave and svekars authored Nov 9, 2023
1 parent 66eaf6a commit a668406
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions advanced_source/static_quantization_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,15 @@ Note: this code is taken from
# Fuse Conv+BN and Conv+BN+Relu modules prior to quantization
# This operation does not change the numerics
def fuse_model(self):
def fuse_model(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
for m in self.modules():
if type(m) == ConvBNReLU:
torch.ao.quantization.fuse_modules(m, ['0', '1', '2'], inplace=True)
fuse_modules(m, ['0', '1', '2'], inplace=True)
if type(m) == InvertedResidual:
for idx in range(len(m.conv)):
if type(m.conv[idx]) == nn.Conv2d:
torch.ao.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
2. Helper functions
-------------------
Expand Down Expand Up @@ -533,7 +534,7 @@ We fuse modules as before
.. code:: python
qat_model = load_model(saved_model_dir + float_model_file)
qat_model.fuse_model()
qat_model.fuse_model(is_qat=True)
optimizer = torch.optim.SGD(qat_model.parameters(), lr = 0.0001)
# The old 'fbgemm' is still available but 'x86' is the recommended default.
Expand Down

0 comments on commit a668406

Please sign in to comment.