Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test #96

Closed
wants to merge 11 commits into from
Closed

test #96

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ jobs:
max-parallel: 4
matrix:
include:
- python: 3.9
torch: 2.0.0
torchaudio: 2.0.1
- python: 3.12
torch: 2.4.0
torchaudio: 2.4.0
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ PLATFORM := cu118

venv:
test -d venv || python$(PYTHON_VERSION) -m venv venv
. ./venv/bin/activate && python -m pip install pip --upgrade
. ./venv/bin/activate && python -m pip install --upgrade pip
. ./venv/bin/activate && python -m pip install --upgrade wheel icc-rt
. ./venv/bin/activate && python -m pip install torch==$(TORCH_VERSION)+$(PLATFORM) torchaudio==$(TORCHAUDIO_VERSION)+$(PLATFORM) \
--index-url https://download.pytorch.org/whl/$(PLATFORM)
. ./venv/bin/activate && python -m pip install -e .[dev]
. ./venv/bin/activate && python -m pip install icc-rt

dist:
. ./venv/bin/activate && python -m build
Expand Down
99 changes: 64 additions & 35 deletions diffsptk/modules/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tqdm import tqdm

from ..misc.utils import to_dataloader
from .lbg import LindeBuzoGrayAlgorithm


class GaussianMixtureModeling(nn.Module):
Expand Down Expand Up @@ -65,14 +64,15 @@ class GaussianMixtureModeling(nn.Module):
Batch size.

verbose : bool
If True, print progress.
If 1, show distance at each iteration; if 2, show progress bar.

"""

def __init__(
self,
order,
n_mixture,
*,
n_iter=100,
eps=1e-5,
weight_floor=1e-5,
Expand Down Expand Up @@ -155,6 +155,8 @@ def __init__(
handler.setFormatter(formatter)
self.logger.addHandler(handler)

self.hide_progress_bar = self.verbose <= 1

def set_params(self, params):
"""Set model parameters.

Expand Down Expand Up @@ -192,6 +194,8 @@ def warmup(self, x, **lbg_params):
x = to_dataloader(x, batch_size=self.batch_size)
device = self.w.device

from .lbg import LindeBuzoGrayAlgorithm

lbg = LindeBuzoGrayAlgorithm(self.order, self.n_mixture, **lbg_params).to(
device
)
Expand All @@ -202,15 +206,15 @@ def warmup(self, x, **lbg_params):
mu = codebook

idx = indices.view(-1, 1, 1).expand(-1, self.order + 1, self.order + 1)
kxx = torch.zeros_like(self.sigma) # [K, L, L]
kxx = torch.zeros_like(self.sigma) # (K, L, L)
b = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
xp = batch_x.to(device)
xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2))
kxx.scatter_add_(0, idx[b:e], xx)
b = e
mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # [K, L, L]
mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # (K, L, L)
sigma = kxx / count.view(-1, 1, 1) - mm
sigma = sigma * self.mask

Expand Down Expand Up @@ -287,7 +291,7 @@ def forward(self, x, return_posterior=False):
# Update mean vectors.
px = []
b = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
px.append(torch.matmul(posterior[b:e].t(), batch_x.to(device)))
b = e
Expand All @@ -301,7 +305,7 @@ def forward(self, x, return_posterior=False):
if self.is_diag:
pxx = []
b = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
xx = batch_x.to(device) ** 2
pxx.append(torch.matmul(posterior[b:e].t(), xx))
Expand All @@ -322,7 +326,7 @@ def forward(self, x, return_posterior=False):
else:
pxx = []
b = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
xp = batch_x.to(device)
xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2))
Expand Down Expand Up @@ -364,27 +368,44 @@ def forward(self, x, return_posterior=False):
return ret

def transform(self, x):
"""Transform input vectors to mixture indices.
"""Transform input vectors based on a single mixture sequence.

Parameters
----------
x : Tensor [shape=(T, M+1)]
x : Tensor [shape=(T, N+1)]
Input vectors.

