Skip to content

Commit

Permalink
Support regional lora
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Dec 10, 2024
1 parent 4a044cf commit 5680232
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 63 deletions.
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class RegionInput:
bounds: Bounds
positive: str
control: list[ControlInput] = field(default_factory=list)
loras: list[LoraInput] = field(default_factory=list)


@dataclass
Expand Down
34 changes: 20 additions & 14 deletions ai_diffusion/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,21 +327,27 @@ def filter_supported_styles(styles: Iterable[Style], client: Client | None = Non


def loras_to_upload(workflow: WorkflowInput, client_models: ClientModels):
workflow_loras = []
if models := workflow.models:
for lora in models.loras:
if lora.name in client_models.loras:
continue
if not lora.storage_id and lora.name in _lcm_loras:
raise ValueError(_lcm_warning)
if not lora.storage_id:
raise ValueError(f"Lora model is not available: {lora.name}")
lora_file = FileLibrary.instance().loras.find_local(lora.name)
if lora_file is None or lora_file.path is None:
raise ValueError(f"Can't find Lora model: {lora.name}")
if not lora_file.path.exists():
raise ValueError(_("LoRA model file not found") + f" {lora_file.path}")
assert lora.storage_id == lora_file.hash
yield lora_file
workflow_loras.extend(models.loras)
if cond := workflow.conditioning:
for region in cond.regions:
workflow_loras.extend(region.loras)

for lora in workflow_loras:
if lora.name in client_models.loras:
continue
if not lora.storage_id and lora.name in _lcm_loras:
raise ValueError(_lcm_warning)
if not lora.storage_id:
raise ValueError(f"Lora model is not available: {lora.name}")
lora_file = FileLibrary.instance().loras.find_local(lora.name)
if lora_file is None or lora_file.path is None:
raise ValueError(f"Can't find Lora model: {lora.name}")
if not lora_file.path.exists():
raise ValueError(_("LoRA model file not found") + f" {lora_file.path}")
assert lora.storage_id == lora_file.hash
yield lora_file


_lcm_loras = ["lcm-lora-sdv1-5.safetensors", "lcm-lora-sdxl.safetensors"]
Expand Down
75 changes: 75 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,9 @@ def define_region(self, regions: Output, mask: Output, conditioning: Output):
"ETN_DefineRegion", 1, regions=regions, mask=mask, conditioning=conditioning
)

def list_region_masks(self, regions: Output):
return self.add("ETN_ListRegionMasks", 1, regions=regions)

def attention_mask(self, model: Output, regions: Output):
return self.add("ETN_AttentionMask", 1, model=model, regions=regions)

Expand Down Expand Up @@ -796,6 +799,9 @@ def invert_image(self, image: Output):
def batch_image(self, batch: Output, image: Output):
return self.add("ImageBatch", 1, image1=batch, image2=image)

def image_batch_element(self, batch: Output, index: int):
return self.add("ImageFromBatch", 1, image=batch, batch_index=index, length=1)

def inpaint_image(self, model: Output, image: Output, mask: Output):
return self.add(
"INPAINT_InpaintWithModel", 1, inpaint_model=model, image=image, mask=mask, seed=834729
Expand Down Expand Up @@ -837,6 +843,11 @@ def composite_image_masked(
def mask_to_image(self, mask: Output):
return self.add("MaskToImage", 1, mask=mask)

def mask_batch_element(self, mask_batch: Output, index: int):
image_batch = self.mask_to_image(mask_batch)
image = self.image_batch_element(image_batch, index)
return self.image_to_mask(image)

def solid_mask(self, extent: Extent, value=1.0):
return self.add("SolidMask", 1, width=extent.width, height=extent.height, value=value)

Expand Down Expand Up @@ -938,6 +949,70 @@ def estimate_pose(self, image: Output, resolution: int):
mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx"
return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls)

def create_hook_lora(self, loras: list[tuple[str, float]]):
key = "CreateHookLora" + str(loras)
hooks = self._cache.get(key, None)
if hooks is None:
for lora, strength in loras:
hooks = self.add(
"CreateHookLora",
1,
lora_name=lora,
strength_model=strength,
strength_clip=strength,
prev_hooks=hooks,
)
assert hooks is not None
self._cache[key] = hooks

assert isinstance(hooks, Output)
return hooks

def set_clip_hooks(self, clip: Output, hooks: Output):
return self.add(
"SetClipHooks", 1, clip=clip, hooks=hooks, apply_to_conds=True, schedule_clip=False
)

def combine_masked_conditioning(
self,
positive: Output,
negative: Output,
positive_conds: Output | None = None,
negative_conds: Output | None = None,
mask: Output | None = None,
):
assert (positive_conds and negative_conds) or mask
if mask is None:
return self.add(
"PairConditioningSetDefaultCombine",
2,
positive=positive_conds,
negative=negative_conds,
positive_DEFAULT=positive,
negative_DEFAULT=negative,
)
if positive_conds is None and negative_conds is None:
return self.add(
"PairConditioningSetProperties",
2,
positive_NEW=positive,
negative_NEW=negative,
mask=mask,
strength=1.0,
set_cond_area="default",
)
return self.add(
"PairConditioningSetPropertiesAndCombine",
2,
positive=positive_conds,
negative=negative_conds,
positive_NEW=positive,
negative_NEW=negative,
mask=mask,
strength=1.0,
set_cond_area="default",
)


def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""):
inputs = node_inputs.get(node_name)
Expand Down
Loading

0 comments on commit 5680232

Please sign in to comment.