Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Noisy results with "order == 1" (trying to replicate DDIM resutls) #36

Open
WikiChao opened this issue Apr 24, 2023 · 3 comments
Open

Comments

@WikiChao
Copy link

Hi authors,

Thank you for the nice paper and clear code and documentation!!

I am trying DPM-Solver in my project for sampling acceleration. Previously, I can obtain reasonable results with DDIM (step=10, 100, ...), but the results I obtained with dpm-solver are pretty bad. Could you give some suggestions on the implementation?

Here are the details of my model:
(1) Training: DDPM ( L1 Loss, predict noise), T=1000, UNet with additional condition inputs, trained on audio data.
(2) Beta schedules: Sigmoid schedule (according to(https://arxiv.org/abs/2212.11972))

Code snippet that uses DPM-solver in my project:

    self.betas = sigmoid_beta_schedule(timesteps=1000)
    self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)
    self.model_fn = model_wrapper(
        self.net,
        self.noise_schedule,
        model_type="noise",  # or "x_start" or "v" or "score"
        model_kwargs={},
    )
    self.dpm_solver = DPM_Solver(self.model_fn, self.noise_schedule, algorithm_type="dpmsolver")

After the definition:

    x_T = torch.randn(input.shape, device = "cuda")
    pred = self.dpm_solver.sample(
        x_T,
        condition,
        steps=20,
        order=1,
        skip_type="time_uniform",
        method="singlestep",
    )
   pred = unnormalize_to_zero_to_one(pred)

Thanks a lot!

@LuChengTHU
Copy link
Owner

Hi @WikiChao , does your code contain this line?: https://github.com/LuChengTHU/dpm-solver/blob/main/dpm_solver_pytorch.py#L105

If so, could you please print the first 5 and last 5 items of log_alphas?

@WikiChao
Copy link
Author

Thanks for the prompt reply! The trick did help, it seems I am using the previous version and missing such a line of code.

The results make sense now, but they are still worse than DDIM. I have tried different settings, e.g., "multistep" or "single step", "order = 2 or 3", "step = 10 to 100", but cannot beat DDIM. Are there any tricks in choosing hyperparameters, for example, clipping log-SNR by different values?

Thanks a lot!

Chao

@LuChengTHU
Copy link
Owner

Hi @WikiChao ,

"but they are still worse than DDIM": In fact, order=1 is exactly the DDIM. You can try to reproduce the results of DDIM by manually setting the timestep in

def get_time_steps(self, skip_type, t_T, t_0, N, device):
as the same as your DDIM code to check which part is missing. I guess you need to tune timestep carefully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants