Skip to content

Commit

Permalink
few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 18, 2024
1 parent bf3ed9c commit 8bb76d0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def default_sample_times(
""" they propose to sample times from Beta distribution - last part of appendix part B """

uniform = torch.rand(shape, device = device)
sampled = Beta(alpha, beta).sample()
return ((s - uniform) / s) * sampled
sampled = Beta(alpha, beta).sample().to(device)
return ((s - uniform) / s).clamp(0., 1.) * sampled

def noise_assignment(data, noise):
device = data.device
Expand Down Expand Up @@ -403,9 +403,11 @@ def forward(

# handle read, write memories

assert not (self.accept_memories ^ exists(memories))
has_memories = exists(memories) and any([m.numel() > 0 for m in memories])

if exists(memories):
assert not (self.accept_memories ^ has_memories)

if has_memories:
memories, unpack_memories = pack_with_inverse(memories, 'b * d')
memories = self.mem_rmsnorm(memories)
mqkv = self.to_mem_qkv(memories)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.46"
version = "0.0.47"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit 8bb76d0

Please sign in to comment.