Skip to content

Commit

Permalink
add a new technique for countering oversaturation at higher cfg guida…
Browse files Browse the repository at this point in the history
…nce strength
  • Loading branch information
lucidrains committed Oct 6, 2024
1 parent ef4421a commit 7c1a4cf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,12 @@ You could consider adding a suitable metric to the training loop yourself after
url = {https://api.semanticscholar.org/CorpusID:270391454}
}
```

```bibtex
@inproceedings{Sadat2024EliminatingOA,
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273098845}
}
```
34 changes: 32 additions & 2 deletions denoising_diffusion_pytorch/classifier_free_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
from torch.amp import autocast

from einops import rearrange, reduce, repeat
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange

from tqdm.auto import tqdm
Expand Down Expand Up @@ -54,6 +54,15 @@ def convert_image_to_fn(img_type, image):
return image.convert(img_type)
return image

def pack_one_with_inverse(x, pattern):
packed, packed_shape = pack([x], pattern)

def inverse(x, inverse_pattern = None):
inverse_pattern = default(inverse_pattern, pattern)
return unpack(x, packed_shape, inverse_pattern)[0]

return packed, inverse

# normalization functions

def normalize_to_neg_one_to_one(img):
Expand All @@ -75,6 +84,19 @@ def prob_mask_like(shape, prob, device):
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')

dtype = x.dtype
x, y = x.double(), y.double()
unit = F.normalize(y, dim = -1)

parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

# small helper modules

class Residual(nn.Module):
Expand Down Expand Up @@ -357,6 +379,8 @@ def forward_with_cond_scale(
*args,
cond_scale = 1.,
rescaled_phi = 0.,
remove_parallel_component = True,
keep_parallel_frac = 0.,
**kwargs
):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)
Expand All @@ -365,7 +389,13 @@ def forward_with_cond_scale(
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
scaled_logits = null_logits + (logits - null_logits) * cond_scale
update = logits - null_logits

if remove_parallel_component:
parallel, orthog = project(update, logits)
update = orthog + parallel * keep_parallel_frac

scaled_logits = logits + update * (cond_scale - 1.)

if rescaled_phi == 0.:
return scaled_logits, null_logits
Expand Down
2 changes: 1 addition & 1 deletion denoising_diffusion_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.17'
__version__ = '2.0.18'

0 comments on commit 7c1a4cf

Please sign in to comment.