Skip to content

Commit

Permalink
Add support for norm layer
Browse files Browse the repository at this point in the history
  • Loading branch information
StAlKeR7779 authored and hipsterusername committed Aug 8, 2024
1 parent 7da6120 commit 68f9939
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion invokeai/backend/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,39 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
self.on_input = self.on_input.to(device=device, dtype=dtype)


AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
class NormLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor
# bias: Optional[torch.Tensor]

def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)

self.weight = values["w_norm"]
self.bias = values.get("b_norm", None)

self.rank = None # unscaled
self.check_keys(values, {"w_norm", "b_norm"})

def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight

def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
return model_size

def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)

self.weight = self.weight.to(device=device, dtype=dtype)


AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]


class LoRAModelRaw(RawModel): # (torch.nn.Module):
Expand Down Expand Up @@ -519,6 +551,10 @@ def from_checkpoint(
elif "on_input" in values:
layer = IA3Layer(layer_key, values)

# norms
elif "w_norm" in values:
layer = NormLayer(layer_key, values)

else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
Expand Down

0 comments on commit 68f9939

Please sign in to comment.