Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantization refactor #5

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers<=4.40", "pydantic<2.7"]

def _setup_extras() -> Dict:
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]}
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0",]}

setup(
name="sparsetensors",
Expand Down
21 changes: 21 additions & 0 deletions src/sparsetensors/quantization/lifecycle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .calibration import *
from .forward import *
from .frozen import *
from .initialize import *
from .status import *
43 changes: 43 additions & 0 deletions src/sparsetensors/quantization/lifecycle/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module


__all__ = [
"set_module_for_calibration",
]


_LOGGER = logging.getLogger(__name__)


def set_module_for_calibration(module: Module):
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return
status = getattr(module, "quantization_status", None)
if not status or status != QuantizationStatus.INITIALIZED:
raise _LOGGER.warning(
f"Attempting set module with status {status} to calibration mode. "
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
"be calibrating an uninitialized module which may fail or attempting "
"to re-calibrate a frozen module"
)

module.quantization_status = QuantizationStatus.CALIBRATION
127 changes: 127 additions & 0 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module


__all__ = ["wrap_module_forward_quantized"]


def quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
q_max: torch.Tensor,
) -> torch.Tensor:
return torch.clamp(
torch.round(
x / scale + zero_point,
),
0,
q_max,
)


def dequantize(
x_q: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
) -> torch.Tensor:
return (x_q - zero_point) * scale


def fake_quantize(
x: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, max_q)
return dequantize(Q, scale, zero_point)


def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
# expects a module already initialized and injected with the parameters in
# initialize_module_for_quantization
forward_func_orig = module.forward.__func__

@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
def wrapped_forward(self, *args, **kwargs):
input_ = args[0]

if scheme.input_activations is not None:
# calibrate and (fake) quantize input activations when applicable
input_ = _maybe_calibrate_or_quantize(
module, input_, "input", scheme.input_activations
)

if scheme.weights is not None:
# calibrate and (fake) quantize weights when applicable
self.weight.data = _maybe_calibrate_or_quantize(
module, self.weight, "weight", scheme.weights
)

# perform wrapped forward call
output = forward_func_orig.__get__(module, module.__class__)(
input_, *args[1:], **kwargs
)

if scheme.output_activations is not None:
# calibrate and (fake) quantize output activations when applicable
output = _maybe_calibrate_or_quantize(
module, output, "output", scheme.output_activations
)

return output

# bind wrapped forward to module class so reference to `self` is correct
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
# set forward to wrapped forward
setattr(module, "forward", bound_wrapped_forward)


def _maybe_calibrate_or_quantize(
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# only run quantized for the included stages
if module.quantization_status not in {
QuantizationStatus.CALIBRATION,
QuantizationStatus.FROZEN,
}:
return value

scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")

print(scale, zero_point)

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale
zero_point.data = updated_zero_point

return fake_quantize(value, scale, zero_point, args)
36 changes: 36 additions & 0 deletions src/sparsetensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from torch.nn import Module


__all__ = [
"freeze_module_quantization",
]


def freeze_module_quantization(module: Module):
if not getattr(module, "quantization_scheme", None):
# no quantization scheme nothing to do
return

# delete observers from module
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
delattr(module, submodule_name)

module.quantization_status = QuantizationStatus.FROZEN
71 changes: 71 additions & 0 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import logging

import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter


__all__ = [
"initialize_module_for_quantization",
]


_LOGGER = logging.getLogger(__name__)


def initialize_module_for_quantization(module: Module, scheme: QuantizationScheme):
if scheme.input_activations is not None:

_initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
if scheme.weights is not None:
if hasattr(module, "weight"):
_initialize_scale_zero_point_observer(module, "weight", scheme.weights)
else:
_LOGGER.warning(
f"module type {type(module)} targeted for weight quantization but "
"has no attribute weight, skipping weight quantization "
f"for {type(module)}"
)
if scheme.output_activations is not None:
_initialize_scale_zero_point_observer(
module, "output", scheme.output_activations
)

module.quantization_scheme = scheme
module.quantization_status = QuantizationStatus.INITIALIZED

# wrap forward call of module to perform quantized actions based on calltime status
wrap_module_forward_quantized(module, scheme)


def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initializes empty scale and zero point parameters for the module
init_scale = Parameter(torch.empty(0), requires_grad=False)
module.register_parameter(f"{base_name}_scale", init_scale)

init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
26 changes: 26 additions & 0 deletions src/sparsetensors/quantization/lifecycle/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


__all__ = [
"QuantizationStatus",
]


class QuantizationStatus(Enum):
INITIALIZED = "INITIALIZED"
CALIBRATION = "CALIBRATION"
FROZEN = "FROZEN"
19 changes: 19 additions & 0 deletions src/sparsetensors/quantization/observers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
from .memoryless import *
from .min_max import *
Loading
Loading