Skip to content

Commit

Permalink
add immiscible diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 16, 2024
1 parent ec0a1c7 commit 5a0e07f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 2 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,14 @@ You could consider adding a suitable metric to the training loop yourself after
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
```

```bibtex
@article{Li2024ImmiscibleDA,
title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment},
author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu},
journal = {ArXiv},
year = {2024},
volume = {abs/2406.12303},
url = {https://api.semanticscholar.org/CorpusID:270562607}
}
```
19 changes: 18 additions & 1 deletion denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange

from scipy.optimize import linear_sum_assignment

from PIL import Image
from tqdm.auto import tqdm
from ema_pytorch import EMA
Expand Down Expand Up @@ -488,7 +490,8 @@ def __init__(
auto_normalize = True,
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
min_snr_gamma = 5,
immiscible = False
):
super().__init__()
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
Expand Down Expand Up @@ -564,6 +567,10 @@ def __init__(
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

# immiscible diffusion

self.immiscible = immiscible

# offset noise strength - in blogpost, they claimed 0.1 was ideal

self.offset_noise_strength = offset_noise_strength
Expand Down Expand Up @@ -759,10 +766,20 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):

return img

def noise_assignment(self, x_start, noise):
x_start, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (x_start, noise))
dist = torch.cdist(x_start, noise)
_, assign = linear_sum_assignment(dist.cpu())
return torch.from_numpy(assign).to(dist.device)

@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))

if self.immiscible:
assign = self.noise_assignment(x_start, noise)
noise = noise[assign]

return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.12'
__version__ = '2.0.15'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'numpy',
'pillow',
'pytorch-fid',
'scipy',
'torch',
'torchvision',
'tqdm'
Expand Down

0 comments on commit 5a0e07f

Please sign in to comment.