From c9d46a414e76bb8bd119eb84cb4cb33f57b33e20 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 29 Oct 2024 07:08:01 -0700 Subject: [PATCH] start incorporating einx for extra clarity, fix an issue with dpo --- setup.py | 1 + x_transformers/dpo.py | 11 +++++++---- x_transformers/x_transformers.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/setup.py b/setup.py index 1d741661..5ea9b59d 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ ], install_requires=[ 'torch>=2.0', + 'einx>=0.3.0', 'einops>=0.8.0' ], setup_requires=[ diff --git a/x_transformers/dpo.py b/x_transformers/dpo.py index da12234e..3fc4ec21 100644 --- a/x_transformers/dpo.py +++ b/x_transformers/dpo.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from x_transformers.x_transformers import TransformerWrapper +import einx from einops import rearrange # helper functions @@ -17,16 +18,18 @@ def freeze_all_layers_(module): param.requires_grad = False def log_prob_from_model_and_seq(model, seq): - logits = model(seq) + src_seq, tgt_seq = seq[:, :-1], seq[:, 1:] + logits = model(src_seq) log_prob = logits.log_softmax(dim = -1) - indices = rearrange(seq, '... -> ... 1') - log_probs = log_prob.gather(-1, indices) - return rearrange(log_probs, '... 1 -> ...') + return einx.get_at('b n [l], b n -> b n', log_prob, tgt_seq) def masked_mean(log_probs, mask = None): if not exists(mask): return log_probs.mean(dim = -1) + if mask.shape[-1] == (log_probs.shape[-1] + 1): + mask = mask[:, :-1] + log_probs = log_probs.masked_fill(~mask, 0.) num = log_probs.sum(dim = -1) den = mask.sum(dim = -1) diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index f980345e..078231f7 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -16,8 +16,9 @@ from contextlib import nullcontext from dataclasses import dataclass -from einops import rearrange, repeat, reduce, pack, unpack +import einx from einops.layers.torch import Rearrange +from einops import rearrange, repeat, reduce, pack, unpack from x_transformers.attend import Attend, Intermediates from x_transformers.autoregressive_wrapper import AutoregressiveWrapper @@ -423,7 +424,7 @@ def forward(self, i, j): # get the (n x n) matrix of distances seq_arange = torch.arange(n, device = device) context_arange = torch.arange(n, device = device) - indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') + indices = einx.subtract('i, j -> i j', seq_arange, context_arange) indices += (n - 1) # input to continuous positions MLP @@ -453,9 +454,9 @@ def __init__(self, heads, total_heads = None, **kwargs): self.register_buffer('bias', None, persistent = False) def get_bias(self, i, j, device): - i_arange = torch.arange(j - i, j, device = device) - j_arange = torch.arange(j, device = device) - bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) + seq_arange = torch.arange(j - i, j, device = device) + context_arange = torch.arange(j, device = device) + bias = -torch.abs(einx.subtract('j, i -> 1 i j', context_arange, seq_arange)) return bias @staticmethod @@ -1236,7 +1237,7 @@ def forward( if exists(self.max_attend_past): range_q = torch.arange(j - i, j, device = device) range_k = torch.arange(j, device = device) - dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j') + dist = einx.subtract('i, j -> 1 1 i j', range_q, range_k) max_attend_past_mask = dist > self.max_attend_past max_attend_past_mask = pad_at_dim(max_attend_past_mask, (num_mem_kv, 0), value = False, dim = -1) # handle memory key / values masks.append(max_attend_past_mask) @@ -1291,7 +1292,7 @@ def forward( if exists(self.to_v_head_gate): head_gate = self.to_v_head_gate(x) - out = out * rearrange(head_gate, 'b n h -> b h n 1').sigmoid() + out = einx.multiply('b n h, b h n d ->b h n d', head_gate.sigmoid(), out) # merge heads @@ -1308,8 +1309,7 @@ def forward( out = self.to_out(out) if exists(mask): - mask = rearrange(mask, 'b n -> b n 1') - out = out.masked_fill(~mask, 0.) + out = einx.where('b n, b n d, -> b n d', mask, out, 0.) if not return_intermediates: return out