Skip to content

Commit

Permalink
Implement Missing Modifier Methods (#166)
Browse files Browse the repository at this point in the history
* implement missing functions

* do not require on_initialize_structure and on_event to be implemented

* make abc

* remove requirements

* fix import name

* dummy commit

* dummy 2

* dummy commit

* dummy commit revert

---------

Co-authored-by: Kyle Sayers <[email protected]>
Co-authored-by: Dipika Sikka <[email protected]>
  • Loading branch information
3 people authored Sep 30, 2024
1 parent 5cb95b7 commit f12b3c7
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 87 deletions.
19 changes: 11 additions & 8 deletions src/llmcompressor/modifiers/modifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional

from pydantic import BaseModel
Expand All @@ -9,7 +10,7 @@
__all__ = ["Modifier"]


class Modifier(BaseModel, ModifierInterface):
class Modifier(BaseModel, ModifierInterface, ABC):
"""
A base class for all modifiers to inherit from.
Modifiers are used to modify the training process for a model.
Expand Down Expand Up @@ -224,15 +225,17 @@ def should_end(self, event: Event):
def on_initialize_structure(self, state: State, **kwargs):
"""
on_initialize_structure is called before the model is initialized
with the modifier structure. Must be implemented by the inheriting
modifier.
with the modifier structure.
TODO: Depreciate this function as part of the lifecycle
:param state: The current state of the model
:param kwargs: Additional arguments for initializing the structure
of the model in question
"""
raise NotImplementedError()
pass

@abstractmethod
def on_initialize(self, state: State, **kwargs) -> bool:
"""
on_initialize is called on modifier initialization and
Expand All @@ -255,7 +258,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
:return: True if the modifier was finalized successfully,
False otherwise
"""
raise NotImplementedError()
return True

def on_start(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -266,7 +269,7 @@ def on_start(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the start
:param kwargs: Additional arguments for starting the modifier
"""
raise NotImplementedError()
pass

def on_update(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -278,7 +281,7 @@ def on_update(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the update
:param kwargs: Additional arguments for updating the model
"""
raise NotImplementedError()
pass

def on_end(self, state: State, event: Event, **kwargs):
"""
Expand All @@ -289,7 +292,7 @@ def on_end(self, state: State, event: Event, **kwargs):
:param event: The event that triggered the end
:param kwargs: Additional arguments for ending the modifier
"""
raise NotImplementedError()
pass

def on_event(self, state: State, event: Event, **kwargs):
"""
Expand Down
26 changes: 1 addition & 25 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module
from tqdm import tqdm

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
Expand Down Expand Up @@ -83,25 +83,12 @@ class SparseGPTModifier(Modifier):
prunem_: Optional[int] = None
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
Initialize the structure of the model for compression.
This modifier does not modifiy the model structure, so this method
is a no-op.
:param state: session state storing input model and calibration data
"""
return True

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)

if self.sparsity == 0.0:
raise ValueError(
"To use the SparseGPTModifier, target sparsity must be > 0.0"
Expand All @@ -121,17 +108,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs):
"""
Nothing to do on finalize, on this level.
Quantization Modifier if any will be finalized in the subclass
:param state: session state storing input model and calibration data
:param kwargs: additional arguments
:return: True
"""
return True

def initialize_compression(
self,
model: Module,
Expand Down
3 changes: 0 additions & 3 deletions src/llmcompressor/modifiers/pruning/constant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class ConstantPruningModifier(Modifier, LayerParamMasking):
_save_masks: bool = False
_use_hooks: bool = False

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
if "save_masks" in kwargs:
self._save_masks = kwargs["save_masks"]
Expand Down
3 changes: 0 additions & 3 deletions src/llmcompressor/modifiers/pruning/magnitude/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
Expand Down
21 changes: 1 addition & 20 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn import Module
from tqdm import tqdm

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.wanda.utils.wanda_wrapper import WandaWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
Expand Down Expand Up @@ -61,15 +61,6 @@ class WandaPruningModifier(Modifier):
prunen_: Optional[int] = None
prunem_: Optional[int] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
This modifier does not alter the model structure.
This method is a no-op.
:param state: Unused, kept to conform to the parent method signature
:param kwargs: Unused, kept to conform to the parent method signature
"""

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the WANDA algorithm on the current state
Expand All @@ -91,16 +82,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs):
"""
Nothing to clean up for this module
:param state: Unused, kept to conform to the parent method signature
:param kwargs: Unused, kept to conform to the parent method signature
"""

return True

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
Expand Down
4 changes: 3 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pydantic import Field, field_validator
from torch.nn import Module

from llmcompressor.core.state import State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier, ModifierFactory
from llmcompressor.modifiers.quantization.gptq.utils import (
GPTQWrapper,
Expand Down Expand Up @@ -130,6 +130,8 @@ def on_initialize_structure(self, state: State, **kwargs):
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
TODO: Depreciate and fold into `on_initialize`
:param state: session state storing input model and calibration data
"""
quantization_already_active = qat_active(state.model)
Expand Down
9 changes: 0 additions & 9 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,6 @@ class QuantizationModifier(Modifier):
calibration_dataloader_: Any = None
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
pass

def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
raise ValueError(
Expand All @@ -99,9 +96,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_finalize(self, state: State, **kwargs) -> bool:
return True

def on_start(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(set_module_for_calibration)
Expand All @@ -116,9 +110,6 @@ def on_end(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(freeze_module_quantization)

def on_event(self, state: State, event: Event, **kwargs):
pass

def create_init_config(self) -> QuantizationConfig:
if self.targets is not None and isinstance(self.targets, str):
self.targets = [self.targets]
Expand Down
17 changes: 1 addition & 16 deletions src/llmcompressor/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from loguru import logger
from torch.nn import Module

from llmcompressor.core import Event, State
from llmcompressor.core import State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
Expand Down Expand Up @@ -102,9 +102,6 @@ class SmoothQuantModifier(Modifier):
resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state
Expand Down Expand Up @@ -136,18 +133,6 @@ def on_initialize(self, state: State, **kwargs) -> bool:

return True

def on_start(self, state: State, event: Event, **kwargs):
pass

def on_update(self, state: State, event: Event, **kwargs):
pass

def on_end(self, state: State, event: Event, **kwargs):
pass

def on_event(self, state: State, event: Event, **kwargs):
pass

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data
Expand Down
6 changes: 4 additions & 2 deletions tests/llmcompressor/recipe/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,13 @@ def test_recipe_can_be_created_from_modifier_instances():


class A_FirstDummyModifier(Modifier):
pass
def on_initialize(self, *args, **kwargs) -> bool:
return True


class B_SecondDummyModifier(Modifier):
pass
def on_initialize(self, *args, **kwargs) -> bool:
return True


def test_create_recipe_string_from_modifiers_with_default_group_name():
Expand Down

0 comments on commit f12b3c7

Please sign in to comment.