diff --git a/src/invoke_training/lora/layers/lora_conv_layer.py b/src/invoke_training/lora/layers/lora_conv_layer.py index b919bdce..94e8afd4 100644 --- a/src/invoke_training/lora/layers/lora_conv_layer.py +++ b/src/invoke_training/lora/layers/lora_conv_layer.py @@ -1,4 +1,5 @@ import math +import typing import torch @@ -21,6 +22,9 @@ def __init__( self, in_channels: int, out_channels: int, + kernel_size: typing.Union[int, tuple[int]] = 1, + stride: typing.Union[int, tuple[int]] = 1, + padding: typing.Union[str, int, tuple[int]] = 0, rank: int = 4, alpha: float = 1.0, device: torch.device = None, @@ -30,6 +34,9 @@ def __init__( Args: in_channels (int): The number of channels expected on inputs to this layer. out_channels (int): The number of channels on outputs from this layer. + kernel_size: The kernel_size of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. + stride: The stride of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. + padding: The padding of the conv layer that this LoRA layer is mirroring. See torch.nn.Conv* docs. rank (int, optional): The internal rank of the layer. See the paper for details. alpha (float, optional): A scaling factor that enables tuning the rank without having to adjust the learning rate. The recommendation from the paper is to set alpha equal to the first rank that you try and then do @@ -45,7 +52,14 @@ def __init__( raise ValueError(f"LoRA rank {rank} must be less than or equal to {min(in_channels, out_channels)}") self._down = self.conv_module( - in_channels, rank, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype + in_channels, + rank, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=False, + device=device, + dtype=dtype, ) self._up = self.conv_module(rank, out_channels, kernel_size=1, stride=1, bias=False, device=device, dtype=dtype) @@ -83,12 +97,15 @@ def from_layer( """ if isinstance(layer, cls.conv_module): return cls( - layer.in_channels, - layer.out_channels, - rank, - alpha, - layer.weight.device if device is None else device, - layer.weight.dtype if dtype is None else dtype, + in_channels=layer.in_channels, + out_channels=layer.out_channels, + kernel_size=layer.kernel_size, + stride=layer.stride, + padding=layer.padding, + rank=rank, + alpha=alpha, + device=layer.weight.device if device is None else device, + dtype=layer.weight.dtype if dtype is None else dtype, ) else: raise TypeError(f"'{__class__.__name__}' cannot be initialized from a layer of type '{type(layer)}'.") diff --git a/tests/invoke_training/lora/layers/test_lora_conv_layer.py b/tests/invoke_training/lora/layers/test_lora_conv_layer.py index 5efcedb2..a17eade2 100644 --- a/tests/invoke_training/lora/layers/test_lora_conv_layer.py +++ b/tests/invoke_training/lora/layers/test_lora_conv_layer.py @@ -70,7 +70,7 @@ def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLay batch_size = 10 in_channels = 8 out_channels = 16 - original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3) + original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, padding="same") lora_layer = lora_conv_cls.from_layer(original_layer) @@ -82,6 +82,32 @@ def test_lora_conv_layer_from_layer(self, lora_conv_cls: typing.Type[LoRAConvLay expected_out_shape = (batch_size, out_channels) + (5,) * conv_dims assert y.shape == expected_out_shape + def test_lora_conv_layer_from_layer_kernel_and_stride( + self, lora_conv_cls: typing.Type[LoRAConvLayer], conv_dims: int + ): + """Test that a LoRAConv*Layer is initialized with the correct kernel_size, stride, and padding when initialized + from a torch.nn.Conv* layer.""" + batch_size = 10 + in_channels = 8 + out_channels = 16 + original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, stride=2, padding="valid") + + lora_layer = lora_conv_cls.from_layer(original_layer) + + # Check the internal layer config. + assert lora_layer._down.kernel_size == original_layer.kernel_size + assert lora_layer._down.stride == original_layer.stride + assert lora_layer._down.padding == original_layer.padding + + in_shape = (batch_size, in_channels) + (6,) * conv_dims + x = torch.rand(in_shape) + with torch.no_grad(): + y = lora_layer(x) + + # The combination of kernel_size, stride, and padding should reduce the dimensions to this output shape: + expected_out_shape = (batch_size, out_channels) + (2,) * conv_dims + assert y.shape == expected_out_shape + @pytest.mark.cuda @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_lora_conv_layer_from_layer_inherit_device_and_dtype( @@ -94,7 +120,7 @@ def test_lora_conv_layer_from_layer_inherit_device_and_dtype( in_channels = 8 out_channels = 16 original_layer = lora_conv_cls.conv_module( - in_channels, out_channels, kernel_size=3, device=torch.device("cuda"), dtype=dtype + in_channels, out_channels, kernel_size=3, padding="same", device=torch.device("cuda"), dtype=dtype ) lora_layer = lora_conv_cls.from_layer(original_layer) @@ -123,7 +149,9 @@ def test_lora_conv_layer_from_layer_override_device_and_dtype( in_channels = 8 out_channels = 16 # Original layer has dtype float32 on CPU. - original_layer = lora_conv_cls.conv_module(in_channels, out_channels, kernel_size=3, dtype=torch.float32) + original_layer = lora_conv_cls.conv_module( + in_channels, out_channels, kernel_size=3, padding="same", dtype=torch.float32 + ) target_device = torch.device("cuda:0") lora_layer = lora_conv_cls.from_layer(original_layer, device=target_device, dtype=dtype)