Skip to content

[PT2E] Fix per-tensor observer issue with varying shape & rank #2177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_block_size(
granularity: The granularity type of the quantization
"""
if isinstance(granularity, PerTensor):
return input_shape
return [-1]
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
Expand Down
3 changes: 2 additions & 1 deletion torchao/quantization/pt2e/_affine_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _get_reduction_params(block_size, input_size):
shape_for_reduction: (3, 3, 5, 2, 10)
reduction_dim: [0, 1, 3, 4]
"""
assert len(block_size) == len(input_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still used? we should be using the code in quant_primitives.py I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It's still used when running the prepared model (model after prepare_pt2e). Is it a bug? Do I need to fix it, too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using the observers defined here: torchao/quantization/pt2e/_affine_quantization.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jerryzh168 May I know your suggestion on this? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be using the ones in torchao/quantization/observer.py eventually

only occurrence seems to be

AffineQuantizedMinMaxObserver,
and we want to update it I think

so if you are adding new things I'd recommend use the one from torchao.quantization

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Xia-Weiwen sorry for the delay, please feel free to work on this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we already use the one from torchao:

but if you saw we are using torch.ao please go ahead and change them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Thanks for the reply. I did not mean torch.ao. I meant there are two versions of such utilities in torchao, torchao.quantization.pt2e and torchao.quantization. For example,

class PartialWrapper:

and
class _PartialWrapper:

The PT2E flow in torchao uses those in torchao.quantization.pt2e while you said you wanted to switch to torchao/quantization/observer.py.
So, I was asking whether you would switch to torchao/quantization/observer.py in PT2E flow first. Do you have any suggestions on that? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, yeah for now use torchao/quantization/observer.py would be better I think, we haven't finalized the folder structure for this one yet

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Am I supposed to wait until you finalize the folder structure? Thanks.

assert block_size == [-1] or len(block_size) == len(input_size)
block_size = input_size if block_size == [-1] else block_size
shape_for_reduction = []
reduction_dims = []
cur_dim = 0
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/pt2e/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ def get_block_size(
"Please provide an instance of Granularity, not subclass of it"
)
if isinstance(granularity, PerTensor):
return input_shape
return [-1]
elif isinstance(granularity, PerAxis):
block_size = list(input_shape)
block_size[granularity.axis] = 1
Expand Down
12 changes: 11 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def _get_reduction_params(block_size, input_size):
shape_for_reduction: (3, 3, 5, 2, 10)
reduction_dim: [0, 1, 3, 4]
"""
assert len(block_size) == len(input_size)
assert block_size == [-1] or len(block_size) == len(input_size)
block_size = input_size if block_size == [-1] else block_size
shape_for_reduction = []
reduction_dims = []
cur_dim = 0
Expand Down Expand Up @@ -365,6 +366,9 @@ def _quantize_affine(
# torch.uintx dtypes yet
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
output_dtype = torch.uint8
if block_size == [-1]:
# per-tensor quantization
block_size = input.shape
return _quantize_affine_no_dtype_cast(
input,
block_size,
Expand Down Expand Up @@ -520,6 +524,9 @@ def _dequantize_affine(
torch.float16,
torch.bfloat16,
], f"Unsupported output dtype: {output_dtype}"
if block_size == [-1]:
# per-tensor quantization
block_size = input.shape
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
input,
Expand Down Expand Up @@ -878,6 +885,9 @@ def _choose_qparams_affine(
scale_dtype = input.dtype
if eps is None:
eps = torch.finfo(input.dtype).eps
if block_size == [-1]:
# per-tensor quantization
block_size = input.shape

assert len(block_size) == input.dim(), (
f"Got input dim:{input.dim()}, block_size: {block_size}"
Expand Down
Loading