Skip to content

Commit

Permalink
complete the ada-ln zero conditioning used in DiT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 19, 2024
1 parent 41a3285 commit 830db47
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 25 deletions.
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1515,11 +1515,7 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)

```bibtex
@article{DBLP:journals/corr/abs-1907-01470,
author = {Sainbayar Sukhbaatar and
Edouard Grave and
Guillaume Lample and
Herv{\'{e}} J{\'{e}}gou and
Armand Joulin},
author = {Sainbayar Sukhbaatar and Edouard Grave and Guillaume Lample and Herv{\'{e}} J{\'{e}}gou and Armand Joulin},
title = {Augmenting Self-attention with Persistent Memory},
journal = {CoRR},
volume = {abs/1907.01470},
Expand Down Expand Up @@ -2162,4 +2158,15 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@article{Peebles2022ScalableDM,
title = {Scalable Diffusion Models with Transformers},
author = {William S. Peebles and Saining Xie},
journal = {2023 IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2022},
pages = {4172-4182},
url = {https://api.semanticscholar.org/CorpusID:254854389}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.30.20',
version = '1.30.21',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def test_adaptive_layernorm():
dim_condition = 768,
depth = 12,
heads = 8,
use_adaptive_layernorm = True
use_adaptive_layernorm = True,
use_conditioned_layerscale = True
)
)

Expand Down
80 changes: 62 additions & 18 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,27 @@ def forward(self, x, **kwargs):
out, *rest = out
return out * self.gamma, *rest

class ConditionedLayerScale(Module):
def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
super().__init__()
self.fn = fn

dim_condition = default(dim_condition, dim)
self.to_gamma = nn.Linear(dim_condition, dim)

nn.init.zeros_(self.to_gamma.weight)
nn.init.constant_(self.to_gamma.bias, init_bias_value)

def forward(self, x, *, condition, **kwargs):
out = self.fn(x, **kwargs)
gamma = self.to_gamma(condition).sigmoid()

if isinstance(out, Tensor):
return out * gamma

out, *rest = out
return out * gamma, *rest

# feedforward

class GLU(Module):
Expand Down Expand Up @@ -1160,6 +1181,7 @@ def __init__(
use_simple_rmsnorm = False,
use_adaptive_layernorm = False,
use_adaptive_rmsnorm = False,
use_conditioned_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
dim_condition = None,
adaptive_condition_mlp = False,
adaptive_condition_mlp_expansion = 4,
Expand Down Expand Up @@ -1269,7 +1291,7 @@ def __init__(

assert at_most_one_of(use_scalenorm, use_rmsnorm, use_simple_rmsnorm, use_adaptive_layernorm, use_adaptive_rmsnorm), 'you can only use either scalenorm, rmsnorm, adaptive layernorm, adaptive rmsnorm, or simple rmsnorm'

need_condition = False
norm_need_condition = False
dim_condition = default(dim_condition, dim)
dim_condition_mult = 1

Expand All @@ -1283,25 +1305,17 @@ def __init__(
elif use_simple_rmsnorm:
norm_class = SimpleRMSNorm
elif use_adaptive_layernorm:
need_condition = True
norm_need_condition = True
norm_class = partial(AdaptiveLayerNorm, dim_condition = dim_condition * dim_condition_mult)
elif use_adaptive_rmsnorm:
need_condition = True
norm_need_condition = True
norm_class = partial(AdaptiveRMSNorm, dim_condition = dim_condition * dim_condition_mult)
else:
norm_class = LayerNorm

norm_fn = partial(norm_class, dim)

self.adaptive_mlp = nn.Identity()

if need_condition and adaptive_condition_mlp:
self.adaptive_mlp = nn.Sequential(
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
nn.SiLU()
)

self.need_condition = need_condition
self.norm_need_condition = norm_need_condition
self.dim_condition = dim_condition

# determine default block layer type order
Expand All @@ -1318,10 +1332,30 @@ def __init__(

# determine post branch wrapper

assert at_most_one_of(use_layerscale, use_conditioned_layerscale)

post_branch_fn = None
post_branch_fn_needs_condition = False

if use_layerscale:
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
elif use_conditioned_layerscale:
post_branch_fn = partial(ConditionedLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
post_branch_fn_needs_condition = True

self.post_branch_fn_needs_condition = post_branch_fn_needs_condition

# setup mlp for conditioning

self.need_condition = norm_need_condition or post_branch_fn_needs_condition

self.adaptive_mlp = nn.Identity()

if self.need_condition and adaptive_condition_mlp:
self.adaptive_mlp = nn.Sequential(
nn.Linear(dim_condition, dim_condition * dim_condition_mult, bias = False),
nn.SiLU()
)

# zero init

Expand Down Expand Up @@ -1455,24 +1489,32 @@ def forward(
assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True'
assert not (exists(condition) ^ self.need_condition), 'condition needs to be passed in if using adaptive layernorm or vice versa'

# setup maybe layernorm kwarg
# handle condition

norm_kwargs = dict()
if exists(condition):
assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'

if self.need_condition:
assert condition.ndim in {2, 3}

if condition.ndim == 2:
condition = rearrange(condition, 'b d -> b 1 d')

assert condition.shape[-1] == self.dim_condition, f'expected condition dimension of {self.dim_condition} but received {condition.shape[-1]}'
condition = self.adaptive_mlp(condition)

# maybe mlp
# setup maybe layernorm kwarg

condition = self.adaptive_mlp(condition)
norm_kwargs = dict()

if self.norm_need_condition:
norm_kwargs.update(condition = condition)

# maybe post branch fn conditioning (DiT paper's ada-ln-zero)

block_forward_kwargs = dict()

if self.post_branch_fn_needs_condition:
block_forward_kwargs.update(condition = condition)

# initialize accums

hiddens = []
Expand Down Expand Up @@ -1573,6 +1615,8 @@ def forward(
if layer_type == 'a' and exists(layer_mem):
layer_mem = pre_norm(layer_mem)

block = partial(block, **block_forward_kwargs)

if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, return_intermediates = True)
elif layer_type == 'c':
Expand Down

0 comments on commit 830db47

Please sign in to comment.