Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NF4 quantization of linear layers without LoRA applied #1119

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
23 changes: 22 additions & 1 deletion tests/torchtune/models/llama2/test_lora_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchtune import utils
from torchtune.models.llama2 import llama2, lora_llama2
from torchtune.models.llama2._component_builders import lora_llama2_self_attention
from torchtune.modules.low_precision import FrozenNF4Linear
from torchtune.modules.peft import LoRALinear
from torchtune.modules.peft.peft_utils import get_merged_lora_ckpt
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -188,7 +189,7 @@ def test_lora_llama2_state_dict_parity(
assert not unexpected
assert all(["lora" in key for key in missing])

def test_lora_linear_quantize_base(self):
def test_qlora_linear_quantize_base(self):
model = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj", "k_proj", "output_proj"],
apply_lora_to_mlp=True,
Expand All @@ -203,6 +204,26 @@ def test_lora_linear_quantize_base(self):
if isinstance(module, LoRALinear):
assert module._quantize_base

def test_qlora_linear_quantize_base_weights(self):
# this test checks that modules that don't have LoRA applied to them
# have their base weights quantized
model = self.get_lora_llama2(
lora_modules=["q_proj", "v_proj"],
apply_lora_to_mlp=True,
# quantize_base
apply_lora_to_output=False,
vocab_size=50,
quantize_base=True,
embed_dim=512,
dtype=torch.bfloat16,
)
for name, module in model.named_modules():
if isinstance(module, LoRALinear):
assert module._quantize_base
elif name in ["k_proj", "output_proj"]:
assert isinstance(module, FrozenNF4Linear)
assert isinstance(module.weight, NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_qlora_llama2_parity(self, dtype, inputs):
with utils.set_default_dtype(dtype):
Expand Down
117 changes: 117 additions & 0 deletions tests/torchtune/modules/low_precision/test_nf4_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import bitsandbytes as bnb
import pytest
import torch
from torchao.dtypes.nf4tensor import NF4Tensor
from torchtune.modules.low_precision import FrozenNF4Linear
from torchtune.utils.seed import set_seed


@pytest.fixture(autouse=True)
def random():
set_seed(31)


def _build_bnb_linear(input_weight):
"""
Builds a bnb.nn.LinearNF4 from a given input weight
"""
param = bnb.nn.Params4bit(input_weight, requires_grad=False, quant_type="nf4")
bnb_linear = bnb.nn.LinearNF4(
input_weight.size(0), input_weight.size(1), bias=False
)
bnb_linear.weight = param
bnb_linear.cuda()
return bnb_linear


class TestNF4Linear:
"""
Class for testing our NF4Linear implementation.
"""

def test_bias_unsupported(self):
with pytest.raises(RuntimeError, match="does not currently support biases"):
_ = FrozenNF4Linear(1, 1, bias=True)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_parameters(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
params = list(nf4_linear.parameters())
assert len(params) == 1
assert isinstance(params[0], NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_state_dict(self, dtype):
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
state_dict = nf4_linear.state_dict()
assert len(state_dict) == 1
assert isinstance(state_dict["weight"], NF4Tensor)

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_output_dtype(self, dtype):
# Test to ensure W4 A16 produces A16 / W4A32 produces A32
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
out = nf4_linear(inp)
assert out.dtype == dtype

@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_backward_dtype(self, dtype):
# Test to ensure backward pass gives activation a bf16 gradient and no gradient
# to the linear's weight, as it is frozen.
nf4_linear = FrozenNF4Linear(512, 512, device="cpu", dtype=dtype)
inp = torch.randn(2, 512, dtype=dtype, requires_grad=True)
nf4_linear(inp).sum().backward()
assert inp.grad is not None and inp.grad.dtype == dtype
assert nf4_linear.weight.grad is None

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_nf4_reconstruction_vs_bnb(self, dtype):
"""
Ensures a BNB NF4 linear and our FrozenNF4Linear have low error when
reconstructing the respective original weights.
"""
dim = 512
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)

# From https://github.com/drisspg/transformer_nuggets/blob/f05afad68ad9086d342268f46a7f344617a02314/test/test_qlora.py#L65
bnb_reconstruction = bnb_nf4_linear(
torch.eye(dim, dim, dtype=dtype, device="cuda")
)
# Ensure nf4_linear and bnb reconstructions are close to each other.
assert torch.allclose(
bnb_reconstruction.T, nf4_linear.weight.get_original_weight(), 1e-2
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_nf4_bnb_linear(self, dtype):
"""
This test ensures that nf4_linear is "no worse" than BNB by ensuring the
error compared to a bf16 linear is not more than BNB's implementation.
"""
dim = 512
nf4_linear = FrozenNF4Linear(dim, dim, device="cuda", dtype=dtype)
orig_weight = nf4_linear.weight.get_original_weight().clone().detach()
bnb_nf4_linear = _build_bnb_linear(input_weight=orig_weight)
bf16_linear = torch.nn.Linear(dim, dim, device="cuda", dtype=dtype)

inp = torch.randn(2, 512, dtype=dtype, device="cuda")

out_nf4 = nf4_linear(inp)
out_bnb = bnb_nf4_linear(inp)
out_ref = bf16_linear(inp)

err_bnb = (out_bnb - out_ref).sum()
err_native = (out_nf4 - out_ref).sum()
assert torch.allclose(err_bnb, err_native, 1e-2)
35 changes: 26 additions & 9 deletions torchtune/models/gemma/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torchtune.modules import (
CausalSelfAttention,
FeedForward,
FrozenNF4Linear,
RotaryPositionalEmbeddings,
TransformerDecoderLayer,
)
Expand Down Expand Up @@ -113,17 +114,17 @@ def gemma(
return model


def gemma_mlp(dim: int, hidden_dim: int) -> FeedForward:
def gemma_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the Gemma model.

Args:
dim (int): input dimension to the MLP
hidden_dim (int): hidden dimension of the MLP
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
activation = nn.GELU(approximate="tanh")
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation)

Expand Down Expand Up @@ -212,7 +213,7 @@ def lora_gemma(
lora_dropout=lora_dropout,
)
else:
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)

layer = TransformerDecoderLayer(
attn=self_attn,
Expand Down Expand Up @@ -283,7 +284,11 @@ def lora_gemma_self_attention(
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
)
)
k_proj = (
LoRALinear(
Expand All @@ -295,7 +300,11 @@ def lora_gemma_self_attention(
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
v_proj = (
LoRALinear(
Expand All @@ -307,7 +316,11 @@ def lora_gemma_self_attention(
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
output_proj = (
LoRALinear(
Expand All @@ -319,7 +332,11 @@ def lora_gemma_self_attention(
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
else nn.Linear(num_heads * head_dim, embed_dim, bias=False)
else (
nn.Linear(num_heads * head_dim, embed_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(num_heads * head_dim, embed_dim, bias=False)
)
)

rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
Expand Down
36 changes: 26 additions & 10 deletions torchtune/models/llama2/_component_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from torchtune.modules import (
CausalSelfAttention,
FeedForward,
FrozenNF4Linear,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
TransformerDecoderLayer,
)


from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear

"""
Expand Down Expand Up @@ -120,13 +120,13 @@ def llama2(
)


def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward:
def llama2_mlp(dim: int, hidden_dim: int, quantize_base: bool = False) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
gate_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False) if not quantize_base else FrozenNF4Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False) if not quantize_base else FrozenNF4Linear(dim, hidden_dim, bias=False)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)


Expand Down Expand Up @@ -221,7 +221,7 @@ def lora_llama2(
lora_dropout=lora_dropout,
)
else:
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim, quantize_base=quantize_base)

layer = TransformerDecoderLayer(
attn=self_attn,
Expand Down Expand Up @@ -328,7 +328,11 @@ def lora_llama2_self_attention(
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_heads * head_dim, bias=False)
)
)
k_proj = (
LoRALinear(
Expand All @@ -340,7 +344,11 @@ def lora_llama2_self_attention(
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
v_proj = (
LoRALinear(
Expand All @@ -352,7 +360,11 @@ def lora_llama2_self_attention(
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
else (
nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
)
output_proj = (
LoRALinear(
Expand All @@ -364,7 +376,11 @@ def lora_llama2_self_attention(
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
else (
nn.Linear(embed_dim, embed_dim, bias=False)
if not quantize_base
else FrozenNF4Linear(embed_dim, embed_dim, bias=False)
)
)
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
self_attn = CausalSelfAttention(
Expand Down
Loading
Loading