Skip to content

Commit

Permalink
move hyper connections to external lib
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 26, 2024
1 parent cd2fc43 commit d7c20e3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 107 deletions.
115 changes: 9 additions & 106 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

from pi_zero_pytorch.tensor_typing import Float, Int, Bool

from hyper_connections import HyperConnections

import tqdm

# ein notation
Expand Down Expand Up @@ -218,97 +220,6 @@ def noise_assignment(data, noise):
_, assign = linear_sum_assignment(dist.cpu())
return torch.from_numpy(assign).to(device)

# hyper connections - multiple residual streams

class Residual(Module):
def __init__(self, **kwargs):
super().__init__()

def prepare_with_inverse(self, residuals):
branch_input, residuals, residual_kwargs = self.prepare(residuals)

def inverse(branch_out):
return self(branch_out, residuals, **residual_kwargs)

return branch_input, inverse

def prepare(self, residuals):
return residuals, residuals, dict()

def forward(self, branch_out, residuals, **kwargs):
return branch_out + residuals

class HyperConnections(Module):
def __init__(
self,
dim,
*,
num_residual_streams,
layer_index = None,
tanh = True,
**kwargs
):
"""
https://arxiv.org/abs/2409.19606
Appendix J - Algorithm 2, Dynamic only
"""
super().__init__()

self.act = nn.Tanh() if tanh else nn.Identity()

self.norm = nn.RMSNorm(dim)

self.num_residual_streams = num_residual_streams
layer_index = default(layer_index, randrange(num_residual_streams)) # just choose one random residual stream if layer index not given

self.static_beta = nn.Parameter(torch.ones(num_residual_streams))

init_alpha0 = torch.zeros((num_residual_streams, 1))
init_alpha0[layer_index % num_residual_streams, 0] = 1.

self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1))

self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1))
self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2)
self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim))
self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2)

def prepare_with_inverse(self, residuals):
branch_input, residuals, residual_kwargs = self.prepare(residuals)

def inverse(branch_out):
return self(branch_out, residuals, **residual_kwargs)

return branch_input, inverse

def prepare(self, residuals):

residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams)

normed = self.norm(residuals)

wc_weight = self.act(normed @ self.dynamic_alpha_fn)
dynamic_alpha = wc_weight * self.dynamic_alpha_scale
alpha = dynamic_alpha + self.static_alpha

dc_weight = self.act(normed @ self.dynamic_beta_fn)
dynamic_beta = dc_weight * self.dynamic_beta_scale
beta = dynamic_beta + self.static_beta

# width connection

mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d')

branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :]

return branch_input, residuals, dict(beta = beta)

def forward(self, branch_output, residuals, *, beta):
# 'depth' connection

residuals = einsum(branch_output, beta, 'b n d, b n s -> b n s d') + residuals
return rearrange(residuals, 'b n s d -> (b s) n d')

# attention

class Attention(Module):
Expand Down Expand Up @@ -789,19 +700,11 @@ def __init__(
# residual functions, with maybe hyper connections

assert num_residual_streams >= 1
is_hyper_connection = num_residual_streams > 1
residual_klass = Residual if not is_hyper_connection else HyperConnections
init_residual_fn, self.maybe_expand_residuals, self.maybe_reduce_residuals = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

residual_fns = []
counter = count()

self.maybe_expand_residuals = identity
self.maybe_reduce_residuals = identity

if is_hyper_connection:
self.maybe_expand_residuals = maybe(partial(repeat, pattern = 'b n d -> (b s) n d', s = num_residual_streams))
self.maybe_reduce_residuals = maybe(partial(reduce, reduction = 'sum', pattern = '(b s) n d -> b n d', s = num_residual_streams))

# attention and feedforward

layers = []
Expand All @@ -818,8 +721,8 @@ def __init__(
]))

residual_fns.append(ModuleList([
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = next(counter)),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = next(counter)),
init_residual_fn(dim = dim, layer_index = next(counter)),
init_residual_fn(dim = dim, layer_index = next(counter)),
]))

cond_layers.append(ModuleList([
Expand Down Expand Up @@ -1336,7 +1239,7 @@ def forward(

# joint attention

action_tokens, add_action_residual = attn_residual.prepare_with_inverse(action_tokens)
action_tokens, add_action_residual = attn_residual(action_tokens)

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

Expand Down Expand Up @@ -1374,7 +1277,7 @@ def forward(

# action feedforward

action_tokens, add_action_ff_residual = actions_ff_residual.prepare_with_inverse(action_tokens)
action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens)

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

Expand Down Expand Up @@ -1403,7 +1306,7 @@ def forward(

# actions attention

action_tokens, add_action_residual = attn_residual.prepare_with_inverse(action_tokens)
action_tokens, add_action_residual = attn_residual(action_tokens)

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

Expand All @@ -1424,7 +1327,7 @@ def forward(

# actions feed forward

action_tokens, add_action_ff_residual = actions_ff_residual.prepare_with_inverse(action_tokens)
action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens)

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.1.2"
version = "0.1.4"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -30,6 +30,7 @@ dependencies = [
"einops>=0.8.0",
"ema-pytorch>=0.7.3",
"jaxtyping",
'hyper-connections>=0.0.10',
"rotary-embedding-torch>=0.8.5",
'scipy',
"torch>=2.5",
Expand Down

0 comments on commit d7c20e3

Please sign in to comment.