Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support image2image mode for Flux #1713

Open
wants to merge 1 commit into
base: sd3
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 53 additions & 20 deletions flux_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,17 @@ def do_sample(
l_pooled: torch.Tensor,
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
timesteps: list[float],
guidance: float,
t5_attn_mask: Optional[torch.Tensor],
is_schnell: bool,
device: torch.device,
flux_dtype: torch.dtype,
neg_l_pooled: Optional[torch.Tensor] = None,
neg_t5_out: Optional[torch.Tensor] = None,
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
logger.info(f"num_steps: {num_steps}")
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
logger.info(f"num_steps: {len(timesteps)}")

# denoise initial noise
if accelerator:
Expand Down Expand Up @@ -196,20 +194,26 @@ def generate_image(
t5xxl,
ae,
prompt: str,
image_path: Optional[str],
seed: Optional[int],
image_width: int,
image_height: int,
steps: Optional[int],
guidance: float,
negative_prompt: Optional[str],
cfg_scale: float,
strength: float,
):
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")

if steps is None:
steps = 4 if is_schnell else 50
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
timesteps = get_schedule(steps, packed_latent_height * packed_latent_width, shift=not is_schnell)

# make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn(
1,
Expand All @@ -220,14 +224,21 @@ def generate_image(
generator=torch.Generator(device=device).manual_seed(seed),
)

# prepare img and img ids
if image_path:
image = Image.open(image_path).convert("RGB")
image = torch.tensor(np.array(image), device=device).permute(2, 0, 1).unsqueeze(0)
image = torch.nn.functional.interpolate(image, (image_height, image_width))
image = image / 255.0 * 2.0 - 1.0
image = image.to(device)
latents = ae.encode(image)
latents = flux_utils.pack_latents(latents)

# this is needed only for img2img
# img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
# if img.shape[0] == 1 and bs > 1:
# img = repeat(img, "1 ... -> bs ...", bs=bs)
t_idx = int((1 - strength) * steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]

noise = noise * t + latents * (1 - t)

# txt2img only needs img_ids
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)

# prepare fp8 models
Expand Down Expand Up @@ -313,8 +324,6 @@ def encode(prpt: str):
# generate image
logger.info("Generating image...")
model = model.to(device)
if steps is None:
steps = 4 if is_schnell else 50

img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
Expand All @@ -327,10 +336,9 @@ def encode(prpt: str):
l_pooled,
t5_out,
txt_ids,
steps,
timesteps,
guidance,
t5_attn_mask,
is_schnell,
device,
flux_dtype,
neg_l_pooled,
Expand Down Expand Up @@ -362,13 +370,13 @@ def encode(prpt: str):

x = x.clamp(-1, 1)
x = x.permute(0, 2, 3, 1)
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])

# save image
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
img.save(output_path)
image.save(output_path)

logger.info(f"Saved image to {output_path}")

Expand All @@ -390,6 +398,7 @@ def encode(prpt: str):
parser.add_argument("--ae", type=str, required=False)
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--image_path", type=str, default=None)
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype")
parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l")
Expand All @@ -401,6 +410,7 @@ def encode(prpt: str):
parser.add_argument("--guidance", type=float, default=3.5)
parser.add_argument("--negative_prompt", type=str, default=None)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--strength", type=float, default=0.8)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument(
"--lora_weights",
Expand Down Expand Up @@ -512,13 +522,15 @@ def is_fp8(dt):
t5xxl,
ae,
args.prompt,
args.image_path,
args.seed,
args.width,
args.height,
args.steps,
args.guidance,
args.negative_prompt,
args.cfg_scale,
args.strength,
)
else:
# loop for interactive
Expand All @@ -527,11 +539,12 @@ def is_fp8(dt):
steps = None
guidance = args.guidance
cfg_scale = args.cfg_scale
strength = args.strength

while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --g <guidance> --m <multipliers for LoRA>"
" --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed> --i <image_path> --r <strength> "
"--g <guidance> --m <multipliers for LoRA> --n <negative prompt>, `-` for empty negative prompt --c <cfg_scale>"
)
prompt = input()
if prompt == "":
Expand All @@ -542,6 +555,7 @@ def is_fp8(dt):
prompt = options[0].strip()
seed = None
negative_prompt = None
image_path = None
for opt in options[1:]:
try:
opt = opt.strip()
Expand All @@ -553,6 +567,10 @@ def is_fp8(dt):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
elif opt.startswith("i"):
image_path = opt[1:].strip()
elif opt.startswith("r"):
strength = float(opt[1:].strip())
elif opt.startswith("g"):
guidance = float(opt[1:].strip())
elif opt.startswith("m"):
Expand All @@ -571,6 +589,21 @@ def is_fp8(dt):
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")

generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale)
generate_image(
model,
clip_l,
t5xxl,
ae,
prompt,
image_path,
seed,
width,
height,
steps,
guidance,
negative_prompt,
cfg_scale,
strength,
)

logger.info("Done!")