diff --git a/img_styler/image_prompt/stable_diffusion.py b/img_styler/image_prompt/stable_diffusion.py index c98f3be..cc4f5cb 100644 --- a/img_styler/image_prompt/stable_diffusion.py +++ b/img_styler/image_prompt/stable_diffusion.py @@ -3,19 +3,27 @@ from typing import Optional import torch -from diffusers import DDIMScheduler, LMSDiscreteScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline +from diffusers import ( + DDIMScheduler, + LMSDiscreteScheduler, + StableDiffusionImg2ImgPipeline, + StableDiffusionPipeline, +) from PIL import Image from torch import autocast -def generate_image_with_prompt(input_img_path: Optional[str]=None, prompt_txt: str = "Face portrait", - n_steps: int = 50, - guidance_scale: int = 7.5, - sampler_type: str = "K-LMS", - output_path: str=None): +def generate_image_with_prompt( + input_img_path: Optional[str] = None, + prompt_txt: str = "Face portrait", + n_steps: int = 50, + guidance_scale: int = 7.5, + sampler_type: str = "DDIM", + output_path: str = None, +): # License: https://huggingface.co/spaces/CompVis/stable-diffusion-license torch.cuda.empty_cache() - device = 'cuda' if torch.cuda.is_available() else 'cpu' + device = "cuda" if torch.cuda.is_available() else "cpu" model_path = "./models/stable_diffusion_v1_4" # Default Scheduler K-LMS(Katherine Crowson) @@ -23,15 +31,22 @@ def generate_image_with_prompt(input_img_path: Optional[str]=None, prompt_txt: s sampler = None if sampler_type == "K-LMS": sampler = LMSDiscreteScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + elif sampler_type == "DDIM": + # https://arxiv.org/abs/2010.02502 + sampler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, - beta_schedule="scaled_linear") - elif sampler_type == "DDIM": - sampler = DDIMScheduler() + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) if input_img_path: - pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_path, revision="fp16", - torch_dtype=torch.float16).to(device) + pipe = StableDiffusionImg2ImgPipeline.from_pretrained( + model_path, revision="fp16", torch_dtype=torch.float16 + ).to(device) if sampler: pipe.scheduler = sampler # Open image @@ -39,22 +54,31 @@ def generate_image_with_prompt(input_img_path: Optional[str]=None, prompt_txt: s init_image = image_input.resize((512, 512)) with autocast(device): - images = pipe(prompt=prompt_txt, init_image=init_image, strength=0.5, guidance_scale=guidance_scale, - num_inference_steps=n_steps)["sample"] - else: # Default prompt + images = pipe( + prompt=prompt_txt, + init_image=init_image, + strength=0.5, + guidance_scale=guidance_scale, + num_inference_steps=n_steps, + )["sample"] + else: # Default prompt generator = torch.Generator(device=device).manual_seed(42) - pipe = StableDiffusionPipeline.from_pretrained(model_path, revision="fp16", - torch_dtype=torch.float16).to(device) + pipe = StableDiffusionPipeline.from_pretrained( + model_path, revision="fp16", torch_dtype=torch.float16 + ).to(device) if sampler: pipe.scheduler = sampler with autocast(device): # One sample for now. # TODO Extend for multiple samples. - images = pipe(prompt=[prompt_txt]*1, num_inference_steps=n_steps, guidance_scale=guidance_scale, - generator=generator).images - + images = pipe( + prompt=[prompt_txt] * 1, + num_inference_steps=n_steps, + guidance_scale=guidance_scale, + generator=generator, + ).images - file_name = output_path + '/result.jpg' + file_name = output_path + "/result.jpg" if output_path: images[0].save(file_name) # Release resources diff --git a/img_styler/ui/handlers.py b/img_styler/ui/handlers.py index 1e7d68e..8eae25b 100644 --- a/img_styler/ui/handlers.py +++ b/img_styler/ui/handlers.py @@ -14,12 +14,24 @@ from GFPGAN.inference_gfpgan import init_gfpgan, restore_image -from ..caller import apply_projection, generate_gif, generate_projection, generate_style_frames, synthesize_new_img +from ..caller import ( + apply_projection, + generate_gif, + generate_projection, + generate_style_frames, + synthesize_new_img, +) from ..image_prompt.stable_diffusion import generate_image_with_prompt from ..latent_editor import edit_image, load_latent_vectors from ..utils.dataops import buf2img, get_files_in_dir, remove_file from .capture import capture_img, draw_boundary, html_str, js_schema -from .common import progress_generate_gif, update_controls, update_faces, update_gif, update_processed_face +from .common import ( + progress_generate_gif, + update_controls, + update_faces, + update_gif, + update_processed_face, +) from .components import get_header, get_meta, get_user_title PRE_COMPUTED_PROJECTION_PATH = "./z_output" @@ -29,7 +41,7 @@ @on() async def close_dialog(q: Q): - q.page['meta'].dialog = None + q.page["meta"].dialog = None await q.page.save() @@ -52,12 +64,12 @@ async def change_theme(q: Q): async def process(q: Q): - logger.debug(f'Source_face {q.args.source_face}') - logger.debug(f'Style_face {q.args.style_face}') - logger.debug(f'Z Low {q.args.z_low}') - logger.debug(f'Z High {q.args.z_high}') + logger.debug(f"Source_face {q.args.source_face}") + logger.debug(f"Style_face {q.args.style_face}") + logger.debug(f"Z Low {q.args.z_low}") + logger.debug(f"Z High {q.args.z_high}") - hash = q.args['#'] + hash = q.args["#"] if q.args.task_dropdown and q.client.task_choice != q.args.task_dropdown: logger.info(f"Task selection: {q.args.task_dropdown}") q.client.task_choice = q.args.task_dropdown @@ -75,9 +87,9 @@ async def process(q: Q): q.client.z_high = int(q.args.z_high) if q.args.generate_gif: await progress_generate_gif(q) - style_type = q.client.source_style[len('style_'):] + style_type = q.client.source_style[len("style_") :] q.client.gif_path = generate_gif(q.client.source_face, 15, style_type) - out_path = 'GFPGAN/output/temp.png' + out_path = "GFPGAN/output/temp.png" if q.args.fix_resolution and (q.client.processedimg or q.client.source_face): q.client.restorer = q.client.restorer or init_gfpgan() img_path = q.client.processedimg or q.client.source_face @@ -86,25 +98,31 @@ async def process(q: Q): if q.args.save_img_to_list: new_img_path = os.path.join(INPUT_PATH, q.args.img_name) if os.path.exists(new_img_path): - q.page['meta'] = ui.meta_card(box='', notification_bar=ui.notification_bar( - text=f'Image by the name "{q.args.img_name}" already exists!', - type='error', - position='bottom-left', - )) + q.page["meta"] = ui.meta_card( + box="", + notification_bar=ui.notification_bar( + text=f'Image by the name "{q.args.img_name}" already exists!', + type="error", + position="bottom-left", + ), + ) else: os.rename(out_path, new_img_path) temp_img = Image.open(new_img_path) - temp_img.save(os.path.join(INPUT_PATH, 'portrait.jpg')) + temp_img.save(os.path.join(INPUT_PATH, "portrait.jpg")) q.app.source_faces = get_files_in_dir(dir_path=INPUT_PATH) q.client.processedimg = new_img_path - q.page['meta'] = ui.meta_card(box='', notification_bar=ui.notification_bar( - text='Image added to list!', - type='success', - position='bottom-left', - )) + q.page["meta"] = ui.meta_card( + box="", + notification_bar=ui.notification_bar( + text="Image added to list!", + type="success", + position="bottom-left", + ), + ) await q.page.save() - del q.page['meta'] + del q.page["meta"] await update_controls(q) await update_faces(q) @@ -112,7 +130,7 @@ async def process(q: Q): await update_gif(q) if q.args.apply: await apply(q) - if hash == 'capture': + if hash == "capture": await capture(q) elif q.args.upload_image_dialog: await upload_image_dialog(q) @@ -153,23 +171,23 @@ async def img_capture_save(q: Q): # Set the current source face as the captured image. q.app.source_face = q.client.source_face = file_name # Return to home page. - q.page['meta'].redirect = "/" + q.page["meta"].redirect = "/" await q.page.save() @on() async def img_capture_done(q: Q): logger.debug(f"Exit image capture.") - q.page['meta'].redirect = "/" + q.page["meta"].redirect = "/" await q.page.save() @on() async def home(q: Q): q.page.drop() - q.page['meta'] = get_meta(q) - q.page['header'] = get_header(q) - q.page['user_title'] = get_user_title(q) + q.page["meta"] = get_meta(q) + q.page["header"] = get_header(q) + q.page["user_title"] = get_user_title(q) await update_controls(q) await update_faces(q) await q.page.save() @@ -184,15 +202,15 @@ def reset_edit_results(q: Q): q.client.gif_path = "" -@on('source_face', source_face_check) +@on("source_face", source_face_check) async def source_face(q: Q): - logger.debug('Calling source_face') + logger.debug("Calling source_face") await process(q) -@on('task_dropdown') +@on("task_dropdown") async def on_task_selection(q: Q): - logger.info('Selecting task choice') + logger.info("Selecting task choice") await process(q) @@ -200,9 +218,9 @@ def style_face_check(q: Q, style_face_arg: str) -> bool: return style_face_arg != q.client.style_face -@on('style_face', style_face_check) +@on("style_face", style_face_check) async def style_face(q: Q): - logger.debug('Calling style_face') + logger.debug("Calling style_face") await process(q) @@ -210,9 +228,9 @@ def z_low_check(q: Q, z_low_arg: str) -> bool: return int(z_low_arg) != q.client.z_low -@on('z_low', z_low_check) +@on("z_low", z_low_check) async def z_low(q: Q): - logger.debug('Calling z_low') + logger.debug("Calling z_low") await process(q) @@ -220,18 +238,18 @@ def z_high_check(q: Q, z_high_arg: str) -> bool: return int(z_high_arg) != q.client.z_high -@on('z_high', z_high_check) +@on("z_high", z_high_check) async def z_high(q: Q): - logger.debug('Calling z_high') + logger.debug("Calling z_high") await process(q) @on() async def upload_image_dialog(q: Q): - q.page['meta'].dialog = ui.dialog( - title='Upload Image', + q.page["meta"].dialog = ui.dialog( + title="Upload Image", closable=True, - items=[ui.file_upload(name='image_upload', label='Upload')], + items=[ui.file_upload(name="image_upload", label="Upload")], ) await q.page.save() @@ -240,34 +258,42 @@ async def upload_image_dialog(q: Q): @on() async def image_upload(q: Q): q.page.drop() - q.page['meta'] = get_meta(q) - q.page['header'] = get_header(q) - q.page['user_title'] = get_user_title(q) + q.page["meta"] = get_meta(q) + q.page["header"] = get_header(q) + q.page["user_title"] = get_user_title(q) if q.args.image_upload: - local_path = await q.site.download(q.args.image_upload[0], './images/') - encoded = base64.b64encode(open(local_path, "rb").read()).decode('ascii') - _img = 'data:image/png;base64,{}'.format(encoded) + local_path = await q.site.download(q.args.image_upload[0], "./images/") + encoded = base64.b64encode(open(local_path, "rb").read()).decode("ascii") + _img = "data:image/png;base64,{}".format(encoded) q.client.current_img = _img facial_feature_analysis(q, local_path, "Uploaded Image") await q.page.save() -@on('prompt_apply') +@on("prompt_apply") async def prompt_apply(q: Q): logger.info(f"Enable prompt.") logger.info(f"Prompt value: {q.args.prompt_textbox}") logger.info(f"Number of steps: {q.args.diffusion_n_steps}") logger.info(f"Guidance scale: {q.args.prompt_guidance_scale}") + logger.info(f"Sampler choice: {q.args.df_sampling_dropdown}") if q.args.prompt_use_source_img: - res_path = generate_image_with_prompt(input_img_path=q.client.source_face, prompt_txt=q.args.prompt_textbox, - n_steps=q.args.diffusion_n_steps, - guidance_scale=q.args.prompt_guidance_scale, - output_path=OUTPUT_PATH) - else: # Don't initialize with source image - res_path = generate_image_with_prompt(prompt_txt=q.args.prompt_textbox, - n_steps=q.args.diffusion_n_steps, - guidance_scale=q.args.prompt_guidance_scale, - output_path=OUTPUT_PATH) + res_path = generate_image_with_prompt( + input_img_path=q.client.source_face, + prompt_txt=q.args.prompt_textbox, + n_steps=q.args.diffusion_n_steps, + guidance_scale=q.args.prompt_guidance_scale, + sampler_type=q.args.df_sampling_dropdown, + output_path=OUTPUT_PATH, + ) + else: # Don't initialize with source image + res_path = generate_image_with_prompt( + prompt_txt=q.args.prompt_textbox, + n_steps=q.args.diffusion_n_steps, + guidance_scale=q.args.prompt_guidance_scale, + sampler_type=q.args.df_sampling_dropdown, + output_path=OUTPUT_PATH, + ) q.client.prompt_textbox = q.args.prompt_textbox q.client.diffusion_n_steps = q.args.diffusion_n_steps @@ -277,16 +303,16 @@ async def prompt_apply(q: Q): await update_processed_face(q) -@on('#capture') +@on("#capture") async def capture(q: Q): if q.args.img_capture_save: await img_capture_save(q) else: logger.debug("Capture clicked.") q.page.drop() - q.page['meta'] = get_meta(q) - q.page['header'] = get_header(q) - q.page['user_title'] = get_user_title(q) + q.page["meta"] = get_meta(q) + q.page["header"] = get_header(q) + q.page["user_title"] = get_user_title(q) _img = await capture_img(q) q.client.current_img = _img @@ -296,21 +322,21 @@ async def capture(q: Q): facial_feature_analysis(q, _img) elif q.args.exit_camera: # Return to home page. - q.page['meta'].redirect = "/" + q.page["meta"].redirect = "/" await q.page.save() else: - q.page['meta'].script = ui.inline_script( - content=js_schema, requires=[], targets=['video'] + q.page["meta"].script = ui.inline_script( + content=js_schema, requires=[], targets=["video"] ) - q.page['plot'] = ui.markup_card( - box=ui.box('middle_left', order=2, height='950px', width='950px'), - title='', + q.page["plot"] = ui.markup_card( + box=ui.box("middle_left", order=2, height="950px", width="950px"), + title="", content=html_str, ) # TODO Replace css styling - q.page['meta'].stylesheets = [ + q.page["meta"].stylesheets = [ ui.stylesheet( - path='https://cdn.jsdelivr.net/npm/bootstrap@5.1.0/dist/css/bootstrap.min.css' + path="https://cdn.jsdelivr.net/npm/bootstrap@5.1.0/dist/css/bootstrap.min.css" ) ] @@ -334,7 +360,7 @@ def rotate_face(image_path: str): def facial_feature_analysis(q: Q, img_path: str, title="Clicked Image"): models = {} - models['emotion'] = DeepFace.build_model('Emotion') + models["emotion"] = DeepFace.build_model("Emotion") # MTCNN (performed better than RetinaFace for the sample images tried). # If face is not detected; it's probably b'cauz of orientation # Naive approach: @@ -342,7 +368,12 @@ def facial_feature_analysis(q: Q, img_path: str, title="Clicked Image"): # Rotate -> ['Left', 'Right', 'Up', 'Down'] for _ in range(4): try: - obj = DeepFace.analyze(img_path=img_path, models=models, actions=['emotion'], detector_backend='mtcnn') + obj = DeepFace.analyze( + img_path=img_path, + models=models, + actions=["emotion"], + detector_backend="mtcnn", + ) if obj and len(obj) > 0: break except ValueError as ve: @@ -352,42 +383,42 @@ def facial_feature_analysis(q: Q, img_path: str, title="Clicked Image"): pass logger.info(f"Facial Attributes: {obj}") - dominant_emotion = obj['dominant_emotion'] + dominant_emotion = obj["dominant_emotion"] logger.info(f"Dominant emotion: {dominant_emotion}") # Draw bounding box around the face _im = q.client.current_img - _img = _im.split(',')[1] + _img = _im.split(",")[1] base64_decoded = base64.b64decode(_img) image = Image.open(io.BytesIO(base64_decoded)) img_np = np.array(image) - x = obj['region']['x'] - y = obj['region']['y'] - w = obj['region']['w'] - h = obj['region']['h'] + x = obj["region"]["x"] + y = obj["region"]["y"] + w = obj["region"]["w"] + h = obj["region"]["h"] img_w_box2 = draw_boundary(img_np, x, y, w, h, text=dominant_emotion) pil_img = Image.fromarray(img_w_box2) buff = BytesIO() - pil_img = pil_img.convert('RGB') + pil_img = pil_img.convert("RGB") pil_img.save(buff, format="JPEG") new_image_encoded = base64.b64encode(buff.getvalue()).decode("utf-8") img_format = "data:image/png;base64," # Update image new_image = img_format + new_image_encoded - q.page['capture_img'] = ui.form_card( - box=ui.box('middle_left'), + q.page["capture_img"] = ui.form_card( + box=ui.box("middle_left"), title=title, items=[ - ui.image("Captured Image", path=new_image, width='550px'), + ui.image("Captured Image", path=new_image, width="550px"), ui.buttons( items=[ - ui.button('img_capture_save', 'Save & Exit', icon='Save'), - ui.button('img_capture_done', 'Ignore & Exit', icon='ChromeClose'), + ui.button("img_capture_save", "Save & Exit", icon="Save"), + ui.button("img_capture_done", "Ignore & Exit", icon="ChromeClose"), ] - ) - ] + ), + ], ) @@ -402,18 +433,18 @@ async def apply(q: Q): logger.debug(f"Other values: {z_low}/{z_high}") # Use pre-computed projections - source_img_name = source_face.rsplit('.', 1)[0].split('./images/')[1] - style_img_name = style_face.rsplit('.', 1)[0].split('./images/')[1] + source_img_name = source_face.rsplit(".", 1)[0].split("./images/")[1] + style_img_name = style_face.rsplit(".", 1)[0].split("./images/")[1] source_img_proj = f"{PRE_COMPUTED_PROJECTION_PATH}/{source_img_name}.npz" style_img_proj = f"{PRE_COMPUTED_PROJECTION_PATH}/{style_img_name}.npz" source_img_proj_path = Path(source_img_proj) style_img_proj_path = Path(style_img_proj) new_img = None - style_type = '' - file_name = '' - if q.client.task_choice == 'C': # Image Editing - q.client.source_style = q.args.source_style or 'style_none' + style_type = "" + file_name = "" + if q.client.task_choice == "C": # Image Editing + q.client.source_style = q.args.source_style or "style_none" q.client.age_slider = q.args.age_slider if q.args.age_slider else 0 q.client.eye_distance = q.args.eye_distance if q.args.eye_distance else 0 q.client.eyebrow_distance = ( @@ -435,7 +466,7 @@ async def apply(q: Q): q.client.roll = q.args.roll if q.args.roll else 0 q.client.smile = q.args.smile if q.args.smile else 0 q.client.yaw = q.args.yaw if q.args.yaw else 0 - latent_info = load_latent_vectors('./models/stylegan2_attributes/') + latent_info = load_latent_vectors("./models/stylegan2_attributes/") # Update feature info for image editing # Dictionary format: @@ -443,27 +474,27 @@ async def apply(q: Q): # "feature_name": "value" # } f_i = { - 'age': q.client.age_slider, - 'eye_distance': q.client.eye_distance, - 'eye_eyebrow_distance': q.client.eyebrow_distance, - 'eye_ratio': q.client.eye_ratio, - 'eyes_open': q.client.eyes_open, - 'gender': q.client.gender, - 'lip_ratio': q.client.lip_ratio, - 'mouth_open': q.client.mouth_open, - 'mouth_ratio': q.client.mouth_ratio, - 'nose_mouth_distance': q.client.nose_mouth_distance, - 'nose_ratio': q.client.nose_ratio, - 'nose_tip': q.client.nose_tip, - 'pitch': q.client.pitch, - 'roll': q.client.roll, - 'smile': q.client.smile, - 'yaw': q.client.yaw, + "age": q.client.age_slider, + "eye_distance": q.client.eye_distance, + "eye_eyebrow_distance": q.client.eyebrow_distance, + "eye_ratio": q.client.eye_ratio, + "eyes_open": q.client.eyes_open, + "gender": q.client.gender, + "lip_ratio": q.client.lip_ratio, + "mouth_open": q.client.mouth_open, + "mouth_ratio": q.client.mouth_ratio, + "nose_mouth_distance": q.client.nose_mouth_distance, + "nose_ratio": q.client.nose_ratio, + "nose_tip": q.client.nose_tip, + "pitch": q.client.pitch, + "roll": q.client.roll, + "smile": q.client.smile, + "yaw": q.client.yaw, } - style_type = q.client.source_style[len('style_'):] + style_type = q.client.source_style[len("style_") :] logger.debug(f"Source style: {style_type}") - if source_img_proj_path.is_file() and style_type == 'none': + if source_img_proj_path.is_file() and style_type == "none": mlc = edit_image( latent_info, source_img_proj_path, @@ -480,8 +511,8 @@ async def apply(q: Q): source_img_proj_path = Path(source_img_proj) mlc = edit_image(latent_info, source_img_proj_path, f_i) - if style_type != 'none': - file_name = OUTPUT_PATH + f'/{source_img_name}.jpg' + if style_type != "none": + file_name = OUTPUT_PATH + f"/{source_img_name}.jpg" new_img = generate_style_frames(mlc, style_type, file_name) else: new_img = synthesize_new_img(mlc) @@ -490,7 +521,7 @@ async def apply(q: Q): logger.debug(f"Saving to {edit_img_lc}") edit_img_lc_path = Path(edit_img_lc) np.savez(edit_img_lc_path, x=mlc) - elif q.client.task_choice == 'B': # Image Styling + elif q.client.task_choice == "B": # Image Styling # Check if precomputed latent space for source img exists if source_img_proj_path.is_file() & style_img_proj_path.is_file(): swap_idxs = (z_low, z_high) @@ -521,7 +552,9 @@ async def apply(q: Q): # Save new generated img locally if not file_name: - file_name = f"{OUTPUT_PATH}/{source_img_name}_{style_img_name}_{z_low}-{z_high}.jpg" + file_name = ( + f"{OUTPUT_PATH}/{source_img_name}_{style_img_name}_{z_low}-{z_high}.jpg" + ) logger.debug(f"Generate img: {file_name}") if new_img: new_img.save(file_name)