Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible to support img_callback and alternating prompts? #17

Open
zwishenzug opened this issue Nov 26, 2022 · 7 comments
Open

Possible to support img_callback and alternating prompts? #17

zwishenzug opened this issue Nov 26, 2022 · 7 comments

Comments

@zwishenzug
Copy link

zwishenzug commented Nov 26, 2022

Hi, thanks for making your code work with Stable Diffusion.

I have a couple of requests if possible.

  1. Could you support the img_callback parameter? It seems to work okay in my limited testing. I'm using the version of your code which has been included with Stable Diffusion 2.0, but just backported to 1.x (it works fine with no changes).

Supporting the img_callback function allows us to render the current state every N frames. I'm using multistep and found adding code like this after multistep_dpm_solver_update works.

                x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
                                                     solver_type=solver_type)
                if img_callback: img_callback(x, step)
  1. Could you support the ability to alternate/cycle through prompts? This one seems a bit trickier and I haven't been able to code it myself yet, but was able to do it for DDIM.

For example, I may have three seperate prompts, A, B, C

On Step 1, A is used
On Step 2, B is used
On Step 3, C is used
On Step 4, A is used
.... and so on, cycling through the list of prompts/conds

This is one of the workarounds people use to get around the 75/77 token limit for both cond and ucond.

This is what my code looks like in the ddim sampler, as you can see I essentially do x_cond = cond[step % len(cond)] to find which cond in the list to use on each step.

    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = torch.full((b,), step, device=device, dtype=torch.long)

        if mask is not None:
            assert x0 is not None
            img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
            img = img_orig * mask + (1. - mask) * img

        cond_idx     = i % len(cond)
        neg_cond_idx = i % len(unconditional_conditioning)

        x_cond       = cond[cond_idx]
        x_uc         = unconditional_conditioning[neg_cond_idx]

        outs = self.p_sample_ddim(img, x_cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                  quantize_denoised=quantize_denoised, temperature=temperature,
                                  noise_dropout=noise_dropout, score_corrector=score_corrector,
                                  corrector_kwargs=corrector_kwargs,
                                  unconditional_guidance_scale=unconditional_guidance_scale,
                                  unconditional_conditioning=x_uc)
        img, pred_x0 = outs
        if callback: callback(i)
        if img_callback: img_callback(pred_x0, i)

        if index % log_every_t == 0 or index == total_steps - 1:
            intermediates['x_inter'].append(img)
            intermediates['pred_x0'].append(pred_x0)

Thanks

@LuChengTHU
Copy link
Owner

Hi @zwishenzug , Thank you for the greatly valuable suggestions!

  1. The img_callback argument:

Thank you for the suggestion. This argument is important for some downstream tasks such as image inpainting (e.g., adding masks at each step). I've supported it in the newest version of DPM-Solver and provided an example code with stable-diffusion.

The corresponding argument is correcting_xt_fn because it can be understood as correcting the sampled xt at time t. You can find a detailed example code for image inpainting by stable-diffusion with DPM-Solver at this script.

  1. The cycling prompts:

This feature is strange to me because it actually changes the diffusion ODE at each step. Could you please give me some examples/motivations for why we need it? Thank you!

@zwishenzug
Copy link
Author

Thanks for responding.

  1. img_callback

Thanks, I understand now, and can see how I can use correcting_xt_fn

At the beginning of DPMSolverSampler::sample I can add the following code:

    def cb(x, t, step):
        if img_callback: img_callback(x, step)
        return x
    correcting_xt_fn = cb

This is a much less intrusive way to resolve my issue while remaining compatible with the DDIM/PLMS samplers

  1. Cycling prompts

These examples from the AUTOMATIC1111 user interface should give some insight to what people are doing with this.

https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alternating-words

It's correct that the target is being changed on each step. Sometimes this can have some interesting effects. For example, if you change between "Photo of Barack Obama" and "Photo of Donald Trump" on alternating steps, it may be that you get a result which is somewhat a mix of the two people.

Or in the other example, if you start off with "a male military officer" for the first half of the process, then switch to "a female military officer" half way through, you may get a more masculine woman as a result.

@zwishenzug
Copy link
Author

I've been able to get this working for my local version, I'm only concerned with classifier-free and multistep so it wasn't too hard it seems.

It seems to be a case of modifying model_fn to become model_fn(x, t_continuous, step)

Then making the code lookup the cond/ucond in the list via step modulo

    elif guidance_type == "classifier-free":
        if guidance_scale == 1. or unconditional_condition[step % len(unconditional_condition)] is None:
            return noise_pred_fn(x, t_continuous, cond=condition[step % len(condition)])
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t_continuous] * 2)
            c_in = torch.cat([unconditional_condition[step % len(unconditional_condition)], condition[step % len(condition)]])
            noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
            return noise_uncond + guidance_scale * (noise - noise_uncond)

And making sure that the current step gets passed through from the main multistep code.

I will need to do some proper testing, but it seems to be working okay so far in my limited testing.

@LuChengTHU
Copy link
Owner

Hi @zwishenzug ,

Yes, I suppose it is the easiest way to add the img_callback in stable-diffusion.

And now I understand the feature of cycling prompts. I will try to figure out a more general API for supporting this as soon as possible. Thank you for the examples!

@zwishenzug
Copy link
Author

Thank you for your hard work.

I do have another question, is it possible to support the img2img function of stable diffusion? I can see that you have implemented stochastic_encode() but scripts/img2img.py also requires a decode() function, and I haven't been able to understand how to implement it myself.

Thanks

@LuChengTHU
Copy link
Owner

No problem. I will support it soon.

@zwishenzug
Copy link
Author

I've come back to this because today I realised that actually it's correcting_x0_fn which is preferable for this "preview" (to support the img_callback in stable diffusion in the same way as DDIM).

xt shows a preview including the noise, x0 without it.

Just leaving a note in case anyone else is implementing it for themselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants