From c7e3f792606e79db8d7d723664a61baac274f740 Mon Sep 17 00:00:00 2001 From: Grigory Reznikov Date: Sun, 20 Oct 2024 16:20:19 +0000 Subject: [PATCH] Support image2image mode for Flux --- flux_minimal_inference.py | 73 ++++++++++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 7ab224f1b..c3a340504 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -137,10 +137,9 @@ def do_sample( l_pooled: torch.Tensor, t5_out: torch.Tensor, txt_ids: torch.Tensor, - num_steps: int, + timesteps: list[float], guidance: float, t5_attn_mask: Optional[torch.Tensor], - is_schnell: bool, device: torch.device, flux_dtype: torch.dtype, neg_l_pooled: Optional[torch.Tensor] = None, @@ -148,8 +147,7 @@ def do_sample( neg_t5_attn_mask: Optional[torch.Tensor] = None, cfg_scale: Optional[float] = None, ): - logger.info(f"num_steps: {num_steps}") - timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell) + logger.info(f"num_steps: {len(timesteps)}") # denoise initial noise if accelerator: @@ -196,6 +194,7 @@ def generate_image( t5xxl, ae, prompt: str, + image_path: Optional[str], seed: Optional[int], image_width: int, image_height: int, @@ -203,13 +202,18 @@ def generate_image( guidance: float, negative_prompt: Optional[str], cfg_scale: float, + strength: float, ): seed = seed if seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {seed}") + if steps is None: + steps = 4 if is_schnell else 50 + packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + timesteps = get_schedule(steps, packed_latent_height * packed_latent_width, shift=not is_schnell) + # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 - packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, @@ -220,14 +224,21 @@ def generate_image( generator=torch.Generator(device=device).manual_seed(seed), ) - # prepare img and img ids + if image_path: + image = Image.open(image_path).convert("RGB") + image = torch.tensor(np.array(image), device=device).permute(2, 0, 1).unsqueeze(0) + image = torch.nn.functional.interpolate(image, (image_height, image_width)) + image = image / 255.0 * 2.0 - 1.0 + image = image.to(device) + latents = ae.encode(image) + latents = flux_utils.pack_latents(latents) - # this is needed only for img2img - # img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) - # if img.shape[0] == 1 and bs > 1: - # img = repeat(img, "1 ... -> bs ...", bs=bs) + t_idx = int((1 - strength) * steps) + t = timesteps[t_idx] + timesteps = timesteps[t_idx:] + + noise = noise * t + latents * (1 - t) - # txt2img only needs img_ids img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width) # prepare fp8 models @@ -313,8 +324,6 @@ def encode(prpt: str): # generate image logger.info("Generating image...") model = model.to(device) - if steps is None: - steps = 4 if is_schnell else 50 img_ids = img_ids.to(device) t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None @@ -327,10 +336,9 @@ def encode(prpt: str): l_pooled, t5_out, txt_ids, - steps, + timesteps, guidance, t5_attn_mask, - is_schnell, device, flux_dtype, neg_l_pooled, @@ -362,13 +370,13 @@ def encode(prpt: str): x = x.clamp(-1, 1) x = x.permute(0, 2, 3, 1) - img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) + image = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0]) # save image output_dir = args.output_dir os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png") - img.save(output_path) + image.save(output_path) logger.info(f"Saved image to {output_path}") @@ -390,6 +398,7 @@ def encode(prpt: str): parser.add_argument("--ae", type=str, required=False) parser.add_argument("--apply_t5_attn_mask", action="store_true") parser.add_argument("--prompt", type=str, default="A photo of a cat") + parser.add_argument("--image_path", type=str, default=None) parser.add_argument("--output_dir", type=str, default=".") parser.add_argument("--dtype", type=str, default="bfloat16", help="base dtype") parser.add_argument("--clip_l_dtype", type=str, default=None, help="dtype for clip_l") @@ -401,6 +410,7 @@ def encode(prpt: str): parser.add_argument("--guidance", type=float, default=3.5) parser.add_argument("--negative_prompt", type=str, default=None) parser.add_argument("--cfg_scale", type=float, default=1.0) + parser.add_argument("--strength", type=float, default=0.8) parser.add_argument("--offload", action="store_true", help="Offload to CPU") parser.add_argument( "--lora_weights", @@ -512,6 +522,7 @@ def is_fp8(dt): t5xxl, ae, args.prompt, + args.image_path, args.seed, args.width, args.height, @@ -519,6 +530,7 @@ def is_fp8(dt): args.guidance, args.negative_prompt, args.cfg_scale, + args.strength, ) else: # loop for interactive @@ -527,11 +539,12 @@ def is_fp8(dt): steps = None guidance = args.guidance cfg_scale = args.cfg_scale + strength = args.strength while True: print( - "Enter prompt (empty to exit). Options: --w --h --s --d --g --m " - " --n , `-` for empty negative prompt --c " + "Enter prompt (empty to exit). Options: --w --h --s --d --i --r " + "--g --m --n , `-` for empty negative prompt --c " ) prompt = input() if prompt == "": @@ -542,6 +555,7 @@ def is_fp8(dt): prompt = options[0].strip() seed = None negative_prompt = None + image_path = None for opt in options[1:]: try: opt = opt.strip() @@ -553,6 +567,10 @@ def is_fp8(dt): steps = int(opt[1:].strip()) elif opt.startswith("d"): seed = int(opt[1:].strip()) + elif opt.startswith("i"): + image_path = opt[1:].strip() + elif opt.startswith("r"): + strength = float(opt[1:].strip()) elif opt.startswith("g"): guidance = float(opt[1:].strip()) elif opt.startswith("m"): @@ -571,6 +589,21 @@ def is_fp8(dt): except ValueError as e: logger.error(f"Invalid option: {opt}, {e}") - generate_image(model, clip_l, t5xxl, ae, prompt, seed, width, height, steps, guidance, negative_prompt, cfg_scale) + generate_image( + model, + clip_l, + t5xxl, + ae, + prompt, + image_path, + seed, + width, + height, + steps, + guidance, + negative_prompt, + cfg_scale, + strength, + ) logger.info("Done!")