Skip to content

Commit

Permalink
Address: Review comments from @kylesayrs
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Nov 27, 2024
1 parent 305904c commit c54699a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 32 deletions.
10 changes: 6 additions & 4 deletions src/compressed_tensors/compressors/sparse_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,17 @@ def decompress(
yield other_name, value

@staticmethod
def should_compress(name: str, targets: Optional[Set[str]] = None) -> bool:
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
"""
Check if a parameter should be compressed
:param name: name of the parameter
:param targets: set of layer prefixes to compress
:param expanded_targets: set of layer prefixes to compress
:return: whether or not the parameter should be compressed
"""
if targets is None:
if expanded_targets is None:
return name.endswith(".weight")

return name.endswith(".weight") and name[: -(len(".weight"))] in targets
return (
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
)
23 changes: 3 additions & 20 deletions src/compressed_tensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@
"get_nested_weight_mappings",
"get_quantization_state_dict",
"is_quantization_param",
"get_nested_mappings_from_state_dict",
]

WEIGHT_MAPPING_TYPE = Dict[str, str]
NESTED_WEIGHT_MAPPING_TYPE = Dict[str, WEIGHT_MAPPING_TYPE]
WeightMappingType = Dict[str, str]
NestedWeightMappingType = Dict[str, WeightMappingType]


def get_safetensors_folder(
Expand Down Expand Up @@ -181,9 +180,7 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:

def get_nested_weight_mappings(
model_path: str, params_to_nest: List[str], return_other_params: bool = False
) -> Union[
NESTED_WEIGHT_MAPPING_TYPE, Tuple[NESTED_WEIGHT_MAPPING_TYPE, WEIGHT_MAPPING_TYPE]
]:
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
"""
Takes a path to a state dict saved in safetensors format and returns a nested
mapping from uncompressed parameterized layer names to the file locations of each
Expand Down Expand Up @@ -256,17 +253,3 @@ def is_quantization_param(name: str) -> bool:
return True

return False


def get_nested_mappings_from_state_dict(state_dict, params_to_nest):
nested_weight_mappings = {}
for key in state_dict.keys():
for param_name in params_to_nest:
maybe_match = match_param_name(key, param_name)
if maybe_match is not None:
dense_param = maybe_match
if dense_param not in nested_weight_mappings:
nested_weight_mappings[dense_param] = {}
nested_weight_mappings[dense_param][param_name] = state_dict[key]

return nested_weight_mappings
16 changes: 8 additions & 8 deletions tests/test_quantization/lifecycle/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):


@pytest.mark.parametrize(
"targets, ignore, expected",
"targets, ignore, expected_targets",
[
([], [], set()),
(["layer1", "layer2"], [], {"layer1", "layer2"}),
Expand All @@ -305,13 +305,13 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning):
(["re:layer.*"], ["layer3"], {"layer1", "layer2"}),
],
)
def test_expand_targets_with_mock(mock_model, targets, ignore, expected):
result = expand_targets(mock_model, targets, ignore)
assert result == expected
def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets):
expanded_targets = expand_targets(mock_model, targets, ignore)
assert expanded_targets == expected_targets


@pytest.mark.parametrize(
"targets, ignore, expected",
"targets, ignore, expected_targets",
[
(
["re:model.layers.[01].self_attn.q_proj"],
Expand Down Expand Up @@ -344,10 +344,10 @@ def test_expand_targets_with_mock(mock_model, targets, ignore, expected):
],
)
def test_expand_targets_with_llama_stories(
llama_stories_model, targets, ignore, expected
llama_stories_model, targets, ignore, expected_targets
):
actual_targets = expand_targets(llama_stories_model, targets, ignore)
assert actual_targets == expected
expanded_targets = expand_targets(llama_stories_model, targets, ignore)
assert expanded_targets == expected_targets


@pytest.mark.parametrize(
Expand Down

0 comments on commit c54699a

Please sign in to comment.