Skip to content

Commit

Permalink
Merge pull request #287 from lucidrains/karras-magnitude-preserving-unet
Browse files Browse the repository at this point in the history
complete karras unet
  • Loading branch information
lucidrains authored Feb 6, 2024
2 parents 32310c3 + 42a9e79 commit bb7a9be
Show file tree
Hide file tree
Showing 5 changed files with 669 additions and 3 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,14 @@ You could consider adding a suitable metric to the training loop yourself after
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
```

```bibtex
@article{Karras2023AnalyzingAI,
title = {Analyzing and Improving the Training Dynamics of Diffusion Models},
author = {Tero Karras and Miika Aittala and Jaakko Lehtinen and Janne Hellsten and Timo Aila and Samuli Laine},
journal = {ArXiv},
year = {2023},
volume = {abs/2312.02696},
url = {https://api.semanticscholar.org/CorpusID:265659032}
}
```
2 changes: 2 additions & 0 deletions denoising_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
from denoising_diffusion_pytorch.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion

from denoising_diffusion_pytorch.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D

from denoising_diffusion_pytorch.karras_unet import KarrasUnet
13 changes: 11 additions & 2 deletions denoising_diffusion_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def once(fn):
called = False
@wraps(fn)
Expand All @@ -36,10 +39,12 @@ class Attend(nn.Module):
def __init__(
self,
dropout = 0.,
flash = False
flash = False,
scale = None
):
super().__init__()
self.dropout = dropout
self.scale = scale
self.attn_dropout = nn.Dropout(dropout)

self.flash = flash
Expand All @@ -65,6 +70,10 @@ def __init__(
def flash_attn(self, q, k, v):
_, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

if exists(self.scale):
default_scale = q.shape[-1]
q = q * (scale / default_scale)

q, k, v = map(lambda t: t.contiguous(), (q, k, v))

# Check if there is a compatible device for flash attention
Expand Down Expand Up @@ -95,7 +104,7 @@ def forward(self, q, k, v):
if self.flash:
return self.flash_attn(q, k, v)

scale = q.shape[-1] ** -0.5
scale = default(self.scale, q.shape[-1] ** -0.5)

# similarity

Expand Down
Loading

0 comments on commit bb7a9be

Please sign in to comment.