Skip to content

Commit

Permalink
Clean up observer defaulting logic, better error message (#200)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylesayrs authored Nov 1, 2024
1 parent 2b79056 commit 37df2dd
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
24 changes: 10 additions & 14 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,6 @@ def get_observer(self):
"""
:return: torch quantization FakeQuantize built based on these QuantizationArgs
"""

# No observer required for the dynamic case
if self.dynamic:
self.observer = None
return self.observer

return self.observer

@field_validator("type", mode="before")
Expand Down Expand Up @@ -203,6 +197,7 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
"activation ordering"
)

# infer observer w.r.t. dynamic
if dynamic:
if strategy not in (
QuantizationStrategy.TOKEN,
Expand All @@ -214,18 +209,19 @@ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
"quantization",
)
if observer is not None:
warnings.warn(
"No observer is used for dynamic quantization, setting to None"
)
model.observer = None
if observer != "memoryless": # avoid annoying users with old configs
warnings.warn(
"No observer is used for dynamic quantization, setting to None"
)
observer = None

# if we have not set an observer and we
# are running static quantization, use minmax
if not observer and not dynamic:
model.observer = "minmax"
elif observer is None:
# default to minmax for non-dynamic cases
observer = "minmax"

# write back modified values
model.strategy = strategy
model.observer = observer
return model

def pytorch_dtype(self) -> torch.dtype:
Expand Down
2 changes: 1 addition & 1 deletion src/compressed_tensors/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def get_from_registry(
retrieved_value = _import_and_get_value_from_module(module_path, value_name)
else:
# look up name in alias registry
name = _ALIAS_REGISTRY[parent_class].get(name)
name = _ALIAS_REGISTRY[parent_class].get(name, name)
# look up name in registry
retrieved_value = _REGISTRY[parent_class].get(name)
if retrieved_value is None:
Expand Down

0 comments on commit 37df2dd

Please sign in to comment.