Skip to content

Commit

Permalink
add the proposed drawing of times from beta distribution detailed in …
Browse files Browse the repository at this point in the history
…appendix B, allow to be customizable
  • Loading branch information
lucidrains committed Dec 18, 2024
1 parent 8ad66fd commit bf3ed9c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
22 changes: 21 additions & 1 deletion pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn.functional as F
from torch import pi, nn, tensor, is_tensor
from torch.nn import Module, ModuleList
from torch.distributions.beta import Beta

from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten

from torchdiffeq import odeint
Expand Down Expand Up @@ -180,6 +182,19 @@ def pad_at_dim(

# flow related

def default_sample_times(
shape,
s = 0.999,
alpha = 1.5,
beta = 1,
device = None
):
""" they propose to sample times from Beta distribution - last part of appendix part B """

uniform = torch.rand(shape, device = device)
sampled = Beta(alpha, beta).sample()
return ((s - uniform) / s) * sampled

def noise_assignment(data, noise):
device = data.device
data, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (data, noise))
Expand Down Expand Up @@ -589,6 +604,7 @@ def __init__(
lm_loss_weight = 1.,
flow_loss_weight = 1.,
immiscible_flow = False, # https://arxiv.org/abs/2406.12303
sample_times_fn = default_sample_times,
reward_tokens_dropout_prob = 0.,
num_recurrent_memory_tokens = 0,
odeint_kwargs: dict = dict(
Expand Down Expand Up @@ -707,6 +723,10 @@ def __init__(

self.reward_tokens_dropout_prob = reward_tokens_dropout_prob

# time sampling related

self.sample_times_fn = default(sample_times_fn, torch.rand)

# loss related

self.lm_loss_weight = lm_loss_weight
Expand Down Expand Up @@ -975,7 +995,7 @@ def forward(
# noising the action for flow matching

if not exists(times):
times = torch.rand((batch,), device = device)
times = self.sample_times_fn((batch,), device = device)

if times.ndim == 0:
times = repeat(times, '-> b', b = batch)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.45"
version = "0.0.46"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down

0 comments on commit bf3ed9c

Please sign in to comment.