Skip to content

Commit

Permalink
Run ruff, setup initial text to image node
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonrising committed Aug 19, 2024
1 parent f4f5c46 commit 043df07
Show file tree
Hide file tree
Showing 15 changed files with 330 additions and 155 deletions.
7 changes: 1 addition & 6 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import torch


from einops import repeat
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.flux.modules.conditioner import HFEncoder
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo


@invocation(
Expand Down
147 changes: 78 additions & 69 deletions invokeai/app/invocations/flux_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from typing import Literal

import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from einops import rearrange, repeat
from PIL import Image
from transformers.models.auto import AutoModelForTextEncoding

from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
Expand All @@ -19,20 +14,11 @@
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo

TFluxModelKeys = Literal["flux-schnell"]
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}


class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
base_class = FluxTransformer2DModel


class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
auto_class = AutoModelForTextEncoding
from invokeai.backend.util.devices import TorchDevice


@invocation(
Expand Down Expand Up @@ -75,7 +61,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
assert isinstance(flux_conditioning, FLUXConditioningInfo)

latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
image = self._run_vae_decoding(context, latents)
image = self._run_vae_decoding(context, flux_ae_path, latents)

Check failure on line 64 in invokeai/app/invocations/flux_text_to_image.py

View workflow job for this annotation

GitHub Actions / python-checks

Ruff (F821)

invokeai/app/invocations/flux_text_to_image.py:64:49: F821 Undefined name `flux_ae_path`
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

Expand All @@ -86,42 +72,79 @@ def _run_diffusion(
t5_embeddings: torch.Tensor,
):
transformer_info = context.models.load(self.transformer.transformer)
inference_dtype = TorchDevice.choose_torch_dtype()

# Prepare input noise.
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
# CPU RNG?
x = get_noise(
num_samples=1,
height=self.height,
width=self.width,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
seed=self.seed,
)

img, img_ids = self._prepare_latent_img_patches(x)

# HACK(ryand): Find a better way to determine if this is a schnell model or not.
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
timesteps = get_schedule(
num_steps=self.num_steps,
image_seq_len=img.shape[1],
shift=not is_schnell,
)

bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())

# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
# if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)

with transformer_info as transformer:
assert isinstance(transformer, FluxTransformer2DModel)

flux_pipeline_with_transformer = FluxPipeline(
scheduler=scheduler,
vae=None,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=transformer,
assert isinstance(transformer, Flux)

x = denoise(
model=transformer,
img=img,
img_ids=img_ids,
txt=t5_embeddings,
txt_ids=txt_ids,
vec=clip_embeddings,
timesteps=timesteps,
guidance=self.guidance,
)

t5_embeddings = t5_embeddings.to(dtype=transformer.dtype)
clip_embeddings = clip_embeddings.to(dtype=transformer.dtype)
x = unpack(x.float(), self.height, self.width)

return x

latents = flux_pipeline_with_transformer(
height=self.height,
width=self.width,
num_inference_steps=self.num_steps,
guidance_scale=self.guidance,
generator=torch.Generator().manual_seed(self.seed),
prompt_embeds=t5_embeddings,
pooled_prompt_embeds=clip_embeddings,
output_type="latent",
return_dict=False,
)[0]
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Convert an input image in latent space to patches for diffusion.
assert isinstance(latents, torch.Tensor)
return latents
This implementation was extracted from:
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
Returns:
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
"""
bs, c, h, w = latent_img.shape

# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
if img.shape[0] == 1 and bs > 1:
img = repeat(img, "1 ... -> bs ...", bs=bs)

# Generate patch position ids.
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

return img, img_ids

def _run_vae_decoding(
self,
Expand All @@ -130,27 +153,13 @@ def _run_vae_decoding(
) -> Image.Image:
vae_info = context.models.load(self.vae.vae)
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)

flux_pipeline_with_vae = FluxPipeline(
scheduler=None,
vae=vae,
text_encoder=None,
tokenizer=None,
text_encoder_2=None,
tokenizer_2=None,
transformer=None,
)
assert isinstance(vae, AutoEncoder)
# TODO(ryand): Test that this works with both float16 and bfloat16.
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
img = vae.decode(latents)

latents = flux_pipeline_with_vae._unpack_latents(
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
)
latents = (
latents / flux_pipeline_with_vae.vae.config.scaling_factor
) + flux_pipeline_with_vae.vae.config.shift_factor
latents = latents.to(dtype=vae.dtype)
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]

assert isinstance(image, Image.Image)
return image
img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())

return img_pil
95 changes: 60 additions & 35 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from time import sleep
from typing import List, Optional, Literal, Dict
from typing import Dict, List, Literal, Optional

from pydantic import BaseModel, Field

Expand All @@ -12,10 +12,10 @@
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType


class ModelIdentifierField(BaseModel):
Expand Down Expand Up @@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:

return ModelIdentifierOutput(model=self.model)

T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]

T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
"base": {
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
"name": "t5_base_encoder",
"format": ModelFormat.T5Encoder,
},
"8b_quantized": {
"text_encoder_repo": "hf_repo1",
"tokenizer_repo": "hf_repo1",
"text_encoder_name": "hf_repo1",
"tokenizer_name": "hf_repo1",
"format": ModelFormat.T5Encoder8b,
},
"4b_quantized": {
"text_encoder_repo": "hf_repo2",
"tokenizer_repo": "hf_repo2",
"text_encoder_name": "hf_repo2",
"tokenizer_name": "hf_repo2",
"format": ModelFormat.T5Encoder8b,
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
"name": "t5_8b_quantized_encoder",
"format": ModelFormat.T5Encoder,
},
}


@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
Expand All @@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)

t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")

def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
Expand All @@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
vae = self._install_model(
context,
SubModelType.VAE,
"FLUX.1-schnell_ae",
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
ModelFormat.Checkpoint,
ModelType.VAE,
BaseModelType.Flux,
)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer),
Expand All @@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
vae=VAEField(vae=vae),
)

def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
match(submodel):
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
match submodel:
case SubModelType.Transformer:
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
case SubModelType.TextEncoder2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
case SubModelType.Tokenizer2:
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
return self._install_model(
context,
submodel,
"clip-vit-large-patch14",
"openai/clip-vit-large-patch14",
ModelFormat.Diffusers,
ModelType.CLIPEmbed,
BaseModelType.Any,
)
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
return self._install_model(
context,
submodel,
T5_ENCODER_MAP[self.t5_encoder]["name"],
T5_ENCODER_MAP[self.t5_encoder]["repo"],
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
ModelType.T5Encoder,
BaseModelType.Any,
)
case _:
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")

def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")

def _install_model(
self,
context: InvocationContext,
submodel: SubModelType,
name: str,
repo_id: str,
format: ModelFormat,
type: ModelType,
base: BaseModelType,
):
if models := context.models.search_by_attrs(name=name, base=base, type=type):
if len(models) != 1:
raise Exception(f"Multiple models detected for selected model with name {name}")
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
else:
model_path = context.models.download_and_cache_model(repo_id)
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
while not model_install_job.in_terminal_state:
sleep(0.01)
if not model_install_job.config_out:
raise Exception(f"Failed to install {name}")
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
update={"submodel_type": submodel}
)


@invocation(
"main_model_loader",
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def search_by_attr(
for row in result:
try:
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
except pydantic.ValidationError as e:
except pydantic.ValidationError:
# We catch this error so that the app can still run if there are invalid model configs in the database.
# One reason that an invalid model config might be in the database is if someone had to rollback from a
# newer version of the app that added a new model type.
Expand Down
Loading

0 comments on commit 043df07

Please sign in to comment.