Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply lora by model patching #3583

Merged
merged 3 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,20 @@ def invoke(self, context: InvocationContext) -> CompelOutput:
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
text_encoder_info as text_encoder:

loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]

ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
blessedcoolant marked this conversation as resolved.
Show resolved Hide resolved
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
#print(e)
Expand Down
10 changes: 4 additions & 6 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,7 @@ def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)

unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:

scheduler = get_scheduler(
context=context,
Expand All @@ -297,7 +296,7 @@ def step_callback(state: PipelineIntermediateState):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)

loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]

control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
Expand Down Expand Up @@ -361,8 +360,7 @@ def step_callback(state: PipelineIntermediateState):
**self.unet.unet.dict(),
)

with unet_info as unet,\
ExitStack() as stack:
with unet_info as unet:

scheduler = get_scheduler(
context=context,
Expand Down Expand Up @@ -391,7 +389,7 @@ def step_callback(state: PipelineIntermediateState):
device=unet.device,
)

loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]

with ModelPatcher.apply_lora_unet(pipeline.unet, loras):
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
Expand Down
12 changes: 9 additions & 3 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,13 @@ class LoraLoaderInvocation(BaseInvocation):

def invoke(self, context: InvocationContext) -> LoraLoaderOutput:

# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1

if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
):
raise Exception(f"Unkown lora name: {self.lora_name}!")

Expand All @@ -195,8 +199,9 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
output.unet = copy.deepcopy(self.unet)
output.unet.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
Expand All @@ -206,8 +211,9 @@ def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
output.clip = copy.deepcopy(self.clip)
output.clip.loras.append(
LoraInfo(
base_model=base_model,
model_name=self.lora_name,
model_type=SDModelType.Lora,
model_type=ModelType.Lora,
submodel=None,
weight=self.weight,
)
Expand Down
55 changes: 34 additions & 21 deletions invokeai/backend/model_management/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(
op = torch.nn.functional.linear
extra_args = {}

weight = self.get_weight(module)
weight = self.get_weight()

bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
Expand All @@ -81,7 +81,7 @@ def forward(
**extra_args,
) * multiplier * scale

def get_weight(self, module: torch.nn.Module):
def get_weight(self):
raise NotImplementedError()

def calc_size(self) -> int:
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(

self.rank = self.down.shape[0]

def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1])
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
super().__init__(layer_key, values)

self.w1_a = values["hada_w1_a"]
self.w1_b = values["hada_w1_b"]
Expand All @@ -185,7 +185,7 @@ def __init__(

self.rank = self.w1_b.shape[0]

def get_weight(self, module: torch.nn.Module):
def get_weight(self):
if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)

Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(
layer_key: str,
values: dict,
):
super().__init__(module_key, rank, alpha, bias)
super().__init__(layer_key, values)

if "lokr_w1" in values:
self.w1 = values["lokr_w1"]
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(
else:
self.rank = None # unscaled

def get_weight(self, module: torch.nn.Module):
def get_weight(self):
w1 = self.w1
if w1 is None:
w1 = self.w1_a @ self.w1_b
Expand All @@ -286,7 +286,7 @@ def get_weight(self, module: torch.nn.Module):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
weight = torch.kron(w1, w2)

return weight

Expand Down Expand Up @@ -471,7 +471,7 @@ def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tup
submodule_name += "_" + key_parts.pop(0)

module = module.get_submodule(submodule_name)
module_key = module_key.rstrip(".")
module_key = (module_key + "." + submodule_name).lstrip(".")

return (module_key, module)

Expand Down Expand Up @@ -525,23 +525,36 @@ def apply_lora(
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
hooks = dict()
original_weights = dict()
try:
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
with torch.no_grad():
for lora, lora_weight in loras:
#assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue

module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)

# enable autocast to calc fp16 loras on cpu
with torch.autocast(device_type="cpu"):
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight() * lora_weight * layer_scale

if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)

module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
if module_key not in hooks:
hooks[module_key] = module.register_forward_hook(cls._lora_forward_hook(loras, layer_key))
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)

yield # wait for context manager exit

finally:
for module_key, hook in hooks.items():
hook.remove()
hooks.clear()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)


@classmethod
Expand Down Expand Up @@ -591,7 +604,7 @@ def _get_trigger(ti, index):
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
)

model_embeddings.weight.data[token_id] = embedding
model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype)
ti_tokens.append(token_id)

if len(ti_tokens) > 1:
Expand Down