Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference Strategy] Will inconsistent sampling steps cause suboptimal performance? #14

Open
Luciennnnnnn opened this issue Aug 7, 2024 · 9 comments

Comments

@Luciennnnnnn
Copy link

Hi, I notice that you utilize builtin DDIM scheduler in inference, however, the discretization of it is different with that used in training stage.

Specifically, the endpoints that split time interval into sub-trajectories are [0, 239, 499, 739, 999], however, DDIM scheduler use [0, 249, 499, 749, 999] in inference.

@G-U-N
Copy link
Owner

G-U-N commented Aug 7, 2024

Thanks for the reminder. It might be caused by the num_ddim_timesteps in the training configs. Using num_ddim_timesteps=50 should make it consistent.

@Luciennnnnnn
Copy link
Author

Luciennnnnnn commented Aug 7, 2024

@G-U-N num_ddim_timesteps=50 indeed cause problem. You can check this with following code:

import numpy as np
import torch

ddim_timesteps = 50
multiphase = 4

step_ratio = 1000 // ddim_timesteps

ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * step_ratio
        ).round().astype(np.int64) - 1
ddim_timesteps = torch.from_numpy(ddim_timesteps).long()

inference_indices = np.linspace(
    0, len(ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
    torch.from_numpy(inference_indices).long().to(ddim_timesteps.device)
)

print(ddim_timesteps) # tensor([..., 719, 739, 759, 779, ...])
print(inference_indices) # tensor([ 0, 12, 25, 37])

print(ddim_timesteps[inference_indices]) # tensor([ 19, 259, 519, 759])

step_ratio = 1000 / 4
timesteps = np.round(np.arange(1000, 0, -step_ratio)).astype(np.int64)
timesteps -= 1

print(timesteps) # [999 749 499 249]

In training, the previous timestep in DDIM before step 759 is 739, which means we learn to jump to step 739 for every timestep after 739. However, in inference, we jump from 999 to 749, which cause inconsistency.

@G-U-N
Copy link
Owner

G-U-N commented Aug 7, 2024

Did you print the end_timesteps when training? I remember it did print [0 249 499 749].

@G-U-N
Copy link
Owner

G-U-N commented Aug 7, 2024

You can make a PR if you find it is indeed wrong.

@Luciennnnnnn
Copy link
Author

end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.

import numpy as np
import torch

from diffusers import DDPMScheduler


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        self.step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * self.step_ratio
        ).round().astype(np.int64) - 1
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)

        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(
            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
        )
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
        return x_prev

    def ddim_style_multiphase_pred(self, timestep_index, multiphase):
        inference_indices = np.linspace(
            0, len(self.ddim_timesteps), num=multiphase, endpoint=False
        )
        inference_indices = np.floor(inference_indices).astype(np.int64)
        inference_indices = (
            torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
        )
        expanded_timestep_index = timestep_index.unsqueeze(1).expand(
            -1, inference_indices.size(0)
        )
        valid_indices_mask = expanded_timestep_index >= inference_indices
        last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
        last_valid_index = inference_indices.size(0) - 1 - last_valid_index
        timestep_index = inference_indices[last_valid_index]
        return self.ddim_timesteps_prev[timestep_index]
    

if __name__ == '__main__':
    ddim_timesteps = 50
    multiphase = 4

    # 1. Create the noise scheduler and the desired noise schedule.
    noise_scheduler = DDPMScheduler.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="scheduler"
    )

    # The scheduler calculates the alpha and sigma schedule for us
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    solver = DDIMSolver(
        noise_scheduler.alphas_cumprod.numpy(),
        timesteps=noise_scheduler.config.num_train_timesteps,
        ddim_timesteps=ddim_timesteps,
    )
    
    index = torch.arange(ddim_timesteps)

    end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)

    print(end_timesteps)
    # tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 239, 239,
    #     239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
    #     499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
    #     739, 739, 739, 739, 739, 739, 739, 739])

@Luciennnnnnn
Copy link
Author

I'm happy to make a PR after we agree on a proper way for it.

I think a suitable division on time interval would be:

endpoints = np.linspace(
                    -1, noise_scheduler.config.num_train_timesteps - 1, num=num_phase + 1, endpoint=True
                )# [-1, 249. 499. 749. 999.]
# endpoints[0] = 0 # we can also let start point be 0, it is fine.
endpoints = np.floor(endpoints).astype(np.int64)
endpoints = (
    torch.from_numpy(endpoints).long().to(start_timesteps.device)
)

In training, for every timestep in (endpoints[i], endpoints[i + 1]], we enforce it jump to endpoints[i]. This division is consistent with following DDIM scheduler:

DDIMScheduler(
            timestep_spacing="trailing",
            set_alpha_to_one = True,
        )

Furthermore, a proper condition is that endpoints is subset of solver.ddim_timesteps (so that we learn from endpoints directly). To satisfy the condition num_train_timesteps / num_phase needs to be a multiple of num_train_timesteps / ddim_timesteps, so num_phase=4, ddim_timesteps=40 or num_phase=5, ddim_timesteps=50 are feasible choices.

@JaySimple
Copy link

end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.

import numpy as np
import torch

from diffusers import DDPMScheduler


def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        self.step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * self.step_ratio
        ).round().astype(np.int64) - 1
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)

        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(
            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
        )
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
        return x_prev

    def ddim_style_multiphase_pred(self, timestep_index, multiphase):
        inference_indices = np.linspace(
            0, len(self.ddim_timesteps), num=multiphase, endpoint=False
        )
        inference_indices = np.floor(inference_indices).astype(np.int64)
        inference_indices = (
            torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
        )
        expanded_timestep_index = timestep_index.unsqueeze(1).expand(
            -1, inference_indices.size(0)
        )
        valid_indices_mask = expanded_timestep_index >= inference_indices
        last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
        last_valid_index = inference_indices.size(0) - 1 - last_valid_index
        timestep_index = inference_indices[last_valid_index]
        return self.ddim_timesteps_prev[timestep_index]
    

if __name__ == '__main__':
    ddim_timesteps = 50
    multiphase = 4

    # 1. Create the noise scheduler and the desired noise schedule.
    noise_scheduler = DDPMScheduler.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="scheduler"
    )

    # The scheduler calculates the alpha and sigma schedule for us
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    solver = DDIMSolver(
        noise_scheduler.alphas_cumprod.numpy(),
        timesteps=noise_scheduler.config.num_train_timesteps,
        ddim_timesteps=ddim_timesteps,
    )
    
    index = torch.arange(ddim_timesteps)

    end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)

    print(end_timesteps)
    # tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 239, 239,
    #     239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
    #     499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
    #     739, 739, 739, 739, 739, 739, 739, 739])

I think you're right. I did the same test as yours, and the results were the same. Therefore, the inconsistency between the timesteps of the training and the test phase, may lead to the sub-optimal results, and I'm wondering whether you've tried your revised version, how is the result? Thx.

@Luciennnnnnn
Copy link
Author

@JaySimple I have tested my revised version, and it functions as intended. However, I'm currently working on a different configuration, so a direct comparison with the original implementation is not available. I'm confident that my implementation at least maintains the existing performance level, if not improves it.

@G-U-N
Copy link
Owner

G-U-N commented Sep 12, 2024

Great, I will check it in time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants