diff --git a/tests/torchtune/models/llama2/test_lora_llama2.py b/tests/torchtune/models/llama2/test_lora_llama2.py index 3452b7a2e..4dcdaa081 100644 --- a/tests/torchtune/models/llama2/test_lora_llama2.py +++ b/tests/torchtune/models/llama2/test_lora_llama2.py @@ -15,6 +15,7 @@ from torchtune import utils from torchtune.models.llama2 import llama2, lora_llama2 from torchtune.models.llama2._component_builders import lora_llama2_self_attention +from torchtune.modules.low_precision import FrozenNF4Linear from torchtune.modules.peft import LoRALinear from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt from torchtune.utils.seed import set_seed @@ -188,7 +189,7 @@ def test_lora_llama2_state_dict_parity( assert not unexpected assert all(["lora" in key for key in missing]) - def test_lora_linear_quantize_base(self): + def test_qlora_linear_quantize_base(self): model = self.get_lora_llama2( lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"], apply_lora_to_mlp=True, @@ -203,6 +204,26 @@ def test_lora_linear_quantize_base(self): if isinstance(module, LoRALinear): assert module._quantize_base + def test_qlora_linear_quantize_base_weights(self): + # this test checks that modules that don't have LoRA applied to them + # have their base weights quantized + model = self.get_lora_llama2( + lora_modules=["q_proj", "v_proj"], + apply_lora_to_mlp=True, + # quantize_base + apply_lora_to_output=False, + vocab_size=50, + quantize_base=True, + embed_dim=512, + dtype=torch.bfloat16, + ) + for name, module in model.named_modules(): + if isinstance(module, LoRALinear): + assert module._quantize_base + elif name in ["k_proj", "output_proj"]: + assert isinstance(module, FrozenNF4Linear) + assert isinstance(module.weight, NF4Tensor) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_qlora_llama2_parity(self, dtype, inputs): with utils.set_default_dtype(dtype): diff --git a/tests/torchtune/modules/low_precision/test_nf4_linear.py b/tests/torchtune/modules/low_precision/test_nf4_linear.py new file mode 100644 index 000000000..5408561f1 --- /dev/null +++ b/tests/torchtune/modules/low_precision/test_nf4_linear.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import bitsandbytes as bnb +import pytest +import torch +from torchao.dtypes.nf4tensor import NF4Tensor +from torchtune.modules.low_precision import FrozenNF4Linear +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(31) + + +def _build_bnb_linear(input_weight): + """ + Builds a bnb.nn.LinearNF4 from a given input weight + """ + param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4") + bnb_linear = bnb.nn.LinearNF4( + input_weight.size(0), input_weight.size(1), bias=False + ) + bnb_linear.weight = param + bnb_linear.cuda() + return bnb_linear + + +class TestNF4Linear: + """ + Class for testing our NF4Linear implementation. + """ + + def test_bias_unsupported(self): + with pytest.raises(RuntimeError, match="does not currently support biases"): + _ = FrozenNF4Linear(1, 1, bias=True) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_parameters(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + params = list(nf4_linear.parameters()) + assert len(params) == 1 + assert isinstance(params[0], NF4Tensor) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_state_dict(self, dtype): + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + state_dict = nf4_linear.state_dict() + assert len(state_dict) == 1 + assert isinstance(state_dict["weight"], NF4Tensor) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_output_dtype(self, dtype): + # Test to ensure W4 A16 produces A16 / W4A32 produces A32 + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + out = nf4_linear(inp) + assert out.dtype == dtype + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_backward_dtype(self, dtype): + # Test to ensure backward pass gives activation a bf16 gradient and no gradient + # to the linear's weight, as it is frozen. + nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype) + inp = torch.randn(2, 512, dtype=dtype, requires_grad=True) + nf4_linear(inp).sum().backward() + assert inp.grad is not None and inp.grad.dtype == dtype + assert nf4_linear.weight.grad is None + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_reconstruction_vs_bnb(self, dtype): + """ + Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when + reconstructing the respective original weights. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + + # From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65 + bnb_reconstruction = bnb_nf4_linear( + torch.eye(dim, dim, dtype=dtype, device="cuda") + ) + # Ensure nf4_linear and bnb reconstructions are close to each other. + assert torch.allclose( + bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2 + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) + def test_nf4_bnb_linear(self, dtype): + """ + This test ensures that nf4_linear is "no worse" than BNB by ensuring the + error compared to a bf16 linear is not more than BNB's implementation. + """ + dim = 512 + nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype) + orig_weight = nf4_linear.weight.get_original_weight().clone().detach() + bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight) + bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype) + + inp = torch.randn(2, 512, dtype=dtype, device="cuda") + + out_nf4 = nf4_linear(inp) + out_bnb = bnb_nf4_linear(inp) + out_ref = bf16_linear(inp) + + err_bnb = out_bnb - out_ref + err_native = out_nf4 - out_ref + assert torch.allclose(err_bnb, err_native, 1.0e-2, 1.0e-2) diff --git a/torchtune/models/gemma/_component_builders.py b/torchtune/models/gemma/_component_builders.py index 8335df540..36e8ec450 100644 --- a/torchtune/models/gemma/_component_builders.py +++ b/torchtune/models/gemma/_component_builders.py @@ -12,6 +12,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RotaryPositionalEmbeddings, TransformerDecoderLayer, ) @@ -113,7 +114,7 @@ def gemma( return model -def gemma_mlp(dim: int, hidden_dim: int) -> FeedForward: +def gemma_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Gemma model. @@ -121,9 +122,9 @@ def gemma_mlp(dim: int, hidden_dim: int) -> FeedForward: dim (int): input dimension to the MLP hidden_dim (int): hidden dimension of the MLP """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) activation = nn.GELU(approximate="tanh") return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) @@ -212,7 +213,7 @@ def lora_gemma( lora_dropout=lora_dropout, ) else: - mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -283,7 +284,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -295,7 +300,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -307,7 +316,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -319,7 +332,11 @@ def lora_gemma_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(num_heads * head_dim, embed_dim, bias=False) + else ( + nn.Linear(num_heads * head_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) diff --git a/torchtune/models/llama2/_component_builders.py b/torchtune/models/llama2/_component_builders.py index f9420a568..368834b07 100644 --- a/torchtune/models/llama2/_component_builders.py +++ b/torchtune/models/llama2/_component_builders.py @@ -15,13 +15,13 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, TransformerDecoderLayer, ) - from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear """ @@ -120,13 +120,13 @@ def llama2( ) -def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward: +def llama2_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -221,7 +221,7 @@ def lora_llama2( lora_dropout=lora_dropout, ) else: - mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim) + mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -328,7 +328,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -340,7 +344,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -352,7 +360,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -364,7 +376,11 @@ def lora_llama2_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len) self_attn = CausalSelfAttention( diff --git a/torchtune/models/llama3/_component_builders.py b/torchtune/models/llama3/_component_builders.py index 8010a48ec..155bb5f93 100644 --- a/torchtune/models/llama3/_component_builders.py +++ b/torchtune/models/llama3/_component_builders.py @@ -14,6 +14,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, KVCache, RMSNorm, RotaryPositionalEmbeddings, @@ -116,13 +117,13 @@ def llama3( output=output_proj, ) -def llama3_mlp(dim: int, hidden_dim: int) -> FeedForward: +def llama3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Llama model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -218,7 +219,7 @@ def lora_llama3( lora_dropout=lora_dropout, ) else: - mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim) + mlp = llama3_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -320,7 +321,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -332,7 +337,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -344,7 +353,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -356,7 +369,11 @@ def lora_llama3_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/models/mistral/_component_builders.py b/torchtune/models/mistral/_component_builders.py index 7a908dc83..959b93b90 100644 --- a/torchtune/models/mistral/_component_builders.py +++ b/torchtune/models/mistral/_component_builders.py @@ -13,6 +13,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -114,13 +115,13 @@ def mistral( ) -def mistral_mlp(dim: int, hidden_dim: int) -> FeedForward: +def mistral_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Mistral model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -209,7 +210,7 @@ def lora_mistral( quantize_base=quantize_base, ) else: - mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = mistral_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -316,7 +317,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -328,7 +333,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -340,7 +349,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -352,7 +365,11 @@ def lora_mistral_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/models/phi3/_component_builders.py b/torchtune/models/phi3/_component_builders.py index 16aacf309..3e84fe441 100644 --- a/torchtune/models/phi3/_component_builders.py +++ b/torchtune/models/phi3/_component_builders.py @@ -13,6 +13,7 @@ from torchtune.modules import ( CausalSelfAttention, FeedForward, + FrozenNF4Linear, RMSNorm, RotaryPositionalEmbeddings, TransformerDecoder, @@ -105,13 +106,13 @@ def phi3( output=output_proj, ) -def phi3_mlp(dim: int, hidden_dim: int) -> FeedForward: +def phi3_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward: """ Build the MLP layer associated with the Phi3 Mini 4K Instruct model. """ - gate_proj = nn.Linear(dim, hidden_dim, bias=False) - down_proj = nn.Linear(hidden_dim, dim, bias=False) - up_proj = nn.Linear(dim, hidden_dim, bias=False) + gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) + down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False) + up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False) return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj) @@ -206,7 +207,7 @@ def lora_phi3( lora_dropout=lora_dropout, ) else: - mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim) + mlp = phi3_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base) layer = TransformerDecoderLayer( attn=self_attn, @@ -312,7 +313,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "q_proj" in lora_modules - else nn.Linear(embed_dim, num_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False) + ) ) k_proj = ( LoRALinear( @@ -324,7 +329,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "k_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) v_proj = ( LoRALinear( @@ -336,7 +345,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "v_proj" in lora_modules - else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + else ( + nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False) + ) ) output_proj = ( LoRALinear( @@ -348,7 +361,11 @@ def lora_phi3_self_attention( quantize_base=quantize_base, ) if "output_proj" in lora_modules - else nn.Linear(embed_dim, embed_dim, bias=False) + else ( + nn.Linear(embed_dim, embed_dim, bias=False) + if not quantize_base + else FrozenNF4Linear(embed_dim, embed_dim, bias=False) + ) ) rope = Phi3RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base) self_attn = CausalSelfAttention( diff --git a/torchtune/modules/__init__.py b/torchtune/modules/__init__.py index 1798f58b4..3942b82b1 100644 --- a/torchtune/modules/__init__.py +++ b/torchtune/modules/__init__.py @@ -9,6 +9,7 @@ from .feed_forward import FeedForward # noqa from .kv_cache import KVCache # noqa from .layer_norm import Fp32LayerNorm # noqa +from .low_precision import FrozenNF4Linear # noqa from .lr_schedulers import get_cosine_schedule_with_warmup # noqa from .position_embeddings import RotaryPositionalEmbeddings # noqa from .rms_norm import RMSNorm # noqa @@ -18,6 +19,7 @@ __all__ = [ "CausalSelfAttention", "FeedForward", + "FrozenNF4Linear", "get_cosine_schedule_with_warmup", "KVCache", "RotaryPositionalEmbeddings", diff --git a/torchtune/modules/low_precision/__init__.py b/torchtune/modules/low_precision/__init__.py index 2e41cd717..8bf6448ec 100644 --- a/torchtune/modules/low_precision/__init__.py +++ b/torchtune/modules/low_precision/__init__.py @@ -3,3 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .nf4_linear import FrozenNF4Linear + +__all__ = [ + "FrozenNF4Linear", +] diff --git a/torchtune/modules/low_precision/nf4_linear.py b/torchtune/modules/low_precision/nf4_linear.py new file mode 100644 index 000000000..6626688d4 --- /dev/null +++ b/torchtune/modules/low_precision/nf4_linear.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch + +import torch.nn as nn +from torch import Tensor +from torchao.dtypes.nf4tensor import linear_nf4, to_nf4 + + +class FrozenNF4Linear(nn.Linear): + """ + A linear layer similar to ``torch.nn.Linear`` but uses a quantized + NF4Tensor as its weight. This class also freezes its ``weight`` parameter + and is meant to be used as the base Linear layer for modeling + use cases such as QLoRA where base model parameters are frozen. + NOTE: biases are currently not supported. + + Args: + in_dim (int): input dimension + out_dim (int): output dimension + device (Optional[torch.device]): device to use for the underlying weight. If ``None``, uses the default + device given by `torch.get_default_device()`. + **kwargs: any additional arguments to pass to the underlying Linear layer. + + Raises: + RuntimeError: if ``bias`` is set to ``True`` + """ + + def __init__( + self, in_dim: int, out_dim: int, device: Optional[torch.device] = None, **kwargs + ): + if "bias" in kwargs and kwargs.pop("bias"): + raise RuntimeError("FrozenNF4Linear does not currently support biases!") + + super().__init__(in_dim, out_dim, device=device, bias=False, **kwargs) + self.weight.requires_grad_(False) + self.nf4_weight = to_nf4(self.weight) + # re-register self.weight as the nf4 weight, so that the nf4 weight + # shows up as expected in .parameters, state_dict, etc. + torch.utils.swap_tensors( + self.weight, torch.nn.Parameter(self.nf4_weight, requires_grad=False) + ) + + def forward(self, input: Tensor) -> Tensor: + """ + Runs linear operation with input tensor as given by `input`. Computation happens in higher + precision, though only the nf4 weight is saved for backward for gradient computation to ensure + additional memory is not used. + Args: + input (Tensor): input tensor + + Returns: + Tensor: output tensor + """ + return linear_nf4(input=input, weight=self.weight)