-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from invoke-ai/ryan/conv-lora
Add support for Conv LoRA layers
- Loading branch information
Showing
5 changed files
with
338 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
from .base_lora_layer import BaseLoRALayer # noqa: F401 | ||
from .lora_conv_layer import ( # noqa: F401 | ||
LoRAConv1dLayer, | ||
LoRAConv2dLayer, | ||
LoRAConv3dLayer, | ||
) | ||
from .lora_linear_layer import LoRALinearLayer # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
import math | ||
import typing | ||
|
||
import torch | ||
|
||
from invoke_training.lora.layers import BaseLoRALayer | ||
|
||
|
||
class LoRAConvLayer(BaseLoRALayer): | ||
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'. | ||
(https://arxiv.org/pdf/2106.09685.pdf) | ||
""" | ||
|
||
@property | ||
def conv_module(self): | ||
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d.""" | ||
raise NotImplementedError( | ||
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead." | ||
) | ||
|
||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
kernel_size: typing.Union[int, tuple[int]] = 1, | ||
stride: typing.Union[int, tuple[int]] = 1, | ||
padding: typing.Union[str, int, tuple[int]] = 0, | ||
rank: int = 4, | ||
alpha: float = 1.0, | ||
device: torch.device = None, | ||
dtype: torch.dtype = None, | ||
): | ||
"""Initialize a LoRAConvLayer. | ||
Args: | ||
in_channels (int): The number of channels expected on inputs to this layer. | ||
out_channels (int): The number of channels on outputs from this layer. | ||
kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. | ||
stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. | ||
padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. | ||
rank (int, optional): The internal rank of the layer. See the paper for details. | ||
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning | ||
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do | ||
not tune it further. See the paper for more details. | ||
device (torch.device, optional): Device where weights will be initialized. | ||
dtype (torch.dtype, optional): Weight dtype. | ||
Raises: | ||
ValueError: If the rank is greater than either in_channels or out_channels. | ||
""" | ||
super().__init__() | ||
|
||
if rank > min(in_channels, out_channels): | ||
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}") | ||
|
||
self._down = self.conv_module( | ||
in_channels, | ||
rank, | ||
kernel_size=kernel_size, | ||
stride=stride, | ||
padding=padding, | ||
bias=False, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype) | ||
|
||
# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict. | ||
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype)) | ||
|
||
self._rank = rank | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
# This initialization is based on Microsoft's implementation: | ||
# https://github.com/microsoft/LoRA/blob/998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe/loralib/layers.py#L279 | ||
torch.nn.init.kaiming_uniform_(self._down.weight, a=math.sqrt(5)) | ||
torch.nn.init.zeros_(self._up.weight) | ||
|
||
@classmethod | ||
def from_layer( | ||
cls, | ||
layer: torch.nn.Module, | ||
rank: int = 4, | ||
alpha: float = 1.0, | ||
device: torch.device = None, | ||
dtype: torch.dtype = None, | ||
): | ||
"""Initialize a LoRAConvLayer with dimensions that are compatible with `layer`. | ||
Args: | ||
layer (torch.nn.Module): The existing layer whose in/out dimensions will be matched. | ||
rank, alpha, device, dtype: These args are forwarded to __init__(...). If device or dtype are None, they | ||
will be inferred from `layer`. | ||
Raises: | ||
TypeError: If `layer` has an unsupported type. | ||
Returns: | ||
LoRAConvLayer: The new LoRAConvLayer. | ||
""" | ||
if isinstance(layer, cls.conv_module): | ||
return cls( | ||
in_channels=layer.in_channels, | ||
out_channels=layer.out_channels, | ||
kernel_size=layer.kernel_size, | ||
stride=layer.stride, | ||
padding=layer.padding, | ||
rank=rank, | ||
alpha=alpha, | ||
device=layer.weight.device if device is None else device, | ||
dtype=layer.weight.dtype if dtype is None else dtype, | ||
) | ||
else: | ||
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.") | ||
|
||
def forward(self, input: torch.Tensor): | ||
down_hidden = self._down(input) | ||
up_hidden = self._up(down_hidden) | ||
|
||
up_hidden *= self.alpha / self._rank | ||
|
||
return up_hidden | ||
|
||
|
||
class LoRAConv1dLayer(LoRAConvLayer): | ||
conv_module = torch.nn.Conv1d | ||
|
||
|
||
class LoRAConv2dLayer(LoRAConvLayer): | ||
conv_module = torch.nn.Conv2d | ||
|
||
|
||
class LoRAConv3dLayer(LoRAConvLayer): | ||
conv_module = torch.nn.Conv3d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
189 changes: 189 additions & 0 deletions
189
tests/invoke_training/lora/layers/test_lora_conv_layer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import typing | ||
|
||
import pytest | ||
import torch | ||
|
||
from invoke_training.lora.layers import ( | ||
LoRAConv1dLayer, | ||
LoRAConv2dLayer, | ||
LoRAConv3dLayer, | ||
) | ||
from invoke_training.lora.layers.lora_conv_layer import LoRAConvLayer | ||
|
||
|
||
def test_lora_conv_layer_initialize_base_class(): | ||
"""Test that attempting to directly initialize a LoRAConvLayer raise a NotImplementedError.""" | ||
with pytest.raises(NotImplementedError): | ||
_ = LoRAConvLayer(4, 8) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["lora_conv_cls", "conv_dims"], [(LoRAConv1dLayer, 1), (LoRAConv2dLayer, 2), (LoRAConv3dLayer, 3)] | ||
) | ||
class TestLoRAConvLayers: | ||
"""Test class for applying tests to each of the LoRAConv*Layer classes.""" | ||
|
||
def test_lora_conv_layer_output_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): | ||
"""Test that LoRAConv*Layer produces an output with the expected dimensions.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
layer = lora_conv_cls(in_channels, out_channels) | ||
|
||
in_shape = (batch_size, in_channels) + (5,) * conv_dims | ||
x = torch.rand(in_shape) | ||
with torch.no_grad(): | ||
y = layer(x) | ||
|
||
expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims | ||
assert y.shape == expected_out_shape | ||
|
||
def test_lora_conv_layer_invalid_input_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): | ||
"""Test that LoRAConv*Layer raises an exception if it receives an input with invalid dimensions.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
layer = lora_conv_cls(in_channels, out_channels) | ||
|
||
in_shape = (batch_size, in_channels + 1) + (5,) * conv_dims # Bad input dimension. | ||
x = torch.rand(in_shape) | ||
with pytest.raises(RuntimeError): | ||
_ = layer(x) | ||
|
||
def test_lora_conv_layer_zero_after_init(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): | ||
"""Test that a newly-initialized LoRAConv*Layer produces all zeros before it is trained.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
layer = lora_conv_cls(in_channels, out_channels) | ||
|
||
in_shape = (batch_size, in_channels) + (5,) * conv_dims | ||
x = torch.rand(in_shape) | ||
with torch.no_grad(): | ||
y = layer(x) | ||
|
||
assert not torch.allclose(x, torch.Tensor([0.0]), rtol=0.0) # The random input was non-zero. | ||
assert torch.allclose(y, torch.Tensor([0.0]), rtol=0.0) # The untrained outputs are zero. | ||
|
||
def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): | ||
"""Test that a LoRAConv*Layer can be initialized correctly from a torch.nn.Conv* layer.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, padding="same") | ||
|
||
lora_layer = lora_conv_cls.from_layer(original_layer) | ||
|
||
in_shape = (batch_size, in_channels) + (5,) * conv_dims | ||
x = torch.rand(in_shape) | ||
with torch.no_grad(): | ||
y = lora_layer(x) | ||
|
||
expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims | ||
assert y.shape == expected_out_shape | ||
|
||
def test_lora_conv_layer_from_layer_kernel_and_stride( | ||
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int | ||
): | ||
"""Test that a LoRAConv*Layer is initialized with the correct kernel_size, stride, and padding when initialized | ||
from a torch.nn.Conv* layer.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, stride=2, padding="valid") | ||
|
||
lora_layer = lora_conv_cls.from_layer(original_layer) | ||
|
||
# Check the internal layer config. | ||
assert lora_layer._down.kernel_size == original_layer.kernel_size | ||
assert lora_layer._down.stride == original_layer.stride | ||
assert lora_layer._down.padding == original_layer.padding | ||
|
||
in_shape = (batch_size, in_channels) + (6,) * conv_dims | ||
x = torch.rand(in_shape) | ||
with torch.no_grad(): | ||
y = lora_layer(x) | ||
|
||
# The combination of kernel_size, stride, and padding should reduce the dimensions to this output shape: | ||
expected_out_shape = (batch_size, out_channels) + (2,) * conv_dims | ||
assert y.shape == expected_out_shape | ||
|
||
@pytest.mark.cuda | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
def test_lora_conv_layer_from_layer_inherit_device_and_dtype( | ||
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype | ||
): | ||
"""Test that when a LoRAConv*Layer is initialized with from_layer(...), it correctly inherits the device and | ||
dtype. | ||
""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
original_layer = lora_conv_cls.conv_module( | ||
in_channels, out_channels, kernel_size=3, padding="same", device=torch.device("cuda"), dtype=dtype | ||
) | ||
|
||
lora_layer = lora_conv_cls.from_layer(original_layer) | ||
|
||
in_shape = (batch_size, in_channels) + (5,) * conv_dims | ||
x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype) | ||
with torch.no_grad(): | ||
y = lora_layer(x) | ||
|
||
expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims | ||
assert y.shape == expected_out_shape | ||
# Assert that lora_layer's internal layers have correct device and dtype. | ||
assert lora_layer._down.weight.device == original_layer.weight.device | ||
assert lora_layer._down.weight.dtype == original_layer.weight.dtype | ||
assert lora_layer._up.weight.device == original_layer.weight.device | ||
assert lora_layer._up.weight.dtype == original_layer.weight.dtype | ||
|
||
@pytest.mark.cuda | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
def test_lora_conv_layer_from_layer_override_device_and_dtype( | ||
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype | ||
): | ||
"""Test that when a LoRAConv*Layer is initialized with from_layer(...), the device and dtype can be | ||
overriden.""" | ||
batch_size = 10 | ||
in_channels = 8 | ||
out_channels = 16 | ||
# Original layer has dtype float32 on CPU. | ||
original_layer = lora_conv_cls.conv_module( | ||
in_channels, out_channels, kernel_size=3, padding="same", dtype=torch.float32 | ||
) | ||
|
||
target_device = torch.device("cuda:0") | ||
lora_layer = lora_conv_cls.from_layer(original_layer, device=target_device, dtype=dtype) | ||
|
||
in_shape = (batch_size, in_channels) + (5,) * conv_dims | ||
x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype) | ||
with torch.no_grad(): | ||
y = lora_layer(x) | ||
|
||
expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims | ||
assert y.shape == expected_out_shape | ||
# Assert that lora_layer's internal layers have correct device and dtype. | ||
assert lora_layer._down.weight.device == target_device | ||
assert lora_layer._down.weight.dtype == dtype | ||
assert lora_layer._up.weight.device == target_device | ||
assert lora_layer._up.weight.dtype == dtype | ||
|
||
def test_lora_conv_layer_state_dict_roundtrip(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int): | ||
original_layer = lora_conv_cls(8, 16) | ||
|
||
state_dict = original_layer.state_dict() | ||
|
||
roundtrip_layer = lora_conv_cls(8, 16, alpha=2.0) | ||
|
||
# Prior to loading the state_dict, the roundtrip_layer is different than the original_layer. | ||
# (We don't compare the _up layer, because it is initialized to zeros so should match already.) | ||
assert not torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) | ||
assert not torch.allclose(roundtrip_layer.alpha, original_layer.alpha) | ||
|
||
roundtrip_layer.load_state_dict(state_dict) | ||
|
||
# After loading the state_dict the roundtrip_layer and original_layer match. | ||
assert torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight) | ||
assert torch.allclose(roundtrip_layer._up.weight, original_layer._up.weight) | ||
assert torch.allclose(roundtrip_layer.alpha, original_layer.alpha) |