Skip to content

Commit

Permalink
Merge pull request #8 from invoke-ai/ryan/conv-lora
Browse files Browse the repository at this point in the history
Add support for Conv LoRA layers
  • Loading branch information
RyanJDick authored Aug 8, 2023
2 parents 3973019 + 6f5e415 commit bcf1a91
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 7 deletions.
11 changes: 8 additions & 3 deletions src/invoke_training/lora/injection/stable_diffusion_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
from diffusers.models import Transformer2DModel, UNet2DConditionModel
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear

from invoke_training.lora.injection.lora_layer_collection import LoRALayerCollection
from invoke_training.lora.injection.utils import inject_lora_layers
from invoke_training.lora.layers import LoRALinearLayer
from invoke_training.lora.layers import LoRAConv2dLayer, LoRALinearLayer


def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection:
Expand All @@ -21,7 +21,12 @@ def inject_lora_into_unet_sd1(unet: UNet2DConditionModel) -> LoRALayerCollection

lora_layers = inject_lora_layers(
module=unet,
lora_map={torch.nn.Linear: LoRALinearLayer, LoRACompatibleLinear: LoRALinearLayer},
lora_map={
torch.nn.Linear: LoRALinearLayer,
LoRACompatibleLinear: LoRALinearLayer,
torch.nn.Conv2d: LoRAConv2dLayer,
LoRACompatibleConv: LoRAConv2dLayer,
},
include_descendants_of={Transformer2DModel},
exclude_descendants_of=None,
prefix="lora_unet",
Expand Down
5 changes: 5 additions & 0 deletions src/invoke_training/lora/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from .base_lora_layer import BaseLoRALayer # noqa: F401
from .lora_conv_layer import ( # noqa: F401
LoRAConv1dLayer,
LoRAConv2dLayer,
LoRAConv3dLayer,
)
from .lora_linear_layer import LoRALinearLayer # noqa: F401
131 changes: 131 additions & 0 deletions src/invoke_training/lora/layers/lora_conv_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import math
import typing

import torch

from invoke_training.lora.layers import BaseLoRALayer


class LoRAConvLayer(BaseLoRALayer):
"""An implementation of a conv LoRA layer based on the paper 'LoRA: Low-Rank Adaptation of Large Language Models'.
(https://arxiv.org/pdf/2106.09685.pdf)
"""

@property
def conv_module(self):
"""The conv module to be set by child classes. One of torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d."""
raise NotImplementedError(
"LoRAConvLayer cannot be used directly. Use LoRAConv1dLayer, LoRAConv2dLayer, or LoRAConv3dLayer instead."
)

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: typing.Union[int, tuple[int]] = 1,
stride: typing.Union[int, tuple[int]] = 1,
padding: typing.Union[str, int, tuple[int]] = 0,
rank: int = 4,
alpha: float = 1.0,
device: torch.device = None,
dtype: torch.dtype = None,
):
"""Initialize a LoRAConvLayer.
Args:
in_channels (int): The number of channels expected on inputs to this layer.
out_channels (int): The number of channels on outputs from this layer.
kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
rank (int, optional): The internal rank of the layer. See the paper for details.
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
not tune it further. See the paper for more details.
device (torch.device, optional): Device where weights will be initialized.
dtype (torch.dtype, optional): Weight dtype.
Raises:
ValueError: If the rank is greater than either in_channels or out_channels.
"""
super().__init__()

if rank > min(in_channels, out_channels):
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")

self._down = self.conv_module(
in_channels,
rank,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
device=device,
dtype=dtype,
)
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)

# Register alpha as a buffer so that it is not trained, but still gets saved to the state_dict.
self.register_buffer("alpha", torch.tensor(alpha, device=device, dtype=dtype))

self._rank = rank

self.reset_parameters()

def reset_parameters(self):
# This initialization is based on Microsoft's implementation:
# https://github.com/microsoft/LoRA/blob/998cfe4d351f4d6b4a47f0921dec2397aa0b9dfe/loralib/layers.py#L279
torch.nn.init.kaiming_uniform_(self._down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self._up.weight)

