diff --git a/noxfile.py b/noxfile.py index f71510979..0d6680460 100644 --- a/noxfile.py +++ b/noxfile.py @@ -142,8 +142,6 @@ def tests_brevitas_examples_llm(session, pytorch, jit_status): install_pytorch(pytorch, session) install_torchvision(pytorch, session) # Optimum seems to require torchvision session.install('-e', '.[test, llm, export]') - session.install( - 'optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main') session.run('pytest', '-n', 'logical', '-k', 'llm', 'tests/brevitas_examples/test_llm.py') diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt index 9bc21d251..0afb8b765 100644 --- a/requirements/requirements-llm.txt +++ b/requirements/requirements-llm.txt @@ -1,3 +1,7 @@ -# optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main +accelerate +datasets +onnx +onnxruntime +optimum tqdm transformers[sentencepiece]==4.45.2 diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index de55258db..a2979b45b 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -6,7 +6,6 @@ - datasets - torch_mlir (optional for torch-mlir based export) - optimum -- optimum-amd (install from main) - accelerate ## Run diff --git a/src/brevitas_examples/llm/llm_quant/data.py b/src/brevitas_examples/llm/llm_quant/data.py index a535feae9..d58515e7a 100644 --- a/src/brevitas_examples/llm/llm_quant/data.py +++ b/src/brevitas_examples/llm/llm_quant/data.py @@ -1,19 +1,27 @@ """ -Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: +Adapted from https://github.com/huggingface/optimum-amd, released under the following LICENSE: -Copyright 2023 IST-DASLab +MIT License -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +Copyright (c) 2023 Hugging Face - http://www.apache.org/licenses/LICENSE-2.0 +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. """ import random @@ -24,50 +32,61 @@ from tqdm import tqdm -def get_c4(nsamples, seed, seqlen, tokenizer, split='train', nvalsamples=0): - if split == 'train': +def get_c4( + tokenizer: Any, + seqlen: int, + nsamples: int, + split: str = "train", + fuse_sequences: bool = True, + seed: int = 42): + random.seed(seed) + + if split == "train": data = load_dataset( - 'allenai/c4', - 'allenai--c4', - data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, - split='train', - use_auth_token=False) - - random.seed(seed) - dataloader = [] - for _ in range(nsamples): - while True: - i = random.randint(0, len(data) - 1) - trainenc = tokenizer(data[i]['text'], return_tensors='pt') - if trainenc.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = trainenc.input_ids[:, i:j] - dataloader.append(inp) - return dataloader - elif split == 'validation': + "allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"}) + elif split == "validation": data = load_dataset( - 'allenai/c4', - 'allenai--c4', - data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, - split='validation', - use_auth_token=False) - - random.seed(0) # hardcoded for validation reproducibility - valenc = [] - for _ in range(nvalsamples): - while True: - i = random.randint(0, len(data) - 1) - tmp = tokenizer(data[i]['text'], return_tensors='pt') - if tmp.input_ids.shape[1] >= seqlen: - break - i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + "allenai/c4", + split="validation", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + ) + + if fuse_sequences: + data = data.shuffle(seed=seed)[:10000] # c4 is too big. + full_text = "\n\n".join(data["text"]) + tokenized_data = tokenizer(full_text, return_tensors="pt") + + dataset = [] + for _ in range(nsamples): + i = random.randint(0, tokenized_data.input_ids.shape[1] - seqlen - 1) j = i + seqlen - valenc.append(tmp.input_ids[:, i:j]) + inp = tokenized_data.input_ids[:, i:j] + attention_mask = torch.ones((1, seqlen), dtype=torch.int64) + dataset.append({"input_ids": inp, "attention_mask": attention_mask}) + else: + dataset = [] + with tqdm(total=nsamples) as pbar: + while len(dataset) < nsamples: + data_index = random.randint(0, len(data) - 1) + + enc = tokenizer(data[data_index]["text"], return_tensors="pt") + + if enc["input_ids"].shape[1] < seqlen: + continue + + start_idx = random.randint(0, enc["input_ids"].shape[1] - seqlen) + end_idx = start_idx + seqlen - 1 + attention_mask = torch.ones((1, seqlen), dtype=torch.int64) + input_ids = enc["input_ids"][:, start_idx:end_idx + 1] + + # Add BOS token. + if tokenizer.eos_token_id is not None: + input_ids[:, 0] = tokenizer.eos_token_id + + dataset.append({"input_ids": input_ids, "attention_mask": attention_mask}) + pbar.update(1) - valenc = torch.hstack(valenc) - return valenc + return dataset def get_wikitext2( diff --git a/src/brevitas_examples/llm/llm_quant/data_utils.py b/src/brevitas_examples/llm/llm_quant/data_utils.py index 1ff82c157..7946bc0a9 100644 --- a/src/brevitas_examples/llm/llm_quant/data_utils.py +++ b/src/brevitas_examples/llm/llm_quant/data_utils.py @@ -25,18 +25,53 @@ """ import random -from typing import Any, Optional, Union +from typing import Any, Iterable, List, Optional, Union import numpy as np -from optimum.amd.brevitas.data_utils import DatasetToDevice -from optimum.amd.brevitas.data_utils import get_c4 from optimum.utils.normalized_config import NormalizedConfigManager import torch from transformers import AutoConfig +from .data import get_c4 from .data import get_wikitext2 +class DatasetToDevice(torch.utils.data.Dataset): + + def __init__(self, data: List, device: Optional[Union[str, torch.device]]): + super().__init__() + self.data = data + self.device = device + + def __getitem__(self, idx): + if self.device is not None: + return { + name: recursive_to_device(val, self.device) for name, val in self.data[idx].items()} + else: + return self.data[idx] + + def __len__(self): + return len(self.data) + + +@torch.no_grad() +def recursive_to_device(tensor_or_iterable: Union[Iterable, torch.Tensor], device) -> None: + if isinstance(tensor_or_iterable, torch.Tensor): + return tensor_or_iterable.to(device) + elif isinstance(tensor_or_iterable, + tuple): # Special handling of tuples, since they are immutable + tmp_list = [] + for i in tensor_or_iterable: + tmp_list.append(recursive_to_device(i, device)) + return tuple(tmp_list) + elif isinstance(tensor_or_iterable, Iterable): + for i in tensor_or_iterable: + tensor_or_iterable[i] = recursive_to_device(i, device) + return tensor_or_iterable + else: + raise ValueError(f"Cannot move {type(tensor_or_iterable)} to {device}") + + def get_dataset_for_model( model_name_or_path: str, dataset_name: str, diff --git a/src/brevitas_examples/llm/llm_quant/eval.py b/src/brevitas_examples/llm/llm_quant/eval.py index 27ef16c97..a69f64fdb 100644 --- a/src/brevitas_examples/llm/llm_quant/eval.py +++ b/src/brevitas_examples/llm/llm_quant/eval.py @@ -1,25 +1,39 @@ """ -Adapted from https://github.com/IST-DASLab/gptq, released under the following LICENSE: +Adapted from https://github.com/huggingface/optimum-amd, released under the following LICENSE: -Copyright 2023 IST-DASLab +MIT License -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +Copyright (c) 2023 Hugging Face - http://www.apache.org/licenses/LICENSE-2.0 +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. """ +import random +from typing import Any, Dict, List + +import numpy as np import torch from torch import nn from tqdm import tqdm +from brevitas_examples.llm.llm_quant.data_utils import recursive_to_device + def create_validation_dataloader(data, seqlen, device): nsamples = data['input_ids'].numel() // seqlen @@ -32,19 +46,63 @@ def create_validation_dataloader(data, seqlen, device): @torch.no_grad() -def model_eval(model, valenc, seqlen): - nsamples = len(valenc) - with torch.no_grad(): - nlls = [] - for inps in valenc: - lm_logits = model(**inps)['logits'] - shift_logits = lm_logits[:, :-1, :].contiguous() - dev = shift_logits.device - shift_labels = inps['input_ids'][:, 1:].to(dev) - shift_logits = shift_logits.to(dev) - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - neg_log_likelihood = loss.float() * seqlen - nlls.append(neg_log_likelihood) - ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * seqlen)) +def compute_perplexity( + model: torch.nn.Module, + data: List[Dict], + context_length: int, + tokenizer: Any, + seed: int = 0): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + model = model.eval() + + cross_entropy_loss = nn.CrossEntropyLoss() + + nlls = [] + for sample in tqdm(data, desc="Computing perplexity..."): + sample_length = sample["input_ids"].shape[1] + for start_index in range(0, sample_length, context_length * 2): + end_index = min(start_index + sample_length, sample_length - 1) + + subsample = { + "input_ids": sample["input_ids"][:, start_index:end_index + 1], + "attention_mask": sample["attention_mask"][:, start_index:end_index + 1],} + + # In case we are using torch.fx, we can not have optional inputs, and we have traced the model with past_key_values inputs, thus we need them here as well. + if "past_key_values" in sample and isinstance(model, torch.fx.GraphModule): + subsample["past_key_values"] = sample["past_key_values"] + + # Add BOS token. + if tokenizer.bos_token_id is not None: + subsample["input_ids"][:, 0] = tokenizer.bos_token_id + + use_accelerate = hasattr(model, "hf_device_map") + if not use_accelerate or (use_accelerate and not hasattr(model, "_hf_hook")): + device = next(model.parameters()).device + for name, val in subsample.items(): + subsample[name] = recursive_to_device(val, device) + else: + # In accelerate by default `io_same_device=True`, and here we want the of the model output on device. + device = model._hf_hook.execution_device + for name, val in subsample.items(): + subsample[name] = recursive_to_device(val, device) + + lm_logits = model(**subsample)["logits"] + + reference_labels = subsample["input_ids"][:, context_length:] + + shift_logits = lm_logits[:, context_length - 1:-1] + + # Fuse batch and sequence length dimensions. + reference_labels = reference_labels.view(reference_labels.shape[-1]) + shift_logits = shift_logits.view(-1, shift_logits.shape[-1]) + + loss = cross_entropy_loss(shift_logits, reference_labels) + + nlls.append(loss) + + ppl = torch.exp(torch.stack(nlls).mean()) + return ppl diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 03035dd02..0b2fbfaf5 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -7,7 +7,6 @@ from warnings import warn import numpy as np -from optimum.amd.brevitas.data_utils import compute_perplexity from optimum.exporters.onnx import onnx_export_from_model import torch from transformers import AutoModelForCausalLM @@ -30,6 +29,7 @@ from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization +from brevitas_examples.llm.llm_quant.eval import compute_perplexity from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode from brevitas_examples.llm.llm_quant.gpxq import apply_gpfq diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 105a1ea8b..a1d52a3af 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -17,7 +17,6 @@ from brevitas import config from brevitas import torch_version -# LLM example depends on optimum-amd, which requires PyTorch>=2.2 from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args from tests.marker import jit_disabled_for_export