Returns
-------
y : Tensor [shape=(T, M-N)]
Output vectors.

indices : Tensor [shape=(T,)]
Mixture indices.
Selected mixture indices.

log_prob : Tensor [shape=(T,)]
Log probabilities.

"""
posterior, log_prob = self._e_step(x, reduction="none")
N = x.size(-1) - 1
posterior, log_prob = self._e_step(x, reduction="none", in_order=N)
indices = torch.argmax(posterior, dim=-1)
return indices, log_prob

def _e_step(self, x, reduction="sum"):
if self.order == N:
return None, indices, log_prob

L = N + 1
sigma_yx = self.sigma[:, L:, :L]
sigma_xx = self.sigma[:, :L, :L]
sigma_yx_xx = torch.matmul(sigma_yx, torch.inverse(sigma_xx))
mu_x = self.mu[indices, :L]
mu_y = self.mu[indices, L:]
diff = (x - mu_x).unsqueeze(-1)
E = mu_y + torch.matmul(sigma_yx_xx[indices], diff).squeeze(-1)
y = E
return y, indices, log_prob

def _e_step(self, x, reduction="sum", in_order=None):
"""Expectation step.

Parameters
Expand All @@ -395,6 +416,9 @@ def _e_step(self, x, reduction="sum"):
reduction : ['none', 'sum']
Reduction type.

in_order : int >= 0
Order of input vectors.

Returns
-------
posterior : Tensor [shape=(T, K)]
Expand All @@ -407,38 +431,43 @@ def _e_step(self, x, reduction="sum"):
x = to_dataloader(x, self.batch_size)
device = self.w.device

log_pi = (self.order + 1) * np.log(2 * np.pi)
if in_order is None:
L = self.order + 1
mu, sigma = self.mu, self.sigma
else:
L = in_order + 1
mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L]

log_pi = L * np.log(2 * np.pi)
if self.is_diag:
log_det = torch.log(torch.diagonal(self.sigma, dim1=-2, dim2=-1)).sum(
-1
) # [K]
log_det = torch.log(torch.diagonal(sigma, dim1=-2, dim2=-1)).sum(-1) # (K,)
precision = torch.reciprocal(
torch.diagonal(self.sigma, dim1=-2, dim2=-1)
) # [K, L]
torch.diagonal(sigma, dim1=-2, dim2=-1)
) # (K, L)
mahala = []
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
xp = batch_x.to(device)
diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L]
mahala.append((diff**2 * precision).sum(-1)) # [B, K]
mahala = torch.cat(mahala) # [T, K]
diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L)
mahala.append((diff**2 * precision).sum(-1)) # (B, K)
mahala = torch.cat(mahala) # (T, K)
else:
col = torch.linalg.cholesky(self.sigma)
col = torch.linalg.cholesky(sigma)
log_det = (
torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2
) # [K]
precision = torch.cholesky_inverse(col).unsqueeze(0) # [1, K, L, L]
) # (K,)
precision = torch.cholesky_inverse(col).unsqueeze(0) # (1, K, L, L)
mahala = []
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
xp = batch_x.to(device)
diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L]
right = torch.matmul(precision, diff.unsqueeze(-1)) # [B, K, L, 1]
diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L)
right = torch.matmul(precision, diff.unsqueeze(-1)) # (B, K, L, 1)
mahala.append(
torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1)
) # [B, K]
mahala = torch.cat(mahala) # [T, K]
numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # [T, K]
denom = torch.logsumexp(numer, dim=-1, keepdim=True) # [T, 1]
posterior = torch.exp(numer - denom) # [T, K]
) # (B, K)
mahala = torch.cat(mahala) # (T, K)
numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # (T, K)
denom = torch.logsumexp(numer, dim=-1, keepdim=True) # (T, 1)
posterior = torch.exp(numer - denom) # (T, K)
if reduction == "none":
log_likelihood = denom.squeeze(-1)
elif reduction == "sum":
Expand Down
36 changes: 30 additions & 6 deletions diffsptk/modules/lbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# ------------------------------------------------------------------------ #