@classmethod
def from_layer(
cls,
layer: torch.nn.Module,
rank: int = 4,
alpha: float = 1.0,
device: torch.device = None,
dtype: torch.dtype = None,
):
"""Initialize a LoRAConvLayer with dimensions that are compatible with `layer`.
Args:
layer (torch.nn.Module): The existing layer whose in/out dimensions will be matched.
rank, alpha, device, dtype: These args are forwarded to __init__(...). If device or dtype are None, they
will be inferred from `layer`.
Raises:
TypeError: If `layer` has an unsupported type.
Returns:
LoRAConvLayer: The new LoRAConvLayer.
"""
if isinstance(layer, cls.conv_module):
return cls(
in_channels=layer.in_channels,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
rank=rank,
alpha=alpha,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
else:
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")

def forward(self, input: torch.Tensor):
down_hidden = self._down(input)
up_hidden = self._up(down_hidden)

up_hidden *= self.alpha / self._rank

return up_hidden


class LoRAConv1dLayer(LoRAConvLayer):
conv_module = torch.nn.Conv1d


class LoRAConv2dLayer(LoRAConvLayer):
conv_module = torch.nn.Conv2d


class LoRAConv3dLayer(LoRAConvLayer):
conv_module = torch.nn.Conv3d
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def test_inject_lora_into_unet_sd1_smoke():

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(lora_layers) == 160
# assert len(lora_layers) == 192 # TODO(ryand): Enable this check once conv layers are added.
assert len(lora_layers) == 192
for layer_name in lora_layers._names:
assert layer_name.endswith(("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"))
assert layer_name.endswith(
("to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2", ".proj_in", ".proj_out")
)


@pytest.mark.loads_model
Expand All @@ -44,7 +45,7 @@ def test_convert_lora_state_dict_to_kohya_format_sd1_smoke():

# These assertions are based on a manual check of the injected layers and comparison against the behaviour of
# kohya_ss. They are included here to force another manual review after any future breaking change.
assert len(kohya_state_dict) == 160 * 3
assert len(kohya_state_dict) == 192 * 3
for key in kohya_state_dict.keys():
assert key.startswith("lora_unet_")
assert key.endswith((".lora_down.weight", ".lora_up.weight", ".alpha"))
Expand Down
189 changes: 189 additions & 0 deletions tests/invoke_training/lora/layers/test_lora_conv_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import typing

import pytest
import torch

from invoke_training.lora.layers import (
LoRAConv1dLayer,
LoRAConv2dLayer,
LoRAConv3dLayer,
)
from invoke_training.lora.layers.lora_conv_layer import LoRAConvLayer


def test_lora_conv_layer_initialize_base_class():
"""Test that attempting to directly initialize a LoRAConvLayer raise a NotImplementedError."""
with pytest.raises(NotImplementedError):
_ = LoRAConvLayer(4, 8)


@pytest.mark.parametrize(
["lora_conv_cls", "conv_dims"], [(LoRAConv1dLayer, 1), (LoRAConv2dLayer, 2), (LoRAConv3dLayer, 3)]
)
class TestLoRAConvLayers:
"""Test class for applying tests to each of the LoRAConv*Layer classes."""

def test_lora_conv_layer_output_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int):
"""Test that LoRAConv*Layer produces an output with the expected dimensions."""
batch_size = 10
in_channels = 8
out_channels = 16
layer = lora_conv_cls(in_channels, out_channels)

in_shape = (batch_size, in_channels) + (5,) * conv_dims
x = torch.rand(in_shape)
with torch.no_grad():
y = layer(x)

expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims
assert y.shape == expected_out_shape

def test_lora_conv_layer_invalid_input_dim(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int):
"""Test that LoRAConv*Layer raises an exception if it receives an input with invalid dimensions."""
batch_size = 10
in_channels = 8
out_channels = 16
layer = lora_conv_cls(in_channels, out_channels)

in_shape = (batch_size, in_channels + 1) + (5,) * conv_dims # Bad input dimension.
x = torch.rand(in_shape)
with pytest.raises(RuntimeError):
_ = layer(x)

def test_lora_conv_layer_zero_after_init(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int):
"""Test that a newly-initialized LoRAConv*Layer produces all zeros before it is trained."""
batch_size = 10
in_channels = 8
out_channels = 16
layer = lora_conv_cls(in_channels, out_channels)

in_shape = (batch_size, in_channels) + (5,) * conv_dims
x = torch.rand(in_shape)
with torch.no_grad():
y = layer(x)

assert not torch.allclose(x, torch.Tensor([0.0]), rtol=0.0) # The random input was non-zero.
assert torch.allclose(y, torch.Tensor([0.0]), rtol=0.0) # The untrained outputs are zero.

def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int):
"""Test that a LoRAConv*Layer can be initialized correctly from a torch.nn.Conv* layer."""
batch_size = 10
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, padding="same")

lora_layer = lora_conv_cls.from_layer(original_layer)

in_shape = (batch_size, in_channels) + (5,) * conv_dims
x = torch.rand(in_shape)
with torch.no_grad():
y = lora_layer(x)

expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims
assert y.shape == expected_out_shape

def test_lora_conv_layer_from_layer_kernel_and_stride(
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int
):
"""Test that a LoRAConv*Layer is initialized with the correct kernel_size, stride, and padding when initialized
from a torch.nn.Conv* layer."""
batch_size = 10
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, stride=2, padding="valid")

lora_layer = lora_conv_cls.from_layer(original_layer)

# Check the internal layer config.
assert lora_layer._down.kernel_size == original_layer.kernel_size
assert lora_layer._down.stride == original_layer.stride
assert lora_layer._down.padding == original_layer.padding

in_shape = (batch_size, in_channels) + (6,) * conv_dims
x = torch.rand(in_shape)
with torch.no_grad():
y = lora_layer(x)

# The combination of kernel_size, stride, and padding should reduce the dimensions to this output shape:
expected_out_shape = (batch_size, out_channels) + (2,) * conv_dims
assert y.shape == expected_out_shape

@pytest.mark.cuda
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_lora_conv_layer_from_layer_inherit_device_and_dtype(
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype
):
"""Test that when a LoRAConv*Layer is initialized with from_layer(...), it correctly inherits the device and
dtype.
"""
batch_size = 10
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(
in_channels, out_channels, kernel_size=3, padding="same", device=torch.device("cuda"), dtype=dtype
)

lora_layer = lora_conv_cls.from_layer(original_layer)

in_shape = (batch_size, in_channels) + (5,) * conv_dims
x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype)
with torch.no_grad():
y = lora_layer(x)

expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims
assert y.shape == expected_out_shape
# Assert that lora_layer's internal layers have correct device and dtype.
assert lora_layer._down.weight.device == original_layer.weight.device
assert lora_layer._down.weight.dtype == original_layer.weight.dtype
assert lora_layer._up.weight.device == original_layer.weight.device
assert lora_layer._up.weight.dtype == original_layer.weight.dtype

@pytest.mark.cuda
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_lora_conv_layer_from_layer_override_device_and_dtype(
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int, dtype: torch.dtype
):
"""Test that when a LoRAConv*Layer is initialized with from_layer(...), the device and dtype can be
overriden."""
batch_size = 10
in_channels = 8
out_channels = 16
# Original layer has dtype float32 on CPU.
original_layer = lora_conv_cls.conv_module(
in_channels, out_channels, kernel_size=3, padding="same", dtype=torch.float32
)

target_device = torch.device("cuda:0")
lora_layer = lora_conv_cls.from_layer(original_layer, device=target_device, dtype=dtype)

in_shape = (batch_size, in_channels) + (5,) * conv_dims
x = torch.rand(in_shape, device=torch.device("cuda"), dtype=dtype)
with torch.no_grad():
y = lora_layer(x)

expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims
assert y.shape == expected_out_shape
# Assert that lora_layer's internal layers have correct device and dtype.
assert lora_layer._down.weight.device == target_device
assert lora_layer._down.weight.dtype == dtype
assert lora_layer._up.weight.device == target_device
assert lora_layer._up.weight.dtype == dtype

def test_lora_conv_layer_state_dict_roundtrip(self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int):
original_layer = lora_conv_cls(8, 16)

state_dict = original_layer.state_dict()

roundtrip_layer = lora_conv_cls(8, 16, alpha=2.0)

# Prior to loading the state_dict, the roundtrip_layer is different than the original_layer.
# (We don't compare the _up layer, because it is initialized to zeros so should match already.)
assert not torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight)
assert not torch.allclose(roundtrip_layer.alpha, original_layer.alpha)

roundtrip_layer.load_state_dict(state_dict)

# After loading the state_dict the roundtrip_layer and original_layer match.
assert torch.allclose(roundtrip_layer._down.weight, original_layer._down.weight)
assert torch.allclose(roundtrip_layer._up.weight, original_layer._up.weight)
assert torch.allclose(roundtrip_layer.alpha, original_layer.alpha)

0 comments on commit bcf1a91

Please sign in to comment.