Skip to content

Commit

Permalink
Fix handling of LoRA conv layer kernel_size, stride, padding.
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanJDick committed Aug 5, 2023
1 parent f185723 commit 6f5e415
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 10 deletions.
31 changes: 24 additions & 7 deletions src/invoke_training/lora/layers/lora_conv_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import typing

import torch

Expand All @@ -21,6 +22,9 @@ def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: typing.Union[int, tuple[int]] = 1,
stride: typing.Union[int, tuple[int]] = 1,
padding: typing.Union[str, int, tuple[int]] = 0,
rank: int = 4,
alpha: float = 1.0,
device: torch.device = None,
Expand All @@ -30,6 +34,9 @@ def __init__(
Args:
in_channels (int): The number of channels expected on inputs to this layer.
out_channels (int): The number of channels on outputs from this layer.
kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs.
rank (int, optional): The internal rank of the layer. See the paper for details.
alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning
rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do
Expand All @@ -45,7 +52,14 @@ def __init__(
raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}")

self._down = self.conv_module(
in_channels, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype
in_channels,
rank,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=False,
device=device,
dtype=dtype,
)
self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype)

Expand Down Expand Up @@ -83,12 +97,15 @@ def from_layer(
"""
if isinstance(layer, cls.conv_module):
return cls(
layer.in_channels,
layer.out_channels,
rank,
alpha,
layer.weight.device if device is None else device,
layer.weight.dtype if dtype is None else dtype,
in_channels=layer.in_channels,
out_channels=layer.out_channels,
kernel_size=layer.kernel_size,
stride=layer.stride,
padding=layer.padding,
rank=rank,
alpha=alpha,
device=layer.weight.device if device is None else device,
dtype=layer.weight.dtype if dtype is None else dtype,
)
else:
raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.")
Expand Down
34 changes: 31 additions & 3 deletions tests/invoke_training/lora/layers/test_lora_conv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLay
batch_size = 10
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3)
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, padding="same")

lora_layer = lora_conv_cls.from_layer(original_layer)

Expand All @@ -82,6 +82,32 @@ def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLay
expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims
assert y.shape == expected_out_shape

def test_lora_conv_layer_from_layer_kernel_and_stride(
self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int
):
"""Test that a LoRAConv*Layer is initialized with the correct kernel_size, stride, and padding when initialized
from a torch.nn.Conv* layer."""
batch_size = 10
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, stride=2, padding="valid")

lora_layer = lora_conv_cls.from_layer(original_layer)

# Check the internal layer config.
assert lora_layer._down.kernel_size == original_layer.kernel_size
assert lora_layer._down.stride == original_layer.stride
assert lora_layer._down.padding == original_layer.padding

in_shape = (batch_size, in_channels) + (6,) * conv_dims
x = torch.rand(in_shape)
with torch.no_grad():
y = lora_layer(x)

# The combination of kernel_size, stride, and padding should reduce the dimensions to this output shape:
expected_out_shape = (batch_size, out_channels) + (2,) * conv_dims
assert y.shape == expected_out_shape

@pytest.mark.cuda
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_lora_conv_layer_from_layer_inherit_device_and_dtype(
Expand All @@ -94,7 +120,7 @@ def test_lora_conv_layer_from_layer_inherit_device_and_dtype(
in_channels = 8
out_channels = 16
original_layer = lora_conv_cls.conv_module(
in_channels, out_channels, kernel_size=3, device=torch.device("cuda"), dtype=dtype
in_channels, out_channels, kernel_size=3, padding="same", device=torch.device("cuda"), dtype=dtype
)

lora_layer = lora_conv_cls.from_layer(original_layer)
Expand Down Expand Up @@ -123,7 +149,9 @@ def test_lora_conv_layer_from_layer_override_device_and_dtype(
in_channels = 8
out_channels = 16
# Original layer has dtype float32 on CPU.
original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, dtype=torch.float32)
original_layer = lora_conv_cls.conv_module(
in_channels, out_channels, kernel_size=3, padding="same", dtype=torch.float32
)

target_device = torch.device("cuda:0")
lora_layer = lora_conv_cls.from_layer(original_layer, device=target_device, dtype=dtype)
Expand Down

0 comments on commit 6f5e415

Please sign in to comment.