Skip to content

Commit

Permalink
start incorporating einx for extra clarity, fix an issue with dpo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 29, 2024
1 parent 11b0af9 commit c9d46a4
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
],
install_requires=[
'torch>=2.0',
'einx>=0.3.0',
'einops>=0.8.0'
],
setup_requires=[
Expand Down
11 changes: 7 additions & 4 deletions x_transformers/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch.nn.functional as F
from x_transformers.x_transformers import TransformerWrapper

import einx
from einops import rearrange

# helper functions
Expand All @@ -17,16 +18,18 @@ def freeze_all_layers_(module):
param.requires_grad = False

def log_prob_from_model_and_seq(model, seq):
logits = model(seq)
src_seq, tgt_seq = seq[:, :-1], seq[:, 1:]
logits = model(src_seq)
log_prob = logits.log_softmax(dim = -1)
indices = rearrange(seq, '... -> ... 1')
log_probs = log_prob.gather(-1, indices)
return rearrange(log_probs, '... 1 -> ...')
return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq)

def masked_mean(log_probs, mask = None):
if not exists(mask):
return log_probs.mean(dim = -1)

if mask.shape[-1] == (log_probs.shape[-1] + 1):
mask = mask[:, :-1]

log_probs = log_probs.masked_fill(~mask, 0.)
num = log_probs.sum(dim = -1)
den = mask.sum(dim = -1)
Expand Down
18 changes: 9 additions & 9 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from contextlib import nullcontext
from dataclasses import dataclass

from einops import rearrange, repeat, reduce, pack, unpack
import einx
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack

from x_transformers.attend import Attend, Intermediates
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
Expand Down Expand Up @@ -423,7 +424,7 @@ def forward(self, i, j):
# get the (n x n) matrix of distances
seq_arange = torch.arange(n, device = device)
context_arange = torch.arange(n, device = device)
indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j')
indices = einx.subtract('i, j -> i j', seq_arange, context_arange)
indices += (n - 1)

# input to continuous positions MLP
Expand Down Expand Up @@ -453,9 +454,9 @@ def __init__(self, heads, total_heads = None, **kwargs):
self.register_buffer('bias', None, persistent = False)

def get_bias(self, i, j, device):
i_arange = torch.arange(j - i, j, device = device)
j_arange = torch.arange(j, device = device)
bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1'))
seq_arange = torch.arange(j - i, j, device = device)
context_arange = torch.arange(j, device = device)
bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange))
return bias

@staticmethod
Expand Down Expand Up @@ -1236,7 +1237,7 @@ def forward(
if exists(self.max_attend_past):
range_q = torch.arange(j - i, j, device = device)
range_k = torch.arange(j, device = device)
dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j')
dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k)
max_attend_past_mask = dist > self.max_attend_past
max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values
masks.append(max_attend_past_mask)
Expand Down Expand Up @@ -1291,7 +1292,7 @@ def forward(

if exists(self.to_v_head_gate):
head_gate = self.to_v_head_gate(x)
out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid()
out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out)

# merge heads

Expand All @@ -1308,8 +1309,7 @@ def forward(
out = self.to_out(out)

if exists(mask):
mask = rearrange(mask, 'b n -> b n 1')
out = out.masked_fill(~mask, 0.)
out = einx.where('b n, b n d, -> b n d', mask, out, 0.)

if not return_intermediates:
return out
Expand Down

0 comments on commit c9d46a4

Please sign in to comment.