Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 24 compressor #167

Open
wants to merge 5 commits into
base: add-targets-and-ignore-support
Choose a base branch
from

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Sep 26, 2024

This PR introduces the Sparse24Compressor, designed for 2:4 sparse models. The implementation is based on #182 and corresponds to Part 3 of the [Design Document](https://www.notion.so/Design-Document-24-Compressor-25ac643aee604c298f2bb12a6c220861?pvs=4).


Key Changes

  • New Feature: Implementation of Sparse24Compressor for handling 2:4 sparsity in models.
  • Testing: validation, including support and tests for torch.float8e4m3 dtype.

Class Hierarchy

The Sparse24Compressor follows the established compressor class hierarchy:

BaseCompressor (Abstract Class)
    |
    +-- BaseSparsityCompressor (Abstract Class)
           |
           +-- Sparse24Compressor

File Structure

The Sparse24Compressor and associated logic are placed within the sparse_compressors module:

compressors/
└── sparse_compressors/
    ├── __init__.py
    ├── base.py                 <-- Contains BaseSparsityCompressor
    ├── dense.py
    ├── sparse_bitmask.py
    └── sparse24.py             <-- New file for Sparse24Compressor

Click to expand Verification Methodology The `Sparse24Compressor` was tested using a comprehensive script that validates its behavior through the following steps: 1. **Load Model**: An uncompressed model is loaded from the Hugging Face model hub or a local directory. 2. **Compression**: The model is compressed using `ModelCompressor`, and the compressed version is saved. 3. **Decompression**: A new base model is initialized, and the compressed weights are decompressed using `ModelCompressor.decompress`. 4. **Parameter Validation**: Parameters in the decompressed model are verified to match the original uncompressed model. 5. **Inference Check**: The decompressed model is used to generate text, ensuring correctness and functionality.
Click to expand the Verification Script
import torch
from transformers import AutoModelForCausalLM
from compressed_tensors.compressors import ModelCompressor
from transformers import AutoTokenizer
from llmcompressor.transformers import oneshot
from compressed_tensors.config import Sparse24Config

# Load uncompressed model
hf_model_stub = "nm-testing/TinyLlama-1.1B-Chat-v1.0-pruned_50.2of4-uncompressed"
uncompressed_model = AutoModelForCausalLM.from_pretrained(hf_model_stub, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(hf_model_stub)

# Compress the model using Sparse24Compressor
oneshot(model=uncompressed_model)
compressed_save_dir = "temp-model"
sparsity_config = Sparse24Config(targets=["Linear"], ignore=["lm_head"])
uncompressed_model.save_pretrained(save_directory=compressed_save_dir, sparsity_config=sparsity_config)
tokenizer.save_pretrained(save_directory=compressed_save_dir)

# Decompress the model
base_stub = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
decompressed_model = AutoModelForCausalLM.from_pretrained(base_stub, torch_dtype="auto", device_map="auto")
compressor = ModelCompressor.from_pretrained(compressed_save_dir)
compressor.decompress(model_path=compressed_save_dir, model=decompressed_model)

# Verify parameters match
decompressed_state_dict = decompressed_model.state_dict()
uncompressed_state_dict = uncompressed_model.state_dict()

for key in decompressed_state_dict.keys():
    assert key in uncompressed_state_dict.keys()
    decompressed_tensor = decompressed_state_dict[key]
    uncompressed_tensor = uncompressed_state_dict[key]
    assert torch.equal(decompressed_tensor, uncompressed_tensor), f"Tensor {key} mismatch."

print("All parameters match the original model.")
print("Inference on the decompressed model:")

# Inference check
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = decompressed_model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
Click to expand the sample output generation from decompressed model
All parameters match the original model.
Inference on the decompressed model:

========== SAMPLE GENERATION ==============
<s> Hello my name is John. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California. I am a student at the University of California.
==========================================

Note: the fp8 test can only run on GPU's with cuda capability > 90
Proof that it passes on the right device:

(.venv) ➜  compressed-tensors git:(add-24-compressor) ✗ pytest tests/test_utils/test_semi_structured_conversions.py
=========================================================== test session starts ============================================================
platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/rahul/compressed-tensors
configfile: pyproject.toml
collected 4 items                                                                                                                          

tests/test_utils/test_semi_structured_conversions.py ....                                                                            [100%]

============================================================ 4 passed in 1.51s =============================================================

@rahul-tuli rahul-tuli changed the base branch from main to update-folder-structure-compressors September 26, 2024 15:42
horheynm
horheynm previously approved these changes Sep 26, 2024
Copy link
Member

@horheynm horheynm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean.
lgtm after tests
!!

@rahul-tuli rahul-tuli force-pushed the update-folder-structure-compressors branch 2 times, most recently from 2f69d16 to fc4b23c Compare October 2, 2024 20:56
@rahul-tuli rahul-tuli force-pushed the update-folder-structure-compressors branch 2 times, most recently from dd16499 to 7155e61 Compare October 2, 2024 21:06
Base automatically changed from update-folder-structure-compressors to main October 3, 2024 00:43
@mgoin mgoin dismissed horheynm’s stale review October 3, 2024 00:43

The base branch was changed.

Copy link
Member

@markurtz markurtz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall code looks simple. I'd like to reformulate the scope, though. Specifically, I'm not following why we are restricting to just 2:4 right now when we could easily expand this to handle all sparsity cases and detect whether it is 2:4 format, some type of structured pruning, and if not any then set as unstructured. cc @dsikka

Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testing?

@rahul-tuli rahul-tuli changed the base branch from main to add-targets-and-ignore-support November 27, 2024 13:58
Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Only question I had was related to the changes we're making to the sparse_semi_struc to/from methods. Are we making these changes based on kernel compatibility from what was originally in vllm?

@@ -85,7 +86,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
device = dense.device

meta_dtype = torch.int8
if dense.dtype == torch.int8:
if dense.dtype == torch.int8 or dense.dtype == torch.float8_e4m3fn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is meta ever int8?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made None

@@ -165,11 +166,15 @@ def sparse_semi_structured_from_dense_cutlass(dense):
idxs1 = bit2 | (bit3.to(torch.int64) << 1)

if dense.dtype != torch.float:
if dense.dtype == torch.float8_e4m3fn:
dense_4 = dense_4.view(torch.int8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required by the kernel only for fp8?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a quirk for only fp8 dtype because certain operation are not implemented for this dtype. So we have this hack to view it as int8, this does not move the data

Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of nits that can land later, lgtm!



@pytest.mark.parametrize("dtype", supported_dtypes())
def test_inverse_property_from_dense_then_to_dense(dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Id personally create a test for sparse_semi_structured_from_dense_cutlass and sparse_semi_structured_to_dense_cutlass. This combines both important operations, but not a blocker

@pytest.mark.parametrize("dtype", supported_dtypes())
def test_inverse_property_from_dense_then_to_dense(dtype):
M, K = 1024, 1024
dense_matrix = generate_pruned_semi_structured_mat(M, K, dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this outputs booleans in the pattern of [False, F, T, T], this is dense?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants