Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Regional Prompting (Node only) #5916

Merged
merged 23 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
46d83a3
Add a MaskField primitive, and add a mask to the ConditioningField pr…
RyanJDick Feb 13, 2024
1da8423
Add RectangleMaskInvocation.
RyanJDick Mar 8, 2024
bf3ee1f
Update compel nodes to accept an optional prompt mask.
RyanJDick Mar 8, 2024
ef9e0c9
Remove scheduler_args from ConditioningData structure.
RyanJDick Feb 28, 2024
7fe6f03
Split ip_adapter_conditioning out from ConditioningData.
RyanJDick Feb 28, 2024
8923289
Rename ConditioningData -> TextConditioningData.
RyanJDick Mar 8, 2024
b76bb45
Improve documentation of conditioning_data.py.
RyanJDick Mar 8, 2024
c059bc3
Add TextConditioningRegions to the TextConditioningData data structure.
RyanJDick Mar 8, 2024
93056e4
Add support for lists of prompt embeddings to be passed to the Denois…
RyanJDick Mar 8, 2024
dc90ff2
Add RegionalPromptData class for managing prompt region masks.
RyanJDick Mar 8, 2024
b76720f
Initialize a RegionalPromptAttnProcessor2_0 class by copying AttnProc…
RyanJDick Feb 15, 2024
203d4a6
Update CustomAttention to support both IP-Adapters and regional promp…
RyanJDick Mar 8, 2024
787a085
Create a UNetAttentionPatcher for patching UNet models with CustomAtt…
RyanJDick Mar 8, 2024
ee34091
Update the diffusion logic to use the new regional prompting feature.
RyanJDick Mar 8, 2024
4f97192
(minor) The latest ruff version has _slightly_ different formatting p…
RyanJDick Mar 11, 2024
b0edf59
Merge branch 'main' into ryan/regional-prompting-naive
psychedelicious Apr 8, 2024
3a531c5
feat(nodes): add prompt region from image nodes
psychedelicious Apr 8, 2024
98900a7
Pull the upstream changes from diffusers' AttnProcessor2_0 into Custo…
RyanJDick Apr 8, 2024
826f3d6
Fix dimensions of mask produced by ExtractMasksAndPromptsInvocation. …
RyanJDick Apr 8, 2024
26a2b23
Rename MaskField to be a generice TensorField.
RyanJDick Apr 8, 2024
eb32842
Add utility to_standard_float_mask(...) to convert various mask forma…
RyanJDick Apr 8, 2024
7bd902a
Merge branch 'main' into ryan/regional-prompting-naive
psychedelicious Apr 9, 2024
3e61d5f
Revert "feat(nodes): add prompt region from image nodes"
psychedelicious Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
Input,
InputField,
OutputField,
TensorField,
UIComponent,
)
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.ti_utils import generate_ti_list
Expand Down Expand Up @@ -36,7 +44,7 @@
title="Prompt",
tags=["prompt", "compel"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class CompelInvocation(BaseInvocation):
"""Parse prompt using compel package to conditioning."""
Expand All @@ -51,6 +59,9 @@ class CompelInvocation(BaseInvocation):
description=FieldDescriptions.clip,
input=Input.Connection,
)
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
Expand Down Expand Up @@ -117,8 +128,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
)

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)


class SDXLPromptInvocationBase:
Expand Down Expand Up @@ -232,7 +247,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
title="SDXL Prompt",
tags=["sdxl", "compel", "prompt"],
category="conditioning",
version="1.1.1",
version="1.2.0",
)
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
Expand All @@ -255,6 +270,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
target_height: int = InputField(default=1024, description="")
clip: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1")
clip2: CLIPField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2")
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
Expand Down Expand Up @@ -317,7 +335,12 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:

conditioning_name = context.conditioning.save(conditioning_data)

return ConditioningOutput.build(conditioning_name)
return ConditioningOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
mask=self.mask,
)
)


@invocation(
Expand Down
12 changes: 11 additions & 1 deletion invokeai/app/invocations/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ class DenoiseMaskField(BaseModel):
gradient: bool = Field(default=False, description="Used for gradient inpainting")


class TensorField(BaseModel):
"""A tensor primitive field."""

tensor_name: str = Field(description="The name of a tensor.")


class LatentsField(BaseModel):
"""A latents tensor primitive field"""

Expand All @@ -226,7 +232,11 @@ class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

conditioning_name: str = Field(description="The name of conditioning tensor")
# endregion
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)


class MetadataField(RootModel[dict[str, Any]]):
Expand Down
Loading
Loading