Skip to content

Commit

Permalink
rename to adaptive layerscale
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 19, 2024
1 parent 830db47 commit 33ea37a
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.30.21',
version = '1.30.22',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_adaptive_layernorm():
depth = 12,
heads = 8,
use_adaptive_layernorm = True,
use_conditioned_layerscale = True
use_adaptive_layerscale = True
)
)

Expand Down
10 changes: 5 additions & 5 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def forward(self, x, **kwargs):
out, *rest = out
return out * self.gamma, *rest

class ConditionedLayerScale(Module):
class AdaptiveLayerScale(Module):
def __init__(self, fn: Module, dim, dim_condition = None, init_bias_value = -2.):
super().__init__()
self.fn = fn
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def __init__(
use_simple_rmsnorm = False,
use_adaptive_layernorm = False,
use_adaptive_rmsnorm = False,
use_conditioned_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
use_adaptive_layerscale = False, # paired with use_adaptive_layernorm for ada-ln-zero from DiT paper
dim_condition = None,
adaptive_condition_mlp = False,
adaptive_condition_mlp_expansion = 4,
Expand Down Expand Up @@ -1332,15 +1332,15 @@ def __init__(

# determine post branch wrapper

assert at_most_one_of(use_layerscale, use_conditioned_layerscale)
assert at_most_one_of(use_layerscale, use_adaptive_layerscale)

post_branch_fn = None
post_branch_fn_needs_condition = False

if use_layerscale:
post_branch_fn = partial(LayerScale, dim = dim, init_value = layerscale_init_value)
elif use_conditioned_layerscale:
post_branch_fn = partial(ConditionedLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
elif use_adaptive_layerscale:
post_branch_fn = partial(AdaptiveLayerScale, dim = dim, dim_condition = dim_condition * dim_condition_mult)
post_branch_fn_needs_condition = True

self.post_branch_fn_needs_condition = post_branch_fn_needs_condition
Expand Down

0 comments on commit 33ea37a

Please sign in to comment.