Skip to content

Commit

Permalink
Add backend functions and classes for Flux implementation, Update the…
Browse files Browse the repository at this point in the history
… way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded
  • Loading branch information
brandonrising committed Aug 16, 2024
1 parent 53052cf commit f4f5c46
Show file tree
Hide file tree
Showing 19 changed files with 1,340 additions and 197 deletions.
31 changes: 12 additions & 19 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch


from einops import repeat

Check failure on line 4 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/flux_text_encoder.py:4:20: F401 `einops.repeat` imported but unused
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline

Check failure on line 5 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/flux_text_encoder.py:5:52: F401 `diffusers.pipelines.flux.pipeline_flux.FluxPipeline` imported but unused
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
Expand All @@ -9,6 +12,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice

Check failure on line 14 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F401)

invokeai/app/invocations/flux_text_encoder.py:14:43: F401 `invokeai.backend.util.devices.TorchDevice` imported but unused
from invokeai.backend.flux.modules.conditioner import HFEncoder


@invocation(

Check failure on line 18 in invokeai/app/invocations/flux_text_encoder.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/app/invocations/flux_text_encoder.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -69,26 +73,15 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc
assert isinstance(clip_text_encoder, CLIPTextModel)
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(clip_tokenizer, CLIPTokenizer)
assert isinstance(t5_tokenizer, T5TokenizerFast)
assert isinstance(t5_tokenizer, T5Tokenizer)

clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, max_seq_len)

pipeline = FluxPipeline(
scheduler=None,
vae=None,
text_encoder=clip_text_encoder,
tokenizer=clip_tokenizer,
text_encoder_2=t5_text_encoder,
tokenizer_2=t5_tokenizer,
transformer=None,
)
prompt = [self.positive_prompt]
prompt_embeds = t5_encoder(prompt)

# prompt_embeds: T5 embeddings
# pooled_prompt_embeds: CLIP embeddings
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=self.positive_prompt,
prompt_2=self.positive_prompt,
device=TorchDevice.choose_torch_device(),
max_sequence_length=max_seq_len,
)
pooled_prompt_embeds = clip_encoder(prompt)

assert isinstance(prompt_embeds, torch.Tensor)
assert isinstance(pooled_prompt_embeds, torch.Tensor)
Expand Down
4 changes: 1 addition & 3 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ def _run_diffusion(
clip_embeddings: torch.Tensor,
t5_embeddings: torch.Tensor,
):
scheduler_info = context.models.load(self.transformer.scheduler)
transformer_info = context.models.load(self.transformer.transformer)

# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

with transformer_info as transformer, scheduler_info as scheduler:
with transformer_info as transformer:
assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)

flux_pipeline_with_transformer = FluxPipeline(
scheduler=scheduler,

Check failure on line 99 in invokeai/app/invocations/flux_text_to_image.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F821)

invokeai/app/invocations/flux_text_to_image.py:99:27: F821 Undefined name `scheduler`
Expand Down
81 changes: 66 additions & 15 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from typing import List, Optional
from time import sleep
from typing import List, Optional, Literal, Dict

from pydantic import BaseModel, Field

Expand All @@ -13,7 +14,8 @@
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat


class ModelIdentifierField(BaseModel):

Check failure on line 21 in invokeai/app/invocations/model.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (I001)

invokeai/app/invocations/model.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -62,7 +64,6 @@ class CLIPField(BaseModel):

class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")


class T5EncoderField(BaseModel):
Expand Down Expand Up @@ -131,6 +132,30 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:

return ModelIdentifierOutput(model=self.model)

T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": {
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"text_encoder_repo": "hf_repo1",
"tokenizer_repo": "hf_repo1",
"text_encoder_name": "hf_repo1",
"tokenizer_name": "hf_repo1",
"format": ModelFormat.T5Encoder8b,
},
"4b_quantized": {
"text_encoder_repo": "hf_repo2",
"tokenizer_repo": "hf_repo2",
"text_encoder_name": "hf_repo2",
"tokenizer_name": "hf_repo2",
"format": ModelFormat.T5Encoder8b,
},
}

@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
Expand All @@ -151,29 +176,55 @@ class FluxModelLoaderInvocation(BaseInvocation):
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)

Check failure on line 179 in invokeai/app/invocations/model.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W293)

invokeai/app/invocations/model.py:179:1: W293 Blank line contains whitespace
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")

def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
model_key = self.model.key

# TODO: not found exceptions
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")

transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
transformer = self._get_model(context, SubModelType.Transformer)
tokenizer = self._get_model(context, SubModelType.Tokenizer)
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
transformer=TransformerField(transformer=transformer),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
match(submodel):
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
case SubModelType.TextEncoder2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
case SubModelType.Tokenizer2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")

Check failure on line 212 in invokeai/app/invocations/model.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (W291)

invokeai/app/invocations/model.py:212:99: W291 Trailing whitespace

def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
model_path = context.models.download_and_cache_model(repo_id)
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
while not model_install_job.in_terminal_state:
sleep(0.01)
if not model_install_job.config_out:
raise Exception(f"Failed to install {name}")
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})

@invocation(
"main_model_loader",
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
format: Optional[str] = Field(description="format of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def search_by_attr(
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError:
except pydantic.ValidationError as e:

Check failure on line 304 in invokeai/app/services/model_records/model_records_sql.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F841)

invokeai/app/services/model_records/model_records_sql.py:304:48: F841 Local variable `e` is assigned to but never used
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
Expand Down
15 changes: 15 additions & 0 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import (
Expand Down Expand Up @@ -463,6 +464,20 @@ def download_and_cache_model(
"""
return self._services.model_manager.install.download_and_cache_model(source=source)

def import_local_model(
self,
model_path: Path,
config: Optional[ModelRecordChanges] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
):
"""
TODO: Fill out description of this method
"""
if not model_path.exists():
raise Exception("Models provided to import_local_model must already exist on disk")
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)

def load_local_model(
self,
model_path: Path,
Expand Down
30 changes: 30 additions & 0 deletions invokeai/backend/flux/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from einops import rearrange
from torch import Tensor


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)

x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")

return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
111 changes: 111 additions & 0 deletions invokeai/backend/flux/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from dataclasses import dataclass

import torch
from torch import Tensor, nn

from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)

@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool


class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""

def __init__(self, params: FluxParams):
super().__init__()

self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)

self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]

img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
Loading

0 comments on commit f4f5c46

Please sign in to comment.