From 5fbf2cfd0070c257bd238de38c0cf8905ef31492 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 15:51:59 +0900 Subject: [PATCH 01/11] gmm vc without dynamic feature --- diffsptk/modules/gmm.py | 93 ++++++++++++++++++++++++++--------------- tests/test_gmm.py | 17 +++++++- 2 files changed, 76 insertions(+), 34 deletions(-) diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 4409272..14b114d 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -22,7 +22,6 @@ from tqdm import tqdm from ..misc.utils import to_dataloader -from .lbg import LindeBuzoGrayAlgorithm class GaussianMixtureModeling(nn.Module): @@ -65,7 +64,7 @@ class GaussianMixtureModeling(nn.Module): Batch size. verbose : bool - If True, print progress. + If 1, show distance at each iteration; if 2, show progress bar. """ @@ -73,6 +72,7 @@ def __init__( self, order, n_mixture, + *, n_iter=100, eps=1e-5, weight_floor=1e-5, @@ -155,6 +155,8 @@ def __init__( handler.setFormatter(formatter) self.logger.addHandler(handler) + self.hide_progress_bar = self.verbose <= 2 + def set_params(self, params): """Set model parameters. @@ -192,6 +194,8 @@ def warmup(self, x, **lbg_params): x = to_dataloader(x, batch_size=self.batch_size) device = self.w.device + from .lbg import LindeBuzoGrayAlgorithm + lbg = LindeBuzoGrayAlgorithm(self.order, self.n_mixture, **lbg_params).to( device ) @@ -202,15 +206,15 @@ def warmup(self, x, **lbg_params): mu = codebook idx = indices.view(-1, 1, 1).expand(-1, self.order + 1, self.order + 1) - kxx = torch.zeros_like(self.sigma) # [K, L, L] + kxx = torch.zeros_like(self.sigma) # (K, L, L) b = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) xp = batch_x.to(device) xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2)) kxx.scatter_add_(0, idx[b:e], xx) b = e - mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # [K, L, L] + mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # (K, L, L) sigma = kxx / count.view(-1, 1, 1) - mm sigma = sigma * self.mask @@ -287,7 +291,7 @@ def forward(self, x, return_posterior=False): # Update mean vectors. px = [] b = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) px.append(torch.matmul(posterior[b:e].t(), batch_x.to(device))) b = e @@ -301,7 +305,7 @@ def forward(self, x, return_posterior=False): if self.is_diag: pxx = [] b = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) xx = batch_x.to(device) ** 2 pxx.append(torch.matmul(posterior[b:e].t(), xx)) @@ -322,7 +326,7 @@ def forward(self, x, return_posterior=False): else: pxx = [] b = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) xp = batch_x.to(device) xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2)) @@ -364,27 +368,44 @@ def forward(self, x, return_posterior=False): return ret def transform(self, x): - """Transform input vectors to mixture indices. + """Transform input vectors based on a single mixture sequence. Parameters ---------- - x : Tensor [shape=(T, M+1)] + x : Tensor [shape=(T, N+1)] Input vectors. Returns ------- + y : Tensor [shape=(T, M-N)] + Input vectors. + indices : Tensor [shape=(T,)] - Mixture indices. + Selected mixture indices. log_prob : Tensor [shape=(T,)] Log probabilities. """ - posterior, log_prob = self._e_step(x, reduction="none") + N = x.size(-1) - 1 + posterior, log_prob = self._e_step(x, reduction="none", in_order=N) indices = torch.argmax(posterior, dim=-1) - return indices, log_prob - def _e_step(self, x, reduction="sum"): + if self.order == N: + return None, indices, log_prob + + L = N + 1 + sigma_yx = self.sigma[:, L:, :L] + sigma_xx = self.sigma[:, :L, :L] + sigma_yx_xx = torch.matmul(sigma_yx, torch.inverse(sigma_xx)) + mu_x = self.mu[indices, :L] + mu_y = self.mu[indices, L:] + diff = (x - mu_x).unsqueeze(-1) + E = mu_y + torch.matmul(sigma_yx_xx[indices], diff).squeeze(-1) + y = E + return y, indices, log_prob + + def _e_step(self, x, reduction="sum", in_order=None): """Expectation step. Parameters @@ -395,6 +416,9 @@ def _e_step(self, x, reduction="sum"): reduction : ['none', 'sum'] Reduction type. + in_order : int >= 0 + Order of input vectors. + Returns ------- posterior : Tensor [shape=(T, K)] @@ -407,38 +431,41 @@ def _e_step(self, x, reduction="sum"): x = to_dataloader(x, self.batch_size) device = self.w.device - log_pi = (self.order + 1) * np.log(2 * np.pi) + if in_order is None: + in_order = self.order + L = in_order + 1 + mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] + + log_pi = (in_order + 1) * np.log(2 * np.pi) if self.is_diag: - log_det = torch.log(torch.diagonal(self.sigma, dim1=-2, dim2=-1)).sum( - -1 - ) # [K] + log_det = torch.log(torch.diagonal(sigma, dim1=-2, dim2=-1)).sum(-1) # (K,) precision = torch.reciprocal( - torch.diagonal(self.sigma, dim1=-2, dim2=-1) - ) # [K, L] + torch.diagonal(sigma, dim1=-2, dim2=-1) + ) # (K, L) mahala = [] - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L] - mahala.append((diff**2 * precision).sum(-1)) # [B, K] - mahala = torch.cat(mahala) # [T, K] + diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) + mahala.append((diff**2 * precision).sum(-1)) # (B, K) + mahala = torch.cat(mahala) # (T, K) else: - col = torch.linalg.cholesky(self.sigma) + col = torch.linalg.cholesky(sigma) log_det = ( torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2 ) # [K] - precision = torch.cholesky_inverse(col).unsqueeze(0) # [1, K, L, L] + precision = torch.cholesky_inverse(col).unsqueeze(0) # (1, K, L, L) mahala = [] - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # [B, K, L] - right = torch.matmul(precision, diff.unsqueeze(-1)) # [B, K, L, 1] + diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) + right = torch.matmul(precision, diff.unsqueeze(-1)) # (B, K, L, 1) mahala.append( torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1) ) # [B, K] - mahala = torch.cat(mahala) # [T, K] - numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # [T, K] - denom = torch.logsumexp(numer, dim=-1, keepdim=True) # [T, 1] - posterior = torch.exp(numer - denom) # [T, K] + mahala = torch.cat(mahala) # (T, K) + numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # (T, K) + denom = torch.logsumexp(numer, dim=-1, keepdim=True) # (T, 1) + posterior = torch.exp(numer - denom) # (T, K) if reduction == "none": log_likelihood = denom.squeeze(-1) elif reduction == "sum": diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 6b53ec9..9c6e49e 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -133,10 +133,25 @@ def _gmmp(x): [], f"nrand -l {B*(M+1)} -s 9", f"gmmp -m {M} -k {K} {optp} {tmp7}", - [f"rm {tmp7}"], + [], dx=M + 1, ) + def _vc(x): + return gmm.transform(x)[0] + + N = M // 2 + U.check_compatibility( + device, + _vc, + [], + f"nrand -l {B*(N+1)} -s 10 -d 10", + f"vc -m {N} -M {N} -k {K} {optp} {tmp7}", + [f"rm {tmp7}"], + dx=N + 1, + dy=N + 1, + ) + def test_posterior(M=3, K=4, B=32, n_iter=50): x = torch.randn(B, M + 1) From 2db7f4e794181583ca98777b47b91c0879ced9f3 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 15:53:28 +0900 Subject: [PATCH 02/11] support for displaying aic or bic in lbg --- diffsptk/modules/lbg.py | 36 ++++++++++++++++++++++++++++++------ tests/test_lbg.py | 16 ++++++++++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/diffsptk/modules/lbg.py b/diffsptk/modules/lbg.py index b2c3dc1..2d19499 100644 --- a/diffsptk/modules/lbg.py +++ b/diffsptk/modules/lbg.py @@ -15,6 +15,7 @@ # ------------------------------------------------------------------------ # import logging +import math import torch from torch import nn @@ -22,6 +23,7 @@ from ..misc.utils import is_power_of_two from ..misc.utils import to_dataloader +from .gmm import GaussianMixtureModeling from .vq import VectorQuantization @@ -52,11 +54,14 @@ class LindeBuzoGrayAlgorithm(nn.Module): init : ['none', 'mean'] or torch.Tensor [shape=(1~K, M+1)] Initialization type. + metric : ['none, 'aic', 'bic'] + Metric used as a reference for model selection. + batch_size : int >= 1 or None Batch size. verbose : bool or int - If True, print progress. + If 1, show distance at each iteration; if 2, show progress bar. """ @@ -64,11 +69,13 @@ def __init__( self, order, codebook_size, + *, min_data_per_cluster=1, n_iter=100, eps=1e-5, perturb_factor=1e-5, init="mean", + metric="none", batch_size=None, verbose=False, ): @@ -87,6 +94,7 @@ def __init__( self.n_iter = n_iter self.eps = eps self.perturb_factor = perturb_factor + self.metric = metric self.batch_size = batch_size self.verbose = verbose @@ -111,6 +119,8 @@ def __init__( handler.setFormatter(formatter) self.logger.addHandler(handler) + self.hide_progress_bar = self.verbose <= 2 + def forward(self, x, return_indices=False): """Design a codebook. @@ -157,7 +167,7 @@ def forward(self, x, return_indices=False): if self.verbose: self.logger.info("K = 1") first = True - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): assert batch_x.dim() == 2 if first: s = batch_x.sum(0) @@ -168,13 +178,13 @@ def forward(self, x, return_indices=False): T += batch_x.size(0) self.vq.codebook[0] = s / T else: - raise ValueError(f"Invalid initialization type: {self.init}") + raise ValueError(f"init {self.init} is not supported.") self.vq.codebook[self.curr_codebook_size :] = 1e10 def e_step(x): indices = [] distance = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): batch_xp = batch_x.to(device) batch_xq, batch_indices, _ = self.vq(batch_xp) indices.append(batch_indices) @@ -203,7 +213,7 @@ def e_step(x): # E-step: evaluate model. indices, distance = e_step(x) if self.verbose: - self.logger.info(f"iter {n+1:5d}: distance = {distance:g}") + self.logger.info(f" iter {n+1:5d}: distance = {distance:g}") # Check convergence. change = (prev_distance - distance).abs() @@ -228,7 +238,7 @@ def e_step(x): ) idx = indices.unsqueeze(1).expand(-1, self.order + 1) b = 0 - for (batch_x,) in tqdm(x, disable=self.verbose <= 1): + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) centroids.scatter_add_(0, idx[b:e], batch_x.to(device)) b = e @@ -244,6 +254,20 @@ def e_step(x): self.vq.codebook[: self.curr_codebook_size] = centroids + if self.metric != "none": + gmm = GaussianMixtureModeling(self.order, self.curr_codebook_size) + gmm.set_params((None, centroids, None)) + _, log_likelihood = gmm._e_step(x) + n_param = self.curr_codebook_size * (self.order + 1) + if self.metric == "aic": + metric = -2 * log_likelihood + n_param * 2 + elif self.metric == "bic": + metric = -2 * log_likelihood + n_param * math.log(len(indices)) + else: + raise ValueError(f"metric {self.metric} is not supported.") + if self.verbose: + self.logger.info(f" {self.metric.upper()} = {metric:g}") + ret = [self.vq.codebook] if return_indices: diff --git a/tests/test_lbg.py b/tests/test_lbg.py index 96c4206..e8a4021 100644 --- a/tests/test_lbg.py +++ b/tests/test_lbg.py @@ -60,12 +60,24 @@ def test_special_case(M=1, K=4, B=10, n_iter=10): torch.manual_seed(1234) x = torch.randn(B, M + 1) lbg = diffsptk.LBG( - M, K, n_iter=n_iter, min_data_per_cluster=int(B * 0.9), init="none" + M, + K, + n_iter=n_iter, + min_data_per_cluster=int(B * 0.9), + init="none", + metric="aic", ) _, idx1, dist = lbg(x, return_indices=True) _, idx2 = lbg.transform(x) assert torch.all(idx1 == idx2) - extra_lbg = diffsptk.LBG(M, K * 2, n_iter=n_iter, init=lbg.vq.codebook) + extra_lbg = diffsptk.LBG( + M, + K * 2, + n_iter=n_iter, + min_data_per_cluster=1, + init=lbg.vq.codebook, + metric="bic", + ) _, extra_dist = extra_lbg(x) assert extra_dist < dist From a160909b397bd57462d9cac047a303875994e7a2 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 15:54:10 +0900 Subject: [PATCH 03/11] update taplo --- tools/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/Makefile b/tools/Makefile index ee94d00..b3ede91 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -14,7 +14,7 @@ # limitations under the License. # # ------------------------------------------------------------------------ # -TAPLO_VERSION := 0.9.2 +TAPLO_VERSION := 0.9.3 YAMLFMT_VERSION := 0.13.0 all: SPTK taplo yamlfmt From 3e15125e3f48ee1a9c2de37d2cd23e70127e4806 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 15:58:26 +0900 Subject: [PATCH 04/11] fix --- diffsptk/modules/gmm.py | 4 ++-- diffsptk/modules/lbg.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 14b114d..b898a28 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -155,7 +155,7 @@ def __init__( handler.setFormatter(formatter) self.logger.addHandler(handler) - self.hide_progress_bar = self.verbose <= 2 + self.hide_progress_bar = self.verbose <= 1 def set_params(self, params): """Set model parameters. @@ -378,7 +378,7 @@ def transform(self, x): Returns ------- y : Tensor [shape=(T, M-N)] - Input vectors. + Output vectors. indices : Tensor [shape=(T,)] Selected mixture indices. diff --git a/diffsptk/modules/lbg.py b/diffsptk/modules/lbg.py index 2d19499..ff7a47c 100644 --- a/diffsptk/modules/lbg.py +++ b/diffsptk/modules/lbg.py @@ -119,7 +119,7 @@ def __init__( handler.setFormatter(formatter) self.logger.addHandler(handler) - self.hide_progress_bar = self.verbose <= 2 + self.hide_progress_bar = self.verbose <= 1 def forward(self, x, return_indices=False): """Design a codebook. From f2c24e4428d486cc1b34840c8094a3ca0ecb9e9e Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 16:20:09 +0900 Subject: [PATCH 05/11] avoid copying --- diffsptk/modules/gmm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index b898a28..49335c9 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -433,8 +433,10 @@ def _e_step(self, x, reduction="sum", in_order=None): if in_order is None: in_order = self.order - L = in_order + 1 - mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] + mu, sigma = self.mu, self.sigma + else: + L = in_order + 1 + mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] log_pi = (in_order + 1) * np.log(2 * np.pi) if self.is_diag: From 2b169f22c2e5565937f19da430b7ac80d8852a07 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 17:09:33 +0900 Subject: [PATCH 06/11] update Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index cb4cd51..8f3c7fd 100644 --- a/Makefile +++ b/Makefile @@ -25,11 +25,11 @@ PLATFORM := cu118 venv: test -d venv || python$(PYTHON_VERSION) -m venv venv - . ./venv/bin/activate && python -m pip install pip --upgrade + . ./venv/bin/activate && python -m pip install --upgrade pip + . ./venv/bin/activate && python -m pip install --upgrade wheel icc-rt . ./venv/bin/activate && python -m pip install torch==$(TORCH_VERSION)+$(PLATFORM) torchaudio==$(TORCHAUDIO_VERSION)+$(PLATFORM) \ --index-url https://download.pytorch.org/whl/$(PLATFORM) . ./venv/bin/activate && python -m pip install -e .[dev] - . ./venv/bin/activate && python -m pip install icc-rt dist: . ./venv/bin/activate && python -m build From 413f82f21742f8ff5fdc73cd9b2bd4e10f584200 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 5 Sep 2024 17:33:17 +0900 Subject: [PATCH 07/11] minor fix --- diffsptk/modules/gmm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 49335c9..5a3dfce 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -432,13 +432,13 @@ def _e_step(self, x, reduction="sum", in_order=None): device = self.w.device if in_order is None: - in_order = self.order + L = self.order + 1 mu, sigma = self.mu, self.sigma else: L = in_order + 1 mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] - log_pi = (in_order + 1) * np.log(2 * np.pi) + log_pi = L * np.log(2 * np.pi) if self.is_diag: log_det = torch.log(torch.diagonal(sigma, dim1=-2, dim2=-1)).sum(-1) # (K,) precision = torch.reciprocal( @@ -454,7 +454,7 @@ def _e_step(self, x, reduction="sum", in_order=None): col = torch.linalg.cholesky(sigma) log_det = ( torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2 - ) # [K] + ) # (K,) precision = torch.cholesky_inverse(col).unsqueeze(0) # (1, K, L, L) mahala = [] for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): @@ -463,7 +463,7 @@ def _e_step(self, x, reduction="sum", in_order=None): right = torch.matmul(precision, diff.unsqueeze(-1)) # (B, K, L, 1) mahala.append( torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1) - ) # [B, K] + ) # (B, K) mahala = torch.cat(mahala) # (T, K) numer = torch.log(self.w) - 0.5 * (log_pi + log_det + mahala) # (T, K) denom = torch.logsumexp(numer, dim=-1, keepdim=True) # (T, 1) From fac80c6c6c63db77eca587690bdd0c427efda2b7 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Sat, 7 Sep 2024 13:26:36 +0900 Subject: [PATCH 08/11] test --- .github/workflows/ci.yml | 3 --- diffsptk/modules/gmm.py | 16 +++++++++------- tests/test_gmm.py | 2 ++ 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1423444..d4de61e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,9 +17,6 @@ jobs: max-parallel: 4 matrix: include: - - python: 3.9 - torch: 2.0.0 - torchaudio: 2.0.1 - python: 3.12 torch: 2.4.0 torchaudio: 2.4.0 diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 5a3dfce..76cadd5 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -433,25 +433,27 @@ def _e_step(self, x, reduction="sum", in_order=None): if in_order is None: L = self.order + 1 - mu, sigma = self.mu, self.sigma + # mu, sigma = self.mu, self.sigma else: L = in_order + 1 - mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] + # mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] log_pi = L * np.log(2 * np.pi) if self.is_diag: - log_det = torch.log(torch.diagonal(sigma, dim1=-2, dim2=-1)).sum(-1) # (K,) + log_det = torch.log(torch.diagonal(self.sigma, dim1=-2, dim2=-1)).sum( + -1 + ) # (K,) precision = torch.reciprocal( - torch.diagonal(sigma, dim1=-2, dim2=-1) + torch.diagonal(self.sigma, dim1=-2, dim2=-1) ) # (K, L) mahala = [] for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) + diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # (B, K, L) mahala.append((diff**2 * precision).sum(-1)) # (B, K) mahala = torch.cat(mahala) # (T, K) else: - col = torch.linalg.cholesky(sigma) + col = torch.linalg.cholesky(self.sigma) log_det = ( torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2 ) # (K,) @@ -459,7 +461,7 @@ def _e_step(self, x, reduction="sum", in_order=None): mahala = [] for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) + diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # (B, K, L) right = torch.matmul(precision, diff.unsqueeze(-1)) # (B, K, L, 1) mahala.append( torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 9c6e49e..b307d42 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -137,6 +137,7 @@ def _gmmp(x): dx=M + 1, ) + """ def _vc(x): return gmm.transform(x)[0] @@ -151,6 +152,7 @@ def _vc(x): dx=N + 1, dy=N + 1, ) + """ def test_posterior(M=3, K=4, B=32, n_iter=50): From 9c630633d4a4220be8223b338a41a218af5cf7b2 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Sat, 7 Sep 2024 13:33:21 +0900 Subject: [PATCH 09/11] test --- diffsptk/modules/gmm.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 76cadd5..5a3dfce 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -433,27 +433,25 @@ def _e_step(self, x, reduction="sum", in_order=None): if in_order is None: L = self.order + 1 - # mu, sigma = self.mu, self.sigma + mu, sigma = self.mu, self.sigma else: L = in_order + 1 - # mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] + mu, sigma = self.mu[:, :L], self.sigma[:, :L, :L] log_pi = L * np.log(2 * np.pi) if self.is_diag: - log_det = torch.log(torch.diagonal(self.sigma, dim1=-2, dim2=-1)).sum( - -1 - ) # (K,) + log_det = torch.log(torch.diagonal(sigma, dim1=-2, dim2=-1)).sum(-1) # (K,) precision = torch.reciprocal( - torch.diagonal(self.sigma, dim1=-2, dim2=-1) + torch.diagonal(sigma, dim1=-2, dim2=-1) ) # (K, L) mahala = [] for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # (B, K, L) + diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) mahala.append((diff**2 * precision).sum(-1)) # (B, K) mahala = torch.cat(mahala) # (T, K) else: - col = torch.linalg.cholesky(self.sigma) + col = torch.linalg.cholesky(sigma) log_det = ( torch.log(torch.diagonal(col, dim1=-2, dim2=-1)).sum(-1) * 2 ) # (K,) @@ -461,7 +459,7 @@ def _e_step(self, x, reduction="sum", in_order=None): mahala = [] for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): xp = batch_x.to(device) - diff = xp.unsqueeze(1) - self.mu.unsqueeze(0) # (B, K, L) + diff = xp.unsqueeze(1) - mu.unsqueeze(0) # (B, K, L) right = torch.matmul(precision, diff.unsqueeze(-1)) # (B, K, L, 1) mahala.append( torch.matmul(diff.unsqueeze(-2), right).squeeze(-1).squeeze(-1) From 59a0440b7d659a0dc6688f29f2f8da3f94ca42d0 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Sat, 7 Sep 2024 13:46:19 +0900 Subject: [PATCH 10/11] test --- tests/test_gmm.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index b307d42..9c6e49e 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -137,7 +137,6 @@ def _gmmp(x): dx=M + 1, ) - """ def _vc(x): return gmm.transform(x)[0] @@ -152,7 +151,6 @@ def _vc(x): dx=N + 1, dy=N + 1, ) - """ def test_posterior(M=3, K=4, B=32, n_iter=50): From 2369f133283387d60c298997c6bc0b45561fa0fa Mon Sep 17 00:00:00 2001 From: takenori-y Date: Sat, 7 Sep 2024 14:01:41 +0900 Subject: [PATCH 11/11] test --- tests/test_gmm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_gmm.py b/tests/test_gmm.py index 9c6e49e..f185b3c 100644 --- a/tests/test_gmm.py +++ b/tests/test_gmm.py @@ -32,6 +32,8 @@ def test_compatibility( if device == "cuda" and not torch.cuda.is_available(): return + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) gmm = diffsptk.GMM( M, K,