Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BestRQ implementation #63

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions i6_models/parts/best_rq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .mask import *
from .quantizer import *
74 changes: 74 additions & 0 deletions i6_models/parts/best_rq/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
import numpy as np

__all__ = ["RandomMask"]


class RandomMask(nn.Module):
"""
randomly mask out consecutive frames time dimension, the masked frames can be either
replaced with zeros or with learnable embeddings.
simplified version from Fairseq compute_mask_indices function,
C.f. https://github.com/facebookresearch/fairseq/blob/ecbf110e1eb43861214b05fa001eff584954f65a/fairseq/data/data_utils.py#L399
"""

def __init__(
self,
input_dim: int,
mask_replace_val: str,
mask_percentage: float,
mask_length: int,
):
"""
:param input_dim: number of feature dimension of input
:param mask_replace_val: the way to replace masked frames, either with zeros or lernable embeddings
:param mask_percentage: percentage of frames to be masked out
:param mask_length: the length of each mask span
michelwi marked this conversation as resolved.
Show resolved Hide resolved
"""
super().__init__()

assert mask_replace_val in ["lernable", "zero"], "not implemented yet"
if mask_replace_val == "lernable":
self.mask_emb = nn.Parameter(torch.FloatTensor(input_dim).uniform_())
elif mask_replace_val == "zero":
self.mask_emb = torch.zeros(input_dim)
self.mask_percentage = mask_percentage
self.mask_length = mask_length

def forward(
self,
tensor: torch.tensor,
padding_mask: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
ndim_batch, ndim_time, _ = tensor.size()

mask = torch.zeros((ndim_batch, ndim_time), dtype=torch.bool)

mask_idcs = []
for i in range(ndim_batch):
if padding_mask is not None:
seq_len = ndim_time - padding_mask[i].long().sum().item()
assert seq_len >= 0
else:
seq_len = ndim_time

num_mask = int(
# add a random number for probabilistic rounding
self.mask_percentage * seq_len / float(self.mask_length)
+ np.random.rand()
)

min_len = self.mask_length
if seq_len - min_len <= num_mask:
min_len = seq_len - num_mask - 1
michelwi marked this conversation as resolved.
Show resolved Hide resolved
mask_idc = np.random.choice(seq_len - min_len, num_mask, replace=False)
michelwi marked this conversation as resolved.
Show resolved Hide resolved

for j in mask_idc:
mask[i, j : j + self.mask_length] = True

tensor[mask] = self.mask_emb.to(tensor.device)

return tensor, torch.tensor(mask).to(tensor.device)
37 changes: 37 additions & 0 deletions i6_models/parts/best_rq/quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm

__all__ = [
"RandomProjectionQuantizer",
]


class RandomProjectionQuantizer(nn.Module):
"""
implement the fixed random projection quantizer from BestRQ
C.f. https://arxiv.org/pdf/2202.01855 for theoretic background
code adapted from https://github.com/speechbrain/speechbrain/blob/16b6420d4ff23210cfca2e888be8853264e0cb17/speechbrain/nnet/quantisers.py#L127
"""

def __init__(self, input_dim, codebook_dim, codebook_num_vars):
"""
:param input_dim: number of feature dimension of input
:param codebook_dim: number of dimension for vocab in the codebook
:param codebook_num_vars: vocab size of the codebook
"""
super().__init__()

self.input_dim = input_dim

# projection matrix use Xavier initialization
P_init = torch.empty((input_dim, codebook_dim))
self.register_buffer("P", nn.init.xavier_uniform_(P_init))

# normalize random matrix for codebook
self.register_buffer("CB", F.normalize(torch.randn(codebook_num_vars, codebook_dim)))

def forward(self, x: torch.tensor) -> torch.tensor:
x = F.normalize(x @ self.P)
return vector_norm((self.CB.unsqueeze(1) - x.unsqueeze(1)), dim=-1).argmin(dim=1)
michelwi marked this conversation as resolved.
Show resolved Hide resolved
Loading