From 6ca31b1d49350eb53dd7120c101fc348e44e48f6 Mon Sep 17 00:00:00 2001 From: Shimanogov <37351887+Shimanogov@users.noreply.github.com> Date: Wed, 3 Aug 2022 15:09:38 +0300 Subject: [PATCH] implemented inpainting mode --- discoart/create.py | 4 ++++ discoart/nn/inpainting.py | 13 +++++++++++++ discoart/resources/default.yml | 2 ++ discoart/runner.py | 9 +++++++++ 4 files changed, 28 insertions(+) create mode 100644 discoart/nn/inpainting.py diff --git a/discoart/create.py b/discoart/create.py index 598c903..d186cb4 100644 --- a/discoart/create.py +++ b/discoart/create.py @@ -46,6 +46,8 @@ def create( init_document: Optional[Union['Document', 'DocumentArray']] = None, init_image: Optional[str] = None, init_scale: Optional[Union[int, str]] = 1000, + inpaint_image: Optional[str] = None, + inpaint_mask: Optional[str] = None, n_batches: Optional[int] = 4, name_docarray: Optional[str] = None, on_misspelled_token: Optional[str] = 'ignore', @@ -120,6 +122,8 @@ def create(**kwargs) -> Optional['DocumentArray']: :param init_document: [DiscoArt] Use a Document object as the initial state for DD: its ``.tags`` will be used as parameters, ``.uri`` (if present) will be used as init image. :param init_image: Recall that in the image sequence above, the first image shown is just noise. If an init_image is provided, diffusion will replace the noise with the init_image as its starting state. To use an init_image, upload the image to the Colab instance or your Google Drive, and enter the full image path here. If using an init_image, you may need to increase skip_steps to ~ 50% of total steps to retain the character of the init. See skip_steps above for further discussion. :param init_scale: This controls how strongly CLIP will try to match the init_image provided. This is balanced against the clip_guidance_scale (CGS) above. Too much init scale, and the image won’t change much during diffusion. Too much CGS and the init image will be lost.[DiscoArt] Can be scheduled via syntax `[val1]*400+[val2]*600`. + :param inpaint_image: [DiscoArt] If an inpaint image and mask is provided, final image will match the inpaint image in non-masked areas. + :param inpaint_mask: [DiscoArt] If an inpaint image and mask is provided, final image will match the inpaint image in non-masked areas. :param n_batches: This variable sets the number of still images you want DD to create. If you are using an animation mode (see below for details) DD will ignore n_batches and create a single set of animated frames based on the animation settings. :param name_docarray: [DiscoArt] When specified, it overrides the default naming schema of the resulted DocumentArray. Useful when you have to know the result DocumentArray name in advance.The name also supports variable substitution via `{}`. For example, `name_docarray='test-{steps}-{perlin_init}'` will give the name of the DocumentArray as `test-250-False`. Any variable in the config can be substituted. :param on_misspelled_token: [DiscoArt] Strategy when encounter misspelled token, can be 'raise', 'correct' and 'ignore'. If 'raise', then the misspelled token in the prompt will raise a ValueError. If 'correct', then the token will be replaced with the correct token. If 'ignore', then the token will be ignored but a warning will show. diff --git a/discoart/nn/inpainting.py b/discoart/nn/inpainting.py new file mode 100644 index 0000000..59e94cd --- /dev/null +++ b/discoart/nn/inpainting.py @@ -0,0 +1,13 @@ +def wrap_inpaint(diff, inpaint_tensor, mask_tensor, noise): + def inpaint_sampling(f): + def func(*args, **kwargs): + result = f(*args, **kwargs) + curr_t = args[-1] + curr_image = diff.q_sample(inpaint_tensor, curr_t, noise) + result["sample"] = curr_image * mask_tensor + result["sample"] * (1 - mask_tensor) + return result + + return func + + diff.ddim_sample = inpaint_sampling(diff.ddim_sample) + diff.plms_sample = inpaint_sampling(diff.plms_sample) \ No newline at end of file diff --git a/discoart/resources/default.yml b/discoart/resources/default.yml index e0b91be..e42d100 100644 --- a/discoart/resources/default.yml +++ b/discoart/resources/default.yml @@ -4,6 +4,8 @@ text_prompts: init_image: width_height: [ 1280, 768 ] +inpaint_image: +inpaint_mask: skip_steps: 0 steps: 250 diff --git a/discoart/runner.py b/discoart/runner.py index b3a0243..6f33aed 100644 --- a/discoart/runner.py +++ b/discoart/runner.py @@ -31,6 +31,7 @@ from .nn.transform import symmetry_transformation_fn from .persist import _sample_thread, _persist_thread, _save_progress_thread from .prompt import PromptPlanner +from .nn.inpainting import wrap_inpaint def do_run(args, models, device, events) -> 'DocumentArray': @@ -325,6 +326,14 @@ def cond_fn(x, t, **kwargs): return r_grad + if args.inpaint_image and args.inpaint_mask: + d = Document(uri=args.inpaint_image).load_uri_to_image_tensor(side_x, side_y) + inpaint = TF.to_tensor(d.tensor).to(device).unsqueeze(0).mul(2).sub(1) + d = Document(uri=args.inpaint_mask).load_uri_to_image_tensor(side_x, side_y) + mask = TF.to_tensor(d.tensor).to(device).unsqueeze(0) + noise = torch.randn(*inpaint.size(), device=device) + wrap_inpaint(diffusion, inpaint, mask, noise) + if args.diffusion_sampling_mode == 'ddim': sample_fn = diffusion.ddim_sample_loop_progressive else: