Skip to content

Commit

Permalink
fix layer index for hyper connection
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 23, 2024
1 parent aba2cf8 commit 27817d5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from beartype.typing import Callable

from functools import partial, wraps
from itertools import count

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -792,6 +793,7 @@ def __init__(
residual_klass = Residual if not is_hyper_connection else HyperConnections

residual_fns = []
counter = count()

self.maybe_expand_residuals = identity
self.maybe_reduce_residuals = identity
Expand All @@ -815,10 +817,12 @@ def __init__(
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None
]))

attn_layer_ind, ff_layer_ind = next(counter), next(counter)

residual_fns.append(ModuleList([
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = i),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = i + 1),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = i + 1),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = attn_layer_ind),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = ff_layer_ind),
residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = ff_layer_ind),
]))

cond_layers.append(ModuleList([
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.1.0"
version = "0.1.1"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 27817d5

Please sign in to comment.