diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd5d07ba..03479203 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,20 +24,20 @@ repos: - id: ruff-format - id: ruff - - repo: local - hooks: - - id: fast-test - name: fast-test - entry: make - args: ["fast-test"] - language: system - pass_filenames: false - - id: clean - name: clean - entry: make - args: ["clean"] - language: system - pass_filenames: false + # - repo: local + # hooks: + # - id: fast-test + # name: fast-test + # entry: make + # args: ["fast-test"] + # language: system + # pass_filenames: false + # - id: clean + # name: clean + # entry: make + # args: ["clean"] + # language: system + # pass_filenames: false - repo: local hooks: diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index 7841cf7c..5101952d 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -1,3 +1,4 @@ +import base64 import logging from copy import deepcopy from dataclasses import dataclass, field @@ -8,6 +9,8 @@ import torch from ..utils import ( + convert_to_safetensor, + dequantize_safetensor, drop_padding, get_sequences_from_batched_steps, json_advanced_dump, @@ -159,6 +162,59 @@ def __post_init__(self): if self.attr_pos_end is None or self.attr_pos_end > len(self.target): self.attr_pos_end = len(self.target) + def _convert_to_safetensors(self, scores_precision="float32"): + """ + Converts tensor attributes within the class to the specified precision. + The conversion is based on the specified `scores_precision`. + If the input tensor is already of the desired precision, no conversion occurs. + For float8, the function performs scaling and converts to uint8, which can be later converted back to float16 upon reloading. + + Args: + scores_precision (str, optional): Desired output data type precision.Defaults to "float32". + Returns: + self: The function modifies the class attributes in-place. + """ + + if self.source_attributions is not None: + self.source_attributions = convert_to_safetensor( + self.source_attributions.contiguous(), quantization=scores_precision + ) + if self.target_attributions is not None: + self.target_attributions = convert_to_safetensor( + self.target_attributions.contiguous(), quantization=scores_precision + ) + if self.step_scores is not None: + self.step_scores = { + k: convert_to_safetensor(v.contiguous(), quantization=scores_precision) + for k, v in self.step_scores.items() + } + if self.sequence_scores is not None: + self.sequence_scores = { + k: convert_to_safetensor(v.contiguous(), quantization=scores_precision) + for k, v in self.sequence_scores.items() + } + return self + + def _recover_from_safetensors(self): + """ + Converts tensor attributes within the class from b64-encoded safetensors to torch tensors.`. + Args: + self + Returns: + self + """ + if self.source_attributions is not None: + self.source_attributions = dequantize_safetensor(base64.b64decode(self.source_attributions)) + if self.target_attributions is not None: + self.target_attributions = dequantize_safetensor(base64.b64decode(self.target_attributions)) + if self.step_scores is not None: + self.step_scores = {k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.step_scores.items()} + if self.sequence_scores is not None: + self.sequence_scores = { + k: dequantize_safetensor(base64.b64decode(v)) for k, v in self.sequence_scores.items() + } + return self + @staticmethod def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable: if attr.source_attributions is None or name.startswith("decoder"): @@ -546,6 +602,7 @@ def save( ndarray_compact: bool = True, use_primitives: bool = False, split_sequences: bool = False, + scores_precision: str = "float32", ) -> None: """Save class contents to a JSON file. @@ -572,17 +629,25 @@ def save( raise ValueError(f"{path} already exists. Override with overwrite=True.") save_outs = [] paths = [] + self_out = deepcopy(self) if split_sequences: - for i, seq in enumerate(self.sequence_attributions): + for i, seq in enumerate(self_out.sequence_attributions): attr_out = deepcopy(self) - attr_out.sequence_attributions = [seq] + attr_out.sequence_attributions = [ + seq._convert_to_safetensors(scores_precision=scores_precision) + ] # this overwrites the original attr_out.step_attributions = None attr_out.info["input_texts"] = [attr_out.info["input_texts"][i]] attr_out.info["generated_texts"] = [attr_out.info["generated_texts"][i]] save_outs.append(attr_out) paths.append(f"{str(path).split('.json')[0]}_{i}.json{'.gz' if compress else ''}") else: - save_outs.append(self) + self_out = deepcopy(self) + self_out.sequence_attributions = [ + seq._convert_to_safetensors(scores_precision=scores_precision) + for seq in self_out.sequence_attributions + ] + save_outs.append(self_out) paths.append(path) for attr_out, path_out in zip(save_outs, paths): with open(path_out, f"w{'b' if compress else ''}") as f: @@ -615,9 +680,9 @@ def load( :class:`~inseq.data.FeatureAttributionOutput`: Loaded attribution output """ out = json_advanced_load(path, decompression=decompress) - out.sequence_attributions = [seq.torch() for seq in out.sequence_attributions] + out.sequence_attributions = [seq._recover_from_safetensors() for seq in out.sequence_attributions] if out.step_attributions is not None: - out.step_attributions = [step.torch() for step in out.step_attributions] + out.step_attributions = [step._recover_from_safetensors() for step in out.step_attributions] return out def aggregate( diff --git a/inseq/utils/__init__.py b/inseq/utils/__init__.py index f632ba32..53c39c5e 100644 --- a/inseq/utils/__init__.py +++ b/inseq/utils/__init__.py @@ -49,6 +49,8 @@ from .torch_utils import ( aggregate_contiguous, check_device, + convert_to_safetensor, + dequantize_safetensor, euclidean_distance, filter_logits, find_block_stack, @@ -69,6 +71,8 @@ "UnknownAttributionMethodError", "MissingAlignmentsError", "cache_results", + "convert_to_safetensor", + "dequantize_safetensor", "optional", "pad", "pretty_list", diff --git a/inseq/utils/serialization.py b/inseq/utils/serialization.py index b45f8580..f7966191 100644 --- a/inseq/utils/serialization.py +++ b/inseq/utils/serialization.py @@ -29,6 +29,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import base64 import json from collections import OrderedDict from json import JSONEncoder @@ -59,6 +60,8 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k """ if isinstance(obj, (list, dict)): return obj + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("UTF8") if hasattr(obj, "__class__") and hasattr(obj, "__dict__"): if not hasattr(obj, "__new__"): raise TypeError(f"class '{obj.__class__}' does not have a __new__ method; ") @@ -84,9 +87,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k dct["attributes"] = hashodict(obj.__dict__) if use_primitives: attrs = dct.get("attributes", {}) - return attrs - else: - return dct + return attrs if use_primitives else dct return obj diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index 86acd635..7d9fb9fc 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -1,7 +1,10 @@ +import json import logging +import struct from collections.abc import Sequence from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +import safetensors import torch import torch.nn.functional as F from jaxtyping import Int, Num @@ -38,6 +41,58 @@ def remap_from_filtered( return new_source.scatter(0, index, filtered) +def convert_to_safetensor(tensor: torch.Tensor, quantization="float32") -> bytes: + """ + Converts a torch tensor to a safetensor, and optionally quantizes the weights with zero-point quantization. + Quantization parameters are saved in the safetensor to be used on reloading. + Adapted from https://towardsdatascience.com/introduction-to-weight-quantization-2494701b9c0c + + Args: + tensor (torch.Tensor): some torch tensor + quantization (str): format to quantize weights to [float32, float16, float8] + Returns: + bytes: A safetensor in bytes format + Raises: + ValueError if `quantization` doesn't match the possible options + + """ + metadata_dict = {"quantization": quantization} + if quantization == "float32": + return safetensors.torch.save({"attribution": tensor}, metadata=metadata_dict) + + negatives = torch.any(tensor < 0) + if quantization == "float16": + return safetensors.torch.save({"attribution": tensor.to(torch.float16)}, metadata=metadata_dict) + elif quantization == "float8": + xrange = torch.max(tensor) - torch.min(tensor) + scale = 255 / xrange + if negatives: + zeropoint = (-scale * torch.min(tensor)).round() - 128 + quant_tensor = torch.clip((tensor * scale + zeropoint).round(), -128, 127).to(torch.int8) + else: + zeropoint = (-scale * torch.min(tensor)).round() + quant_tensor = torch.clip((tensor * scale + zeropoint).round(), 0, 255).to(torch.uint8) + + metadata_dict["scale"], metadata_dict["zeropoint"] = f"{scale}", f"{zeropoint}" + return safetensors.torch.save({"attribution": quant_tensor}, metadata=metadata_dict) + else: + raise ValueError("`quantization` has to be one of [float32, float16, float8]") + + +def dequantize_safetensor(safetensor: bytes) -> torch.Tensor: + """ + Convert a safetensor to a torch tensor and dequantize weights to float32. + Adapted from https://huggingface.co/docs/safetensors/metadata_parsing + """ + header_length = struct.unpack("