Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use loguru, not logging module #212

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "loguru",]

def _setup_extras() -> Dict:
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import json
import logging
import operator
import os
import re
Expand Down Expand Up @@ -48,6 +47,7 @@
fix_fsdp_module_name,
is_compressed_tensors_config,
)
from loguru import logger
from torch import Tensor
from torch.nn import Module
from tqdm import tqdm
Expand All @@ -57,8 +57,6 @@

__all__ = ["ModelCompressor", "map_modules_to_quant_args"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


if TYPE_CHECKING:
# dummy type if not available from transformers
Expand Down Expand Up @@ -332,7 +330,7 @@ def update_config(self, save_directory: str):

config_file_path = os.path.join(save_directory, CONFIG_NAME)
if not os.path.exists(config_file_path):
_LOGGER.warning(
logger.warning(
f"Could not find a valid model config file in "
f"{save_directory}. Compression config will not be saved."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationArgs
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from loguru import logger
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)

__all__ = ["BaseQuantizationCompressor"]


Expand Down Expand Up @@ -77,7 +75,7 @@ def compress(
"""
compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
logger.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

Expand Down
8 changes: 3 additions & 5 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
from loguru import logger
from safetensors import safe_open
from torch import Tensor
from tqdm import tqdm


__all__ = ["BaseSparseCompressor"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


class BaseSparseCompressor(BaseCompressor):
"""
Expand Down Expand Up @@ -67,14 +65,14 @@ def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
:return: compressed state dict
"""
compressed_dict = {}
_LOGGER.debug(
logger.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)
for name, value in tqdm(model_state.items(), desc="Compressing model"):
compression_data = self.compress_weight(name, value)
for key in compression_data.keys():
if key in compressed_dict:
_LOGGER.warn(
logger.warn(
f"Expected all compressed state_dict keys to be unique, but "
f"found an existing entry for {key}. The existing entry will "
"be replaced."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Dict, Generator, Tuple

import numpy as np
Expand All @@ -28,13 +27,11 @@
sparse_semi_structured_from_dense_cutlass,
tensor_follows_mask_structure,
)
from loguru import logger
from torch import Tensor
from tqdm import tqdm


_LOGGER: logging.Logger = logging.getLogger(__name__)


@BaseCompressor.register(name=CompressionFormat.marlin_24.value)
class Marlin24Compressor(BaseCompressor):
"""
Expand Down Expand Up @@ -124,7 +121,7 @@ def compress(

compressed_dict = {}
weight_suffix = ".weight"
_LOGGER.debug(
logger.debug(
f"Compressing model with {len(model_state)} parameterized layers..."
)

Expand Down
9 changes: 3 additions & 6 deletions src/compressed_tensors/quantization/lifecycle/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re
from collections import OrderedDict, defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -44,6 +43,7 @@
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
from compressed_tensors.utils.offload import update_parameter_data
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
from loguru import logger
from torch.nn import Module


Expand All @@ -58,9 +58,6 @@
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict


_LOGGER = logging.getLogger(__name__)


def load_pretrained_quantization(model: Module, model_name_or_path: str):
"""
Loads the quantization parameters (scale and zero point) from model_name_or_path to
Expand Down Expand Up @@ -176,7 +173,7 @@ def apply_quantization_config(

if config.ignore is not None and ignored_submodules is not None:
if set(config.ignore) - set(ignored_submodules):
_LOGGER.warning(
logger.warning(
"Some layers that were to be ignored were "
"not found in the model: "
f"{set(config.ignore) - set(ignored_submodules)}"
Expand Down Expand Up @@ -211,7 +208,7 @@ def process_kv_cache_config(
:return: the QuantizationConfig with additional "kv_cache" group
"""
if targets == KV_CACHE_TARGETS:
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
logger.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")

kv_cache_dict = config.kv_cache_scheme.model_dump()
kv_cache_scheme = QuantizationScheme(
Expand Down
6 changes: 1 addition & 5 deletions src/compressed_tensors/quantization/lifecycle/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.


import logging

import torch
from compressed_tensors.quantization.lifecycle.forward import quantize
from compressed_tensors.quantization.quant_config import QuantizationStatus
from loguru import logger
from torch.nn import Module


Expand All @@ -26,9 +25,6 @@
]


_LOGGER = logging.getLogger(__name__)


def compress_quantized_weights(module: Module):
"""
Quantizes the module weight representation to use fewer bits in memory
Expand Down
6 changes: 1 addition & 5 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import logging
from enum import Enum
from typing import Optional

Expand All @@ -40,9 +39,6 @@
]


_LOGGER = logging.getLogger(__name__)


class KVCacheScaleType(Enum):
KEY = "k_scale"
VALUE = "v_scale"
Expand Down Expand Up @@ -97,7 +93,7 @@ def initialize_module_for_quantization(
force_zero_point=force_zero_point,
)
else:
_LOGGER.warning(
logger.warning(
f"module type {type(module)} targeted for weight quantization but "
"has no attribute weight, skipping weight quantization "
f"for {type(module)}"
Expand Down
6 changes: 2 additions & 4 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Generator, List, Optional, Tuple

import torch
Expand All @@ -23,6 +22,7 @@
QuantizationType,
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor
from torch.nn import Module
from tqdm import tqdm
Expand Down Expand Up @@ -50,8 +50,6 @@
# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
KV_CACHE_TARGETS = ["re:.*self_attn$"]

_LOGGER: logging.Logger = logging.getLogger(__name__)


def calculate_qparams(
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
Expand Down Expand Up @@ -305,7 +303,7 @@ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool:
bit_depth = get_torch_bit_depth(value)
requested_depth = quant_args.num_bits
if bit_depth < quant_args.num_bits:
_LOGGER.warn(
logger.warning(
f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
"The QuantizationArgs provided are not compatible with the input tensor."
)
Expand Down
28 changes: 17 additions & 11 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
apply_quantization_status,
)
from compressed_tensors.quantization.utils import iter_named_leaf_modules
from loguru import logger
from tests.testing_utils import requires_accelerate
from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -233,10 +234,8 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"):
[("lm_head", "re:.*foobarbaz"), True],
],
)
def test_apply_quantization_status(caplog, ignore, should_raise_warning):
import logging

# load a dense, unquantized tiny llama model
def test_apply_quantization_status(ignore, should_raise_warning):
# Load a dense, unquantized Tiny Llama model
model = get_tinyllama_model()
quantization_config_dict = {
"quant_method": "sparseml",
Expand All @@ -259,10 +258,17 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
config = QuantizationConfig(**quantization_config_dict)
config.quantization_status = QuantizationStatus.CALIBRATION

# mismatch in the ignore key of quantization_config_dict
with caplog.at_level(logging.WARNING):
apply_quantization_config(model, config)
if should_raise_warning:
assert len(caplog.text) > 0
else:
assert len(caplog.text) == 0
log_messages = []

def capture_log(msg):
log_messages.append(msg)

logger.remove() # Remove default handler
logger.add(capture_log)

apply_quantization_config(model, config)

if should_raise_warning:
assert len(log_messages) > 0
else:
assert len(log_messages) == 0
Loading