diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index c473fb24c1..80e80e022c 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -96,6 +96,7 @@ class RegionInput: bounds: Bounds positive: str control: list[ControlInput] = field(default_factory=list) + loras: list[LoraInput] = field(default_factory=list) @dataclass diff --git a/ai_diffusion/client.py b/ai_diffusion/client.py index 75f2b57d9f..47d79fccde 100644 --- a/ai_diffusion/client.py +++ b/ai_diffusion/client.py @@ -327,21 +327,27 @@ def filter_supported_styles(styles: Iterable[Style], client: Client | None = Non def loras_to_upload(workflow: WorkflowInput, client_models: ClientModels): + workflow_loras = [] if models := workflow.models: - for lora in models.loras: - if lora.name in client_models.loras: - continue - if not lora.storage_id and lora.name in _lcm_loras: - raise ValueError(_lcm_warning) - if not lora.storage_id: - raise ValueError(f"Lora model is not available: {lora.name}") - lora_file = FileLibrary.instance().loras.find_local(lora.name) - if lora_file is None or lora_file.path is None: - raise ValueError(f"Can't find Lora model: {lora.name}") - if not lora_file.path.exists(): - raise ValueError(_("LoRA model file not found") + f" {lora_file.path}") - assert lora.storage_id == lora_file.hash - yield lora_file + workflow_loras.extend(models.loras) + if cond := workflow.conditioning: + for region in cond.regions: + workflow_loras.extend(region.loras) + + for lora in workflow_loras: + if lora.name in client_models.loras: + continue + if not lora.storage_id and lora.name in _lcm_loras: + raise ValueError(_lcm_warning) + if not lora.storage_id: + raise ValueError(f"Lora model is not available: {lora.name}") + lora_file = FileLibrary.instance().loras.find_local(lora.name) + if lora_file is None or lora_file.path is None: + raise ValueError(f"Can't find Lora model: {lora.name}") + if not lora_file.path.exists(): + raise ValueError(_("LoRA model file not found") + f" {lora_file.path}") + assert lora.storage_id == lora_file.hash + yield lora_file _lcm_loras = ["lcm-lora-sdv1-5.safetensors", "lcm-lora-sdxl.safetensors"] diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 190bc5f110..2f3189762a 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -547,6 +547,9 @@ def define_region(self, regions: Output, mask: Output, conditioning: Output): "ETN_DefineRegion", 1, regions=regions, mask=mask, conditioning=conditioning ) + def list_region_masks(self, regions: Output): + return self.add("ETN_ListRegionMasks", 1, regions=regions) + def attention_mask(self, model: Output, regions: Output): return self.add("ETN_AttentionMask", 1, model=model, regions=regions) @@ -796,6 +799,9 @@ def invert_image(self, image: Output): def batch_image(self, batch: Output, image: Output): return self.add("ImageBatch", 1, image1=batch, image2=image) + def image_batch_element(self, batch: Output, index: int): + return self.add("ImageFromBatch", 1, image=batch, batch_index=index, length=1) + def inpaint_image(self, model: Output, image: Output, mask: Output): return self.add( "INPAINT_InpaintWithModel", 1, inpaint_model=model, image=image, mask=mask, seed=834729 @@ -837,6 +843,11 @@ def composite_image_masked( def mask_to_image(self, mask: Output): return self.add("MaskToImage", 1, mask=mask) + def mask_batch_element(self, mask_batch: Output, index: int): + image_batch = self.mask_to_image(mask_batch) + image = self.image_batch_element(image_batch, index) + return self.image_to_mask(image) + def solid_mask(self, extent: Extent, value=1.0): return self.add("SolidMask", 1, width=extent.width, height=extent.height, value=value) @@ -938,6 +949,70 @@ def estimate_pose(self, image: Output, resolution: int): mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx" return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls) + def create_hook_lora(self, loras: list[tuple[str, float]]): + key = "CreateHookLora" + str(loras) + hooks = self._cache.get(key, None) + if hooks is None: + for lora, strength in loras: + hooks = self.add( + "CreateHookLora", + 1, + lora_name=lora, + strength_model=strength, + strength_clip=strength, + prev_hooks=hooks, + ) + assert hooks is not None + self._cache[key] = hooks + + assert isinstance(hooks, Output) + return hooks + + def set_clip_hooks(self, clip: Output, hooks: Output): + return self.add( + "SetClipHooks", 1, clip=clip, hooks=hooks, apply_to_conds=True, schedule_clip=False + ) + + def combine_masked_conditioning( + self, + positive: Output, + negative: Output, + positive_conds: Output | None = None, + negative_conds: Output | None = None, + mask: Output | None = None, + ): + assert (positive_conds and negative_conds) or mask + if mask is None: + return self.add( + "PairConditioningSetDefaultCombine", + 2, + positive=positive_conds, + negative=negative_conds, + positive_DEFAULT=positive, + negative_DEFAULT=negative, + ) + if positive_conds is None and negative_conds is None: + return self.add( + "PairConditioningSetProperties", + 2, + positive_NEW=positive, + negative_NEW=negative, + mask=mask, + strength=1.0, + set_cond_area="default", + ) + return self.add( + "PairConditioningSetPropertiesAndCombine", + 2, + positive=positive_conds, + negative=negative_conds, + positive_NEW=positive, + negative_NEW=negative, + mask=mask, + strength=1.0, + set_cond_area="default", + ) + def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""): inputs = node_inputs.get(node_name) diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index a24502030f..35b29ca88b 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -216,7 +216,9 @@ def from_input(i: ControlInput): class TextPrompt: text: str language: str + # Cached values to avoid re-encoding the same text for multiple regions and passes _output: Output | None = None + _clip: Output | None = None # can be different due to Lora hooks def __init__(self, text: str, language: str): self.text = text @@ -226,10 +228,11 @@ def encode(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None text = self.text if text != "" and style_prompt: text = merge_prompt(text, style_prompt, self.language) - if self._output is None: + if self._output is None or self._clip != clip: if text and self.language: text = w.translate(text) self._output = w.clip_text_encode(clip, text) + self._clip = clip return self._output @@ -239,19 +242,40 @@ class Region: bounds: Bounds positive: TextPrompt control: list[Control] = field(default_factory=list) + loras: list[LoraInput] = field(default_factory=list) is_background: bool = False + clip: Output | None = None @staticmethod def from_input(i: RegionInput, index: int, language: str): control = [Control.from_input(c) for c in i.control] mask = ImageOutput(i.mask, is_mask=True) return Region( - mask, i.bounds, TextPrompt(i.positive, language), control, is_background=index == 0 + mask, + i.bounds, + TextPrompt(i.positive, language), + control, + i.loras, + is_background=index == 0, ) + def patch_clip(self, w: ComfyWorkflow, clip: Output): + if self.clip is None: + self.clip = clip + if len(self.loras) > 0: + hooks = w.create_hook_lora([(lora.name, lora.strength) for lora in self.loras]) + self.clip = w.set_clip_hooks(clip, hooks) + return self.clip + + def encode_prompt(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None): + return self.positive.encode(w, self.patch_clip(w, clip), style_prompt) + def copy(self): control = [copy(c) for c in self.control] - return Region(self.mask, self.bounds, self.positive, control, self.is_background) + loras = copy(self.loras) + return Region( + self.mask, self.bounds, self.positive, control, loras, self.is_background, self.clip + ) @dataclass @@ -323,9 +347,31 @@ def downscale_all_control_images(cond: ConditioningInput, original: Extent, targ downscale_control_images(region.control, original, target) -def encode_text_prompt(w: ComfyWorkflow, cond: Conditioning, clip: Output): - positive = cond.positive.encode(w, clip, cond.style_prompt) - negative = cond.negative.encode(w, clip) +def encode_text_prompt( + w: ComfyWorkflow, + cond: Conditioning, + clip: Output, + regions: Output | None, +): + if len(cond.regions) <= 1 or all(len(r.loras) == 0 for r in cond.regions): + positive = cond.positive.encode(w, clip, cond.style_prompt) + negative = cond.negative.encode(w, clip) + return positive, negative + + assert regions is not None + positive = None + negative = None + region_masks = w.list_region_masks(regions) + + for i, region in enumerate(cond.regions): + region_positive = region.encode_prompt(w, clip, cond.style_prompt) + region_negative = cond.negative.encode(w, region.patch_clip(w, clip)) + mask = w.mask_batch_element(region_masks, i) + positive, negative = w.combine_masked_conditioning( + region_positive, region_negative, positive, negative, mask + ) + + assert positive is not None and negative is not None return positive, negative @@ -333,17 +379,17 @@ def apply_attention_mask( w: ComfyWorkflow, model: Output, cond: Conditioning, clip: Output, target_extent: Extent | None ): if len(cond.regions) == 0: - return model + return model, None if len(cond.regions) == 1: region = cond.regions[0] cond.positive = region.positive cond.control += region.control - return model + return model, None bottom_region = cond.regions[0] if bottom_region.is_background: - regions = w.background_region(bottom_region.positive.encode(w, clip, cond.style_prompt)) + regions = w.background_region(bottom_region.encode_prompt(w, clip, cond.style_prompt)) remaining = cond.regions[1:] else: regions = w.background_region(cond.positive.encode(w, clip, cond.style_prompt)) @@ -351,10 +397,11 @@ def apply_attention_mask( for region in remaining: mask = region.mask.load(w, target_extent) - prompt = region.positive.encode(w, clip, cond.style_prompt) + prompt = region.encode_prompt(w, clip, cond.style_prompt) regions = w.define_region(regions, mask, prompt) - return w.attention_mask(model, regions) + model = w.attention_mask(model, regions) + return model, regions def apply_control( @@ -553,8 +600,6 @@ def scale_refine_and_decode( cond: Conditioning, sampling: SamplingInput, latent: Output, - prompt_pos: Output, - prompt_neg: Output, model: Output, clip: Output, vae: Output, @@ -569,7 +614,7 @@ def scale_refine_and_decode( decoded = w.vae_decode(vae, latent) return scale(extent.initial, extent.desired, mode, w, decoded, models) - model = apply_attention_mask(w, model, cond, clip, extent.desired) + model, regions = apply_attention_mask(w, model, cond, clip, extent.desired) model = apply_regional_ip_adapter(w, model, cond.regions, extent.desired, models) if mode is ScaleMode.upscale_small: @@ -585,8 +630,9 @@ def scale_refine_and_decode( latent = w.vae_encode(vae, upscale) params = _sampler_params(sampling, strength=0.4) + positive, negative = encode_text_prompt(w, cond, clip, regions) model, positive, negative = apply_control( - w, model, prompt_pos, prompt_neg, cond.all_control, extent.desired, vae, models + w, model, positive, negative, cond.all_control, extent.desired, vae, models ) result = w.sampler_custom_advanced(model, positive, negative, latent, models.arch, **params) image = w.vae_decode(vae, result) @@ -617,18 +663,18 @@ def generate( model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) - model = apply_attention_mask(w, model, cond, clip, extent.initial) + model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) latent = w.empty_latent_image(extent.initial, models.arch, misc.batch_count) - prompt_pos, prompt_neg = encode_text_prompt(w, cond, clip) + positive, negative = encode_text_prompt(w, cond, clip, regions) model, positive, negative = apply_control( - w, model, prompt_pos, prompt_neg, cond.all_control, extent.initial, vae, models + w, model, positive, negative, cond.all_control, extent.initial, vae, models ) out_latent = w.sampler_custom_advanced( model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) out_image = scale_refine_and_decode( - extent, w, cond, sampling, out_latent, prompt_pos, prompt_neg, model_orig, clip, vae, models + extent, w, cond, sampling, out_latent, model_orig, clip, vae, models ) out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter) out_image = scale_to_target(extent, w, out_image, models) @@ -743,7 +789,7 @@ def inpaint( cond_base = cond.copy() cond_base.downscale(extent.input, extent.initial) - model = apply_attention_mask(w, model, cond_base, clip, extent.initial) + model, regions = apply_attention_mask(w, model, cond_base, clip, extent.initial) if params.use_reference: reference = get_inpaint_reference(ensure(images.initial_image), initial_bounds) or in_image @@ -763,7 +809,7 @@ def inpaint( model = apply_ip_adapter(w, model, cond_base.control, models) model = apply_regional_ip_adapter(w, model, cond_base.regions, extent.initial, models) - positive, negative = encode_text_prompt(w, cond, clip) + positive, negative = encode_text_prompt(w, cond, clip, regions) model, positive, negative = apply_control( w, model, positive, negative, cond_base.all_control, extent.initial, vae, models ) @@ -810,9 +856,9 @@ def inpaint( cond_upscale.crop(target_bounds) res = upscale_extent.desired - positive_up, negative_up = encode_text_prompt(w, cond_upscale, clip) - model = apply_attention_mask(w, model, cond_upscale, clip, res) + model, regions = apply_attention_mask(w, model, cond_upscale, clip, res) model = apply_regional_ip_adapter(w, model, cond_upscale.regions, res, models) + positive_up, negative_up = encode_text_prompt(w, cond_upscale, clip, regions) if params.use_inpaint_model and models.control.find(ControlMode.inpaint) is not None: hires_image = ImageOutput(images.hires_image) @@ -857,13 +903,13 @@ def refine( ): model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) model = apply_ip_adapter(w, model, cond.control, models) - model = apply_attention_mask(w, model, cond, clip, extent.initial) + model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) in_image = w.load_image(image) in_image = scale_to_initial(extent, w, in_image, models) latent = w.vae_encode(vae, in_image) latent = w.batch_latent(latent, misc.batch_count) - positive, negative = encode_text_prompt(w, cond, clip) + positive, negative = encode_text_prompt(w, cond, clip, regions) model, positive, negative = apply_control( w, model, positive, negative, cond.all_control, extent.desired, vae, models ) @@ -893,9 +939,9 @@ def refine_region( model = w.differential_diffusion(model) model = apply_ip_adapter(w, model, cond.control, models) model_orig = copy(model) - model = apply_attention_mask(w, model, cond, clip, extent.initial) + model, regions = apply_attention_mask(w, model, cond, clip, extent.initial) model = apply_regional_ip_adapter(w, model, cond.regions, extent.initial, models) - prompt_pos, prompt_neg = encode_text_prompt(w, cond, clip) + positive, negative = encode_text_prompt(w, cond, clip, regions) in_image = w.load_image(ensure(images.initial_image)) in_image = scale_to_initial(extent, w, in_image, models) @@ -906,7 +952,7 @@ def refine_region( if inpaint.use_inpaint_model and models.control.find(ControlMode.inpaint) is not None: cond.control.append(inpaint_control(in_image, initial_mask, models.arch)) model, positive, negative = apply_control( - w, model, prompt_pos, prompt_neg, cond.all_control, extent.initial, vae, models + w, model, positive, negative, cond.all_control, extent.initial, vae, models ) if inpaint.use_inpaint_model and models.arch is Arch.sdxl: positive, negative, latent_inpaint, latent = w.vae_encode_inpaint_conditioning( @@ -924,7 +970,7 @@ def refine_region( inpaint_model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) out_image = scale_refine_and_decode( - extent, w, cond, sampling, out_latent, prompt_pos, prompt_neg, model_orig, clip, vae, models + extent, w, cond, sampling, out_latent, model_orig, clip, vae, models ) out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter) out_image = scale_to_target(extent, w, out_image, models) @@ -1036,7 +1082,6 @@ def upscale_tiled( model, clip, vae = load_checkpoint_with_lora(w, checkpoint, models.all) model = apply_ip_adapter(w, model, cond.control, models) - positive, negative = encode_text_prompt(w, cond, clip) in_image = w.load_image(image) if upscale_model_name: @@ -1073,18 +1118,19 @@ def tiled_region(region: Region, index: int, tile_bounds: Bounds): tile_cond = cond.copy() regions = [tiled_region(r, i, bounds) for r in tile_cond.regions] tile_cond.regions = [r for r in regions if r is not None] - tile_model = apply_attention_mask(w, model, tile_cond, clip, None) + tile_model, regions = apply_attention_mask(w, model, tile_cond, clip, None) tile_model = apply_regional_ip_adapter(w, tile_model, tile_cond.regions, None, models) + positive, negative = encode_text_prompt(w, tile_cond, clip, regions) control = [tiled_control(c, i) for c in tile_cond.all_control] - tile_model, tile_pos, tile_neg = apply_control( + tile_model, positive, negative = apply_control( w, tile_model, positive, negative, control, None, vae, models ) latent = w.vae_encode(vae, tile_image) latent = w.set_latent_noise_mask(latent, tile_mask) sampler = w.sampler_custom_advanced( - tile_model, tile_pos, tile_neg, latent, models.arch, **_sampler_params(sampling) + tile_model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) tile_result = w.vae_decode(vae, sampler) out_image = w.merge_image_tile(out_image, tile_layout, i, tile_result) @@ -1195,8 +1241,7 @@ def prepare( i.conditioning.style = style.style_prompt for idx, region in enumerate(i.conditioning.regions): assert region.mask or idx == 0, "Only the first/bottom region can be without a mask" - region.positive, region_loras = extract_loras(region.positive, files.loras) - extra_loras += region_loras + region.positive, region.loras = extract_loras(region.positive, files.loras) i.sampling = _sampling_from_style(style, strength, is_live) i.sampling.seed = seed i.models = style.get_models(models.checkpoints.keys()) @@ -1204,7 +1249,7 @@ def prepare( i.models.loras = unique(i.models.loras + extra_loras, key=lambda l: l.name) arch = i.models.version = models.arch_of(i.models.checkpoint) - _check_server_has_models(i.models, models, files, style.name) + _check_server_has_models(i.models, i.conditioning.regions, models, files, style.name) _check_inpaint_model(inpaint, arch, models) model_set = models.for_arch(arch) @@ -1431,18 +1476,10 @@ def trigger_words(lora: LoraInput) -> str: return " " + result if result else "" -def _check_server_has_models( - input: CheckpointInput, models: ClientModels, files: FileLibrary, style_name: str +def _check_server_has_loras( + loras: list[LoraInput], models: ClientModels, files: FileLibrary, style_name: str, arch: Arch ): - if input.checkpoint not in models.checkpoints: - raise ValueError( - _( - "The checkpoint '{checkpoint}' used by style '{style}' is not available on the server", - checkpoint=input.checkpoint, - style=style_name, - ) - ) - for lora in input.loras: + for lora in loras: if lora.name not in models.loras: if lora_info := files.loras.find_local(lora.name): lora.storage_id = lora_info.compute_hash() @@ -1456,16 +1493,37 @@ def _check_server_has_models( ) for id, res in models.resources.items(): lora_arch = ResourceId.parse(id).arch - if lora.name == res and input.version is not lora_arch: + if lora.name == res and arch is not lora_arch: raise ValueError( _( "Model architecture mismatch for LoRA '{lora}': Cannot use {lora_arch} LoRA with a {checkpoint_arch} checkpoint.", lora=lora.name, lora_arch=lora_arch.value, - checkpoint_arch=input.version.value, + checkpoint_arch=arch.value, ) ) + +def _check_server_has_models( + input: CheckpointInput, + regions: list[RegionInput], + models: ClientModels, + files: FileLibrary, + style_name: str, +): + if input.checkpoint not in models.checkpoints: + raise ValueError( + _( + "The checkpoint '{checkpoint}' used by style '{style}' is not available on the server", + checkpoint=input.checkpoint, + style=style_name, + ) + ) + + _check_server_has_loras(input.loras, models, files, style_name, input.version) + for region in regions: + _check_server_has_loras(region.loras, models, files, style_name, input.version) + if input.vae != StyleSettings.vae.default and input.vae not in models.vae: raise ValueError( _( diff --git a/tests/images/truck_landscape_lines.webp b/tests/images/truck_landscape_lines.webp new file mode 100644 index 0000000000..84001b26bb Binary files /dev/null and b/tests/images/truck_landscape_lines.webp differ diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 4c8bc6bbe8..146a7ff28b 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -461,6 +461,35 @@ def test_regions_ip_adapter(qtapp, client: Client): run_and_save(qtapp, client, job, f"test_regions_ip_adapter") +def test_regions_lora(qtapp, client: Client): + files = FileLibrary.instance() + files.loras.add(File.local(test_dir / "data" / "LowRA.safetensors")) + files.loras.add(File.local(test_dir / "data" / "Ink scenery.safetensors")) + root_text = "snowy landscape, tundra, illustration, truck in the distance" + lines = Image.load(image_dir / "truck_landscape_lines.webp") + prompt = ConditioningInput(root_text) + prompt.regions = [ + RegionInput( + Mask.load(image_dir / "region_mask_bg.png").to_image(), + Bounds(0, 0, 1024, 1024), + "frozen lake . " + root_text, + ), + RegionInput( + Mask.load(image_dir / "region_mask_3.png").to_image(), + Bounds(600, 150, 424, 600), + "truck on an abandoned road . " + root_text, + ), + RegionInput( + Mask.load(image_dir / "region_mask_2.png").to_image(), + Bounds(0, 250, 355, 700), + "ink scenery, mountains, trees, " + root_text, + ), + ] + prompt.control = [ControlInput(ControlMode.soft_edge, lines, 0.5)] + job = create(WorkflowKind.generate, client, canvas=Extent(1024, 1024), cond=prompt, files=files) + run_and_save(qtapp, client, job, f"test_regions_lora") + + @pytest.mark.parametrize( "op", ["generate", "inpaint", "refine", "refine_region", "inpaint_upscale"] )