Skip to content

Commit

Permalink
Gptq tokenized dataset (#1584)
Browse files Browse the repository at this point in the history
* allow tokenized dataset

* style

* Update optimum/gptq/quantizer.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/gptq/quantizer.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/gptq/quantizer.py

Co-authored-by: fxmarty <[email protected]>

* Update optimum/gptq/quantizer.py

Co-authored-by: fxmarty <[email protected]>

* add example in docstring

---------

Co-authored-by: fxmarty <[email protected]>
  • Loading branch information
SunMarc and fxmarty authored Dec 13, 2023
1 parent a0140e0 commit afe2e3c
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,10 @@ def __init__(
Args:
bits (`int`):
The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
dataset (`Union[List[str],str]`, defaults to None):
The dataset used for quantization. You can provide your own dataset in a list of string or just use the original datasets used
in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
dataset (`Union[List[str], str, Any]`, defaults to `None`):
The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data
(e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...])
or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new','ptb','ptb-new'].
group_size (int, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
damp_percent (`float`, defaults to `0.1`):
Expand Down Expand Up @@ -297,14 +298,14 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1)

@torch.no_grad()
def quantize_model(self, model: nn.Module, tokenizer: Any):
def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
"""
Quantizes the model using the dataset
Args:
model (`nn.Module`):
The model to quantize
tokenizer (`Any`):
tokenizer (Optional[`Any`], defaults to `None`):
The tokenizer to use in order to prepare the dataset. You can pass either:
- A custom tokenizer object.
- A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
Expand Down Expand Up @@ -355,23 +356,28 @@ def quantize_model(self, model: nn.Module, tokenizer: Any):
device = get_device(model)

# Step 1: Prepare the data
if isinstance(tokenizer, str):
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str):
logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.")
else:
if isinstance(tokenizer, str):
try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
except Exception:
raise ValueError(
f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
)
if self.dataset is None:
raise ValueError("You need to pass `dataset` in order to quantize your model")
elif isinstance(self.dataset, str):
dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
elif isinstance(self.dataset, list):
dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
else:
raise ValueError(
f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}."
)
if self.dataset is None:
raise ValueError("You need to pass `dataset` in order to quantize your model")
elif isinstance(self.dataset, str):
dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
elif isinstance(self.dataset, list):
dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
else:
raise ValueError("You need to pass a list of string or a string for `dataset`")

dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size)

Expand Down

0 comments on commit afe2e3c

Please sign in to comment.