From 345fabdf78a0e0d616e3d5be37ae7d16e8ee2aa7 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 25 Jun 2024 06:47:17 -0700 Subject: [PATCH 01/13] NF4 quantization of linear layers without LoRA applied --- torchtune/models/gemma/_component_builders.py | 35 ++++++++--- .../models/llama2/_component_builders.py | 36 +++++++---- .../models/llama3/_component_builders.py | 35 ++++++++--- .../models/mistral/_component_builders.py | 35 ++++++++--- torchtune/models/phi3/_component_builders.py | 35 ++++++++--- torchtune/modules/__init__.py | 2 + torchtune/modules/low_precision/__init__.py | 6 ++ torchtune/modules/low_precision/nf4_linear.py | 59 +++++++++++++++++++ 8 files changed, 197 insertions(+), 46 deletions(-) create mode 100644 torchtune/modules/low_precision/nf4_linear.py diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index 8335df540..36e8ec450 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -12,6 +12,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RotaryPositionalEmbeddings, TransformerDecoderLayer, ) @@ -113,7 +114,7 @@ def gemma( return model -def gemma_mlp(dim: int, hidden_dim: int) -> FeedForward: +def gemma_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Gemma model. @@ -121,9 +122,9 @@ def gemma_mlp(dim: int, hidden_dim: int) -> FeedForward: dim (int): input dimension to the MLP hidden_dim (int): hidden dimension of the MLP """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) activation = nn.GELU(approximate="tanh") return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) @@ -212,7 +213,7 @@ def lora_gemma( lora_dropout=lora_dropout, ) else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -283,7 +284,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -295,7 +300,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -307,7 +316,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -319,7 +332,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(num_heads * head_dim, embed_dim, bias=False) + else ( + nn.Linear(num_heads * head_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index f9420a568..368834b07 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -15,13 +15,13 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, TransformerDecoderLayer, ) - from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear """ @@ -120,13 +120,13 @@ def llama2( ) -def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward: +def llama2_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -221,7 +221,7 @@ def lora_llama2( lora_dropout=lora_dropout, ) else: - mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -328,7 +328,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -340,7 +344,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -352,7 +360,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -364,7 +376,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) self_attn = CausalSelfAttention( diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 8010a48ec..155bb5f93 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -14,6 +14,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, KVCache, RMSNorm, RotaryPositionalEmbeddings, @@ -116,13 +117,13 @@ def llama3( output=output_proj, ) -def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward: +def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -218,7 +219,7 @@ def lora_llama3( lora_dropout=lora_dropout, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -320,7 +321,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -332,7 +337,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -344,7 +353,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -356,7 +369,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/models/mistral/_component_builders.py b/torchtune/models/mistral/_component_builders.py index 7a908dc83..959b93b90 100644 --- a/torchtune/models/mistral/_component_builders.py +++ b/torchtune/models/mistral/_component_builders.py @@ -13,6 +13,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -114,13 +115,13 @@ def mistral( ) -def mistral_mlp(dim: int, hidden_dim: int) -> FeedForward: +def mistral_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Mistral model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -209,7 +210,7 @@ def lora_mistral( quantize_base=quantize_base, ) else: - mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -316,7 +317,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -328,7 +333,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -340,7 +349,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -352,7 +365,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/models/phi3/_component_builders.py b/torchtune/models/phi3/_component_builders.py index 16aacf309..3e84fe441 100644 --- a/torchtune/models/phi3/_component_builders.py +++ b/torchtune/models/phi3/_component_builders.py @@ -13,6 +13,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -105,13 +106,13 @@ def phi3( output=output_proj, ) -def phi3_mlp(dim: int, hidden_dim: int) -> FeedForward: +def phi3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Phi3 Mini 4K Instruct model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -206,7 +207,7 @@ def lora_phi3( lora_dropout=lora_dropout, ) else: - mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -312,7 +313,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -324,7 +329,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -336,7 +345,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -348,7 +361,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 1798f58b4..3942b82b1 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -9,6 +9,7 @@ from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa from .layer_norm import Fp32LayerNorm # noqa +from .low_precision import FrozenNF4Linear # noqa from .lr_schedulers import get_cosine_schedule_with_warmup # noqa from .position_embeddings import RotaryPositionalEmbeddings # noqa from .rms_norm import RMSNorm # noqa @@ -18,6 +19,7 @@ __all__ = [ "CausalSelfAttention", "FeedForward", + "FrozenNF4Linear", "get_cosine_schedule_with_warmup", "KVCache", "RotaryPositionalEmbeddings", diff --git a/torchtune/modules/low_precision/__init__.py b/torchtune/modules/low_precision/__init__.py index 2e41cd717..f49b27f46 100644 --- a/torchtune/modules/low_precision/__init__.py +++ b/torchtune/modules/low_precision/__init__.py @@ -3,3 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .nf4_linear import FrozenNF4Linear + +__all__ = [ + "FrozenNF4Linear", +] \ No newline at end of file diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py new file mode 100644 index 000000000..50a5bd313 --- /dev/null +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +import torch.nn as nn +from torch import Tensor +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class FrozenNF4Linear(nn.Linear): + """ + A linear layer similar to ``torch.nn.Linear`` but uses a quantized + NF4Tensor as its weight. This class also freezes its ``weight`` parameter + and is meant to be used as the base Linear layer for modeling + use cases such as QLoRA where base model parameters are frozen. + NOTE: biases are currently not supported. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default + device given by `torch.get_default_device()`. + **kwargs: any additional arguments to pass to the underlying Linear layer. + + Raises: + RuntimeError: if ``bias`` is set to ``True`` + """ + + def __init__( + self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs + ): + if "bias" in kwargs and kwargs.pop("bias"): + raise RuntimeError("FrozenNF4Linear does not currently support biases!") + + super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) + self.weight.requires_grad_(False) + self.nf4_weight = to_nf4(self.weight.data) + # re-register self.weight as the nf4 weight, so that the nf4 weight + # shows up as expected in .parameters, state_dict, etc. + self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs linear operation with input tensor as given by `input`. Computation happens in higher + precision, though only the nf4 weight is saved for backward for gradient computation to ensure + additional memory is not used. + Args: + input (Tensor): input tensor + + Returns: + Tensor: output tensor + """ + return linear_nf4(input=input, weight=self.weight) \ No newline at end of file From fd288ae1ade087f3c2b3c96aba3ec54f50150ff1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 1 Jul 2024 12:07:28 -0700 Subject: [PATCH 02/13] chore: lint --- torchtune/modules/low_precision/__init__.py | 2 +- torchtune/modules/low_precision/nf4_linear.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtune/modules/low_precision/__init__.py b/torchtune/modules/low_precision/__init__.py index f49b27f46..8bf6448ec 100644 --- a/torchtune/modules/low_precision/__init__.py +++ b/torchtune/modules/low_precision/__init__.py @@ -8,4 +8,4 @@ __all__ = [ "FrozenNF4Linear", -] \ No newline at end of file +] diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 50a5bd313..8c297cef5 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -33,7 +33,7 @@ class FrozenNF4Linear(nn.Linear): """ def __init__( - self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs + self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs ): if "bias" in kwargs and kwargs.pop("bias"): raise RuntimeError("FrozenNF4Linear does not currently support biases!") @@ -56,4 +56,4 @@ def forward(self, input: Tensor) -> Tensor: Returns: Tensor: output tensor """ - return linear_nf4(input=input, weight=self.weight) \ No newline at end of file + return linear_nf4(input=input, weight=self.weight) From 9b68db2df3a5c85b71ce25114387976265d55d49 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jul 2024 08:58:01 -0700 Subject: [PATCH 03/13] don't reference .data directly for to_nf4 call --- torchtune/modules/low_precision/nf4_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 8c297cef5..432186a0d 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -40,7 +40,7 @@ def __init__( super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) self.weight.requires_grad_(False) - self.nf4_weight = to_nf4(self.weight.data) + self.nf4_weight = to_nf4(self.weight) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) From 859b1b7d6e716018a5d124188b2e87c1fc9edced Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 5 Jul 2024 09:01:19 -0700 Subject: [PATCH 04/13] restore previously removed test_nf4_linear.py as well --- .../modules/low_precision/test_nf4_linear.py | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 tests/torchtune/modules/low_precision/test_nf4_linear.py diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py new file mode 100644 index 000000000..668defa91 --- /dev/null +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import pytest +import torch +from torchtune.modules.low_precision import FrozenNF4Linear +from torchtune.utils.seed import set_seed + +try: + from torchao.dtypes.nf4tensor import NF4Tensor +except ImportError as e: + raise RuntimeError( + "Please install torchao to run this test." + "Example: pip install git+https://github.com/pytorch-labs/ao.git" + ) from e + +import bitsandbytes as bnb + + +@pytest.fixture(autouse=True) +def random(): + set_seed(31) + + +def _build_bnb_linear(input_weight): + """ + Builds a bnb.nn.LinearNF4 from a given input weight + """ + param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") + bnb_linear = bnb.nn.LinearNF4( + input_weight.size(0), input_weight.size(1), bias=False + ) + bnb_linear.weight = param + bnb_linear.cuda() + return bnb_linear + + +class TestNF4Linear: + """ + Class for testing our NF4Linear implementation. + """ + + def test_bias_unsupported(self): + with pytest.raises(RuntimeError, match="does not currently support biases"): + _ = FrozenNF4Linear(1, 1, bias=True) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_parameters(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + params = list(nf4_linear.parameters()) + assert len(params) == 1 + assert isinstance(params[0], NF4Tensor) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_state_dict(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + state_dict = nf4_linear.state_dict() + assert len(state_dict) == 1 + assert isinstance(state_dict["weight"], NF4Tensor) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_output_dtype(self, dtype): + # Test to ensure W4 A16 produces A16 / W4A32 produces A32 + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + out = nf4_linear(inp) + assert out.dtype == dtype + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_backward_dtype(self, dtype): + # Test to ensure backward pass gives activation a bf16 gradient and no gradient + # to the linear's weight, as it is frozen. + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + nf4_linear(inp).sum().backward() + assert inp.grad is not None and inp.grad.dtype == dtype + assert nf4_linear.weight.grad is None + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_reconstruction_vs_bnb(self, dtype): + """ + Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when + reconstructing the respective original weights. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + + # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 + bnb_reconstruction = bnb_nf4_linear( + torch.eye(dim, dim, dtype=dtype, device="cuda") + ) + # Ensure nf4_linear and bnb reconstructions are close to each other. + diff = ( + (bnb_reconstruction.T - nf4_linear.weight.get_original_weight()).abs().max() + ) + assert diff.item() < 1e-2 + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_bnb_linear(self, dtype): + """ + This test ensures that nf4_linear is "no worse" than BNB by ensuring the + error compared to a bf16 linear is not more than BNB's implementation. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) + + inp = torch.randn(2, 512, dtype=dtype, device="cuda") + + out_nf4 = nf4_linear(inp) + out_bnb = bnb_nf4_linear(inp) + out_ref = bf16_linear(inp) + + err_bnb = (out_bnb - out_ref).sum().abs().max() + err_native = (out_nf4 - out_ref).sum().abs().max() + assert err_native.item() <= err_bnb From 9d3c2cb51117bb970b3ce5035f27f4885614eb76 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jul 2024 08:12:23 -0700 Subject: [PATCH 05/13] add test to check for quantized base weights and remove torchao check since it is a requirement --- .../models/llama2/test_lora_llama2.py | 23 ++++++++++++++++++- .../modules/low_precision/test_nf4_linear.py | 9 +------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/tests/torchtune/models/llama2/test_lora_llama2.py b/tests/torchtune/models/llama2/test_lora_llama2.py index 3452b7a2e..4dcdaa081 100644 --- a/tests/torchtune/models/llama2/test_lora_llama2.py +++ b/tests/torchtune/models/llama2/test_lora_llama2.py @@ -15,6 +15,7 @@ from torchtune import utils from torchtune.models.llama2 import llama2, lora_llama2 from torchtune.models.llama2._component_builders import lora_llama2_self_attention +from torchtune.modules.low_precision import FrozenNF4Linear from torchtune.modules.peft import LoRALinear from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt from torchtune.utils.seed import set_seed @@ -188,7 +189,7 @@ def test_lora_llama2_state_dict_parity( assert not unexpected assert all(["lora" in key for key in missing]) - def test_lora_linear_quantize_base(self): + def test_qlora_linear_quantize_base(self): model = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -203,6 +204,26 @@ def test_lora_linear_quantize_base(self): if isinstance(module, LoRALinear): assert module._quantize_base + def test_qlora_linear_quantize_base_weights(self): + # this test checks that modules that don't have LoRA applied to them + # have their base weights quantized + model = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj"], + apply_lora_to_mlp=True, + # quantize_base + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=torch.bfloat16, + ) + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + assert module._quantize_base + elif name in ["k_proj", "output_proj"]: + assert isinstance(module, FrozenNF4Linear) + assert isinstance(module.weight, NF4Tensor) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_llama2_parity(self, dtype, inputs): with utils.set_default_dtype(dtype): diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 668defa91..0342aa9ef 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -7,17 +7,10 @@ import pytest import torch +from torchao.dtypes.nf4tensor import NF4Tensor from torchtune.modules.low_precision import FrozenNF4Linear from torchtune.utils.seed import set_seed -try: - from torchao.dtypes.nf4tensor import NF4Tensor -except ImportError as e: - raise RuntimeError( - "Please install torchao to run this test." - "Example: pip install git+https://github.com/pytorch-labs/ao.git" - ) from e - import bitsandbytes as bnb From 5fef21d1234f0c35cffa718920b49ef457081363 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Jul 2024 08:22:49 -0700 Subject: [PATCH 06/13] use swap_tensors per PR feedback --- torchtune/modules/low_precision/nf4_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 432186a0d..13150eb65 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -43,7 +43,7 @@ def __init__( self.nf4_weight = to_nf4(self.weight) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. - self.weight = torch.nn.Parameter(self.nf4_weight, requires_grad=False) + torch.utils.swap_tensors(self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False)) def forward(self, input: Tensor) -> Tensor: """ From 731fe8d5b4f9a5c9e6f09fe86a0977aedf733bc4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 10 Jul 2024 10:24:50 -0700 Subject: [PATCH 07/13] chore: lint --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 3 +-- torchtune/modules/low_precision/nf4_linear.py | 4 +++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 0342aa9ef..ec585a5d1 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -5,14 +5,13 @@ # LICENSE file in the root directory of this source tree. +import bitsandbytes as bnb import pytest import torch from torchao.dtypes.nf4tensor import NF4Tensor from torchtune.modules.low_precision import FrozenNF4Linear from torchtune.utils.seed import set_seed -import bitsandbytes as bnb - @pytest.fixture(autouse=True) def random(): diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py index 13150eb65..6626688d4 100644 --- a/torchtune/modules/low_precision/nf4_linear.py +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -43,7 +43,9 @@ def __init__( self.nf4_weight = to_nf4(self.weight) # re-register self.weight as the nf4 weight, so that the nf4 weight # shows up as expected in .parameters, state_dict, etc. - torch.utils.swap_tensors(self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False)) + torch.utils.swap_tensors( + self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False) + ) def forward(self, input: Tensor) -> Tensor: """ From 4210a2ad76b78bd9fd4afde13a3c7667da3f1502 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 10 Jul 2024 13:43:20 -0700 Subject: [PATCH 08/13] fix tests to use torch.allclose --- .../modules/low_precision/test_nf4_linear.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index ec585a5d1..17965e62d 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -89,10 +89,7 @@ def test_nf4_reconstruction_vs_bnb(self, dtype): torch.eye(dim, dim, dtype=dtype, device="cuda") ) # Ensure nf4_linear and bnb reconstructions are close to each other. - diff = ( - (bnb_reconstruction.T - nf4_linear.weight.get_original_weight()).abs().max() - ) - assert diff.item() < 1e-2 + assert torch.allclose(bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) @@ -113,6 +110,6 @@ def test_nf4_bnb_linear(self, dtype): out_bnb = bnb_nf4_linear(inp) out_ref = bf16_linear(inp) - err_bnb = (out_bnb - out_ref).sum().abs().max() - err_native = (out_nf4 - out_ref).sum().abs().max() - assert err_native.item() <= err_bnb + err_bnb = (out_bnb - out_ref).sum() + err_native = (out_nf4 - out_ref).sum() + assert torch.allclose(err_bnb, err_native, 1e-2) From 56bf223aad0a61462598e3de439b8206b4ff5ba6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 10 Jul 2024 13:49:20 -0700 Subject: [PATCH 09/13] chore: lint --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 17965e62d..4c507f014 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -89,7 +89,9 @@ def test_nf4_reconstruction_vs_bnb(self, dtype): torch.eye(dim, dim, dtype=dtype, device="cuda") ) # Ensure nf4_linear and bnb reconstructions are close to each other. - assert torch.allclose(bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2) + assert torch.allclose( + bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 + ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) From 30c30e6763a965143de4326e3945a338b5e0149d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 23 Jul 2024 14:06:57 -0700 Subject: [PATCH 10/13] make the test deterministic --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 4c507f014..ed7caab1d 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -106,6 +106,7 @@ def test_nf4_bnb_linear(self, dtype): bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) + torch.manual_seed(42) inp = torch.randn(2, 512, dtype=dtype, device="cuda") out_nf4 = nf4_linear(inp) From ab6c00c3eebcf8926d2108ccf973c54d800180d8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 29 Jul 2024 11:02:51 -0700 Subject: [PATCH 11/13] include relative tolerance too --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index ed7caab1d..4c658e8da 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -106,7 +106,6 @@ def test_nf4_bnb_linear(self, dtype): bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) - torch.manual_seed(42) inp = torch.randn(2, 512, dtype=dtype, device="cuda") out_nf4 = nf4_linear(inp) @@ -115,4 +114,4 @@ def test_nf4_bnb_linear(self, dtype): err_bnb = (out_bnb - out_ref).sum() err_native = (out_nf4 - out_ref).sum() - assert torch.allclose(err_bnb, err_native, 1e-2) + assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) From c4c81c05e91d091d60f8cc210622269262793f5a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 29 Jul 2024 11:26:08 -0700 Subject: [PATCH 12/13] include relative tolerance on each element, not the sum --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 4c658e8da..7b7729c6f 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -112,6 +112,6 @@ def test_nf4_bnb_linear(self, dtype): out_bnb = bnb_nf4_linear(inp) out_ref = bf16_linear(inp) - err_bnb = (out_bnb - out_ref).sum() - err_native = (out_nf4 - out_ref).sum() + err_bnb = (out_bnb - out_ref) + err_native = (out_nf4 - out_ref) assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) From 44edc45e588b0b6486c1b7728bc9b22dfffabf75 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 29 Jul 2024 11:36:21 -0700 Subject: [PATCH 13/13] chore: lint --- tests/torchtune/modules/low_precision/test_nf4_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py index 7b7729c6f..5408561f1 100644 --- a/tests/torchtune/modules/low_precision/test_nf4_linear.py +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -112,6 +112,6 @@ def test_nf4_bnb_linear(self, dtype): out_bnb = bnb_nf4_linear(inp) out_ref = bf16_linear(inp) - err_bnb = (out_bnb - out_ref) - err_native = (out_nf4 - out_ref) + err_bnb = out_bnb - out_ref + err_native = out_nf4 - out_ref assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2)