Skip to content

Commit 87f1249

Browse files
committed
[PT2E] Fix per-tensor observer issue with varing shape & rank
1 parent 07ca637 commit 87f1249

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

torchao/quantization/pt2e/observer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1793,7 +1793,7 @@ def get_block_size(
17931793
"Please provide an instance of Granularity, not subclass of it"
17941794
)
17951795
if isinstance(granularity, PerTensor):
1796-
return input_shape
1796+
return (-1,) * len(input_shape)
17971797
elif isinstance(granularity, PerAxis):
17981798
block_size = list(input_shape)
17991799
block_size[granularity.axis] = 1
@@ -1891,6 +1891,10 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
18911891
assert self.original_dtype is not None, (
18921892
"Expecting original_dtype to be populated"
18931893
)
1894+
# Since input shape & rank may change (e.g. Resnet18), here we need to update block_size for each input
1895+
self.block_size = get_block_size(
1896+
observer_node.args[0].meta["tensor_meta"].shape, self.granularity
1897+
)
18941898
if hasattr(self, "is_dynamic") and self.is_dynamic:
18951899
choose_qparams_affine = model.graph.call_function(
18961900
torch.ops.torchao.choose_qparams_affine,

0 commit comments

Comments
 (0)