Skip to content

Commit

Permalink
Initial working basic Img2Img diffusion #5
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Oct 31, 2022
1 parent b99e57c commit fdcedde
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ sync_attr_models = s3cmd get --recursive --skip-existing s3://ai.h2o.wave-image-
sync_stgan_nada_models = s3cmd get --recursive --skip-existing s3://ai.h2o.wave-image-styler/public/models/stylegan_nada/ ./models/stylegan_nada/
sync_gfpgan_models = s3cmd get --recursive --skip-existing s3://ai.h2o.wave-image-styler/public/models/gfpgan/ ./models/gfpgan/
download_ffhq_model = wget -P ./models/ https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
download_sd_model = s3cmd get --recursive --skip-existing s3://h2o-model-gym/models/stable-diffusion-v1-4/ ./models/stable_diffusion_v1_4/

.PHONY: download_models

Expand All @@ -22,11 +23,13 @@ download_models:
mkdir -p models/stylegan2_attributes
mkdir -p models/stylegan_nada
mkdir -p models/gfpgan/
mkdir -p models/stable_diffusion_v1_4
$(download_ffhq_model)
$(sync_landmarks_model)
$(sync_attr_models)
$(sync_stgan_nada_models)
$(sync_gfpgan_models)
$(download_sd_model)

poetry: ## Install dependencies
poetry install -vvv
Expand Down
10 changes: 8 additions & 2 deletions img_styler/ui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ async def update_faces(q: Q, save=False):
q.page['source_face'] = get_source_face_card(
img2buf(q.client.source_face), type='jpg', height='520px', width='500px'
)
q.page['prompt_textbox'] = ui.form_card(ui.box('main', order=1, height='200px', width='900px'), items=[
ui.textbox(name='textbox', label='Prompt', multiline=True),

q.page['prompt_form'] = ui.form_card(ui.box('main', order=1, height='200px', width='900px'), items=[
ui.textbox(name='prompt_textbox', label='Prompt', multiline=True, value=q.client.prompt_textbox),
ui.button(name='prompt_apply', label='Apply')])

del q.page['style_face']
if q.client.task_choice == 'A':
q.page['style_face'] = get_style_face_card(
Expand All @@ -51,6 +53,10 @@ async def update_processed_face(q: Q, save=False):
q.page['processed_face'] = get_processed_face_card(
img_buf, title="Fixed Image", type='jpg', layout_pos='middle_right', order=2
)
if q.client.task_choice == 'C':
q.page['prompt_form'] = ui.form_card(ui.box('main', order=1, height='200px', width='900px'), items=[
ui.textbox(name='prompt_textbox', label='Prompt', multiline=True, value=q.client.prompt_textbox),
ui.button(name='prompt_apply', label='Apply')])
if save:
await q.page.save()

Expand Down
2 changes: 1 addition & 1 deletion img_styler/ui/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def get_controls(q: Q):
Image Editing (Edit and transform an existing image).",
),
ui.checkbox(name='checkbox_without_training', label='Without re-training', value=True),
ui.checkbox(name='checkbox_re_training', label='With re-training')])
ui.checkbox(name='checkbox_re_training', label='Dreambooth fine-tuning')])


def get_source_header():
Expand Down
16 changes: 15 additions & 1 deletion img_styler/ui/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dlib
import numpy as np
from deepface import DeepFace
from ..image_prompt.stable_diffusion import generate_image_with_prompt
from h2o_wave import Q, handle_on, on, site, ui
from loguru import logger
from PIL import Image
Expand Down Expand Up @@ -132,6 +133,8 @@ async def process(q: Q):
await img_capture_done(q)
elif q.args.change_theme:
await change_theme(q)
elif q.args.prompt_apply:
await prompt_apply(q)
await q.page.save()


Expand Down Expand Up @@ -255,10 +258,21 @@ async def image_upload(q: Q):
_img = 'data:image/png;base64,{}'.format(encoded)
q.client.current_img = _img
facial_feature_analysis(q, local_path, "Uploaded Image")

await q.page.save()


@on('prompt_apply')
async def prompt_apply(q: Q):
logger.info(f"Enable prompt.")
logger.info(f"Prompt value: {q.args.prompt_textbox}")
res_path = generate_image_with_prompt(input_img_path=q.client.source_face, prompt_txt=q.args.prompt_textbox,
output_path=OUTPUT_PATH)

q.client.prompt_textbox = q.args.prompt_textbox
q.client.processedimg = res_path
await update_processed_face(q)


@on('#capture')
async def capture(q: Q):
if q.args.img_capture_save:
Expand Down

0 comments on commit fdcedde

Please sign in to comment.