import logging
import math

import torch
from torch import nn
from tqdm import tqdm

from ..misc.utils import is_power_of_two
from ..misc.utils import to_dataloader
from .gmm import GaussianMixtureModeling
from .vq import VectorQuantization


Expand Down Expand Up @@ -52,23 +54,28 @@ class LindeBuzoGrayAlgorithm(nn.Module):
init : ['none', 'mean'] or torch.Tensor [shape=(1~K, M+1)]
Initialization type.

metric : ['none, 'aic', 'bic']
Metric used as a reference for model selection.

batch_size : int >= 1 or None
Batch size.

verbose : bool or int
If True, print progress.
If 1, show distance at each iteration; if 2, show progress bar.

"""

def __init__(
self,
order,
codebook_size,
*,
min_data_per_cluster=1,
n_iter=100,
eps=1e-5,
perturb_factor=1e-5,
init="mean",
metric="none",
batch_size=None,
verbose=False,
):
Expand All @@ -87,6 +94,7 @@ def __init__(
self.n_iter = n_iter
self.eps = eps
self.perturb_factor = perturb_factor
self.metric = metric
self.batch_size = batch_size
self.verbose = verbose

Expand All @@ -111,6 +119,8 @@ def __init__(
handler.setFormatter(formatter)
self.logger.addHandler(handler)

self.hide_progress_bar = self.verbose <= 1

def forward(self, x, return_indices=False):
"""Design a codebook.

Expand Down Expand Up @@ -157,7 +167,7 @@ def forward(self, x, return_indices=False):
if self.verbose:
self.logger.info("K = 1")
first = True
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
assert batch_x.dim() == 2
if first:
s = batch_x.sum(0)
Expand All @@ -168,13 +178,13 @@ def forward(self, x, return_indices=False):
T += batch_x.size(0)
self.vq.codebook[0] = s / T
else:
raise ValueError(f"Invalid initialization type: {self.init}")
raise ValueError(f"init {self.init} is not supported.")
self.vq.codebook[self.curr_codebook_size :] = 1e10

def e_step(x):
indices = []
distance = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
batch_xp = batch_x.to(device)
batch_xq, batch_indices, _ = self.vq(batch_xp)
indices.append(batch_indices)
Expand Down Expand Up @@ -203,7 +213,7 @@ def e_step(x):
# E-step: evaluate model.
indices, distance = e_step(x)
if self.verbose:
self.logger.info(f"iter {n+1:5d}: distance = {distance:g}")
self.logger.info(f" iter {n+1:5d}: distance = {distance:g}")

# Check convergence.
change = (prev_distance - distance).abs()
Expand All @@ -228,7 +238,7 @@ def e_step(x):
)
idx = indices.unsqueeze(1).expand(-1, self.order + 1)
b = 0
for (batch_x,) in tqdm(x, disable=self.verbose <= 1):
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
centroids.scatter_add_(0, idx[b:e], batch_x.to(device))
b = e
Expand All @@ -244,6 +254,20 @@ def e_step(x):

self.vq.codebook[: self.curr_codebook_size] = centroids

if self.metric != "none":
gmm = GaussianMixtureModeling(self.order, self.curr_codebook_size)
gmm.set_params((None, centroids, None))
_, log_likelihood = gmm._e_step(x)
n_param = self.curr_codebook_size * (self.order + 1)
if self.metric == "aic":
metric = -2 * log_likelihood + n_param * 2
elif self.metric == "bic":
metric = -2 * log_likelihood + n_param * math.log(len(indices))
else:
raise ValueError(f"metric {self.metric} is not supported.")
if self.verbose:
self.logger.info(f" {self.metric.upper()} = {metric:g}")

ret = [self.vq.codebook]

if return_indices:
Expand Down
Loading