Skip to content

Commit

Permalink
Merge pull request #101 from sp-nitech/pca
Browse files Browse the repository at this point in the history
Support batch input in pca
  • Loading branch information
takenori-y authored Oct 10, 2024
2 parents 2ed279f + f1bad04 commit effd7b2
Show file tree
Hide file tree
Showing 14 changed files with 103 additions and 488 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ check: tool
. ./venv/bin/activate && python -m ruff format --check $(PROJECT) tests docs
. ./venv/bin/activate && python -m isort --check $(PROJECT) tests
. ./venv/bin/activate && python -m mdformat --check *.md
. ./venv/bin/activate && python -m docstrfmt --check docs
. ./venv/bin/activate && cd docs && python -m docstrfmt --check .
./tools/taplo/taplo fmt --check *.toml
./tools/yamlfmt/yamlfmt --lint *.cff *.yml .github/workflows/*.yml

Expand All @@ -60,7 +60,7 @@ format: tool
. ./venv/bin/activate && python -m ruff format $(PROJECT) tests docs
. ./venv/bin/activate && python -m isort $(PROJECT) tests
. ./venv/bin/activate && python -m mdformat *.md
. ./venv/bin/activate && python -m docstrfmt docs
. ./venv/bin/activate && cd docs && python -m docstrfmt --check .
./tools/taplo/taplo fmt *.toml
./tools/yamlfmt/yamlfmt *.cff *.yml .github/workflows/*.yml

Expand Down
96 changes: 4 additions & 92 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def c2ndps(c, fft_length):
)


def cdist(c1, c2, full=False, reduction="mean", eps=1e-8):
def cdist(c1, c2, full=False, reduction="mean"):
"""Calculate cepstral distance between two inputs.
Parameters
Expand All @@ -192,16 +192,13 @@ def cdist(c1, c2, full=False, reduction="mean", eps=1e-8):
reduction : ['none', 'mean', 'batchmean', 'sum']
Reduction type.
eps : float >= 0
A small value to prevent NaN.
Returns
-------
out : Tensor [shape=(...,) or scalar]
Cepstral distance.
"""
return nn.CepstralDistance._func(c1, c2, full=full, reduction=reduction, eps=eps)
return nn.CepstralDistance._func(c1, c2, full=full, reduction=reduction)


def chroma(x, n_channel, sample_rate, norm=float("inf")):
Expand Down Expand Up @@ -1953,7 +1950,7 @@ def rlevdur(a):
return nn.ReverseLevinsonDurbin._func(a)


def rmse(x, y, reduction="mean", eps=1e-8):
def rmse(x, y, reduction="mean"):
"""Calculate RMSE.
Parameters
Expand All @@ -1967,16 +1964,13 @@ def rmse(x, y, reduction="mean", eps=1e-8):
reduction : ['none', 'mean', 'sum']
Reduction type.
eps : float >= 0
A small value to prevent NaN.
Returns
-------
out : Tensor [shape=(...,) or scalar]
RMSE.
"""
return nn.RootMeanSquareError._func(x, y, reduction=reduction, eps=eps)
return nn.RootMeanSquareError._func(x, y, reduction=reduction)


def root_pol(a, out_format="rectangular"):
Expand Down Expand Up @@ -2074,88 +2068,6 @@ def spec(
)


def ssim(
x,
y,
reduction="mean",
*,
alpha=1,
beta=1,
gamma=1,
kernel_size=11,
sigma=1.5,
k1=0.01,
k2=0.03,
eps=1e-8,
padding="same",
dynamic_range=None,
):
"""Calculate SSIM.
Parameters
----------
x : Tensor [shape=(..., N, D)]
Input.
y : Tensor [shape=(..., N, D)]
Target.
reduction : ['none', 'mean', 'sum']
Reduction type.
alpha : float > 0
Relative importance of luminance component.
beta : float > 0
Relative importance of contrast component.
gamma : float > 0
Relative importance of structure component.
kernel_size : int >= 1
Kernel size of Gaussian filter.
sigma : float > 0
Standard deviation of Gaussian filter.
k1 : float > 0
A small constant.
k2 : float > 0
A small constant.
eps : float >= 0
A small value to prevent NaN.
padding : ['valid', 'same']
Padding type.
dynamic_range : float > 0 or None
Dynamic range of input. If None, input is automatically normalized.
Returns
-------
out : Tensor [shape=(..., N, D) or scalar]
SSIM or mean SSIM.
"""
return nn.StructuralSimilarityIndex._func(
x,
y,
reduction=reduction,
alpha=alpha,
beta=beta,
gamma=gamma,
kernel_size=kernel_size,
sigma=sigma,
k1=k1,
k2=k2,
eps=eps,
padding=padding,
dynamic_range=dynamic_range,
)


def stft(
x,
*,
Expand Down
6 changes: 6 additions & 0 deletions diffsptk/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,12 @@ def clog(x):
return torch.log(x.abs())


def outer(x, y=None):
return torch.matmul(
x.unsqueeze(-1), x.unsqueeze(-2) if y is None else y.unsqueeze(-2)
)


def iir(x, b=None, a=None):
if b is None:
b = torch.ones(1, dtype=x.dtype, device=x.device)
Expand Down
2 changes: 0 additions & 2 deletions diffsptk/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@
from .snr import SignalToNoiseRatio
from .snr import SignalToNoiseRatio as SNR
from .spec import Spectrum
from .ssim import StructuralSimilarityIndex
from .ssim import StructuralSimilarityIndex as SSIM
from .stft import ShortTimeFourierTransform
from .stft import ShortTimeFourierTransform as STFT
from .ulaw import MuLawCompression
Expand Down
13 changes: 4 additions & 9 deletions diffsptk/modules/cdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,15 @@ class CepstralDistance(nn.Module):
reduction : ['none', 'mean', 'batchmean', 'sum']
Reduction type.
eps : float >= 0
A small value to prevent NaN.
"""

def __init__(self, full=False, reduction="mean", eps=1e-8):
def __init__(self, full=False, reduction="mean"):
super().__init__()

assert reduction in ("none", "mean", "batchmean", "sum")
assert 0 <= eps

self.full = full
self.reduction = reduction
self.eps = eps

def forward(self, c1, c2):
"""Calculate cepstral distance between two inputs.
Expand Down Expand Up @@ -75,11 +70,11 @@ def forward(self, c1, c2):
tensor(1.6551)
"""
return self._forward(c1, c2, self.full, self.reduction, self.eps)
return self._forward(c1, c2, self.full, self.reduction)

@staticmethod
def _forward(c1, c2, full, reduction, eps):
distance = torch.sqrt((c1[..., 1:] - c2[..., 1:]).square().sum(-1) + eps)
def _forward(c1, c2, full, reduction):
distance = torch.linalg.vector_norm(c1[..., 1:] - c2[..., 1:], ord=2, dim=-1)

if reduction == "none":
pass
Expand Down
19 changes: 8 additions & 11 deletions diffsptk/modules/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from torch import nn
from tqdm import tqdm

from ..misc.utils import outer
from ..misc.utils import to_dataloader


class GaussianMixtureModeling(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/gmm.html>`_
for details. This module is not differentiable.
for details. Note that the forward method is not differentiable.
Parameters
----------
Expand Down Expand Up @@ -210,11 +211,10 @@ def warmup(self, x, **lbg_params):
b = 0
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
xp = batch_x.to(device)
xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2))
xx = outer(batch_x.to(device))
kxx.scatter_add_(0, idx[b:e], xx)
b = e
mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # (K, L, L)
mm = outer(mu) # (K, L, L)
sigma = kxx / count.view(-1, 1, 1) - mm
sigma = sigma * self.mask

Expand Down Expand Up @@ -328,24 +328,21 @@ def forward(self, x, return_posterior=False):
b = 0
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
e = b + batch_x.size(0)
xp = batch_x.to(device)
xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2))
xx = outer(batch_x.to(device))
pxx.append(torch.einsum("bk,blm->klm", posterior[b:e], xx))
b = e
pxx = sum(pxx)
mm = torch.matmul(self.mu.unsqueeze(-1), self.mu.unsqueeze(-2))
mm = outer(self.mu)
if self.alpha == 0:
sigma = pxx * z.view(-1, 1, 1) - mm
else:
y = posterior.sum(dim=0)
nu = px / y.view(-1, 1)
nm = torch.matmul(nu.unsqueeze(-1), self.mu.unsqueeze(-2))
nm = outer(nu, self.mu)
mn = nm.transpose(-2, -1)
a = pxx - y.view(-1, 1, 1) * (nm + mn - mm)
b = xi.view(-1, 1, 1) * self.ubm_sigma
diff = self.ubm_mu - self.mu
dd = torch.matmul(diff.unsqueeze(-1), diff.unsqueeze(-2))
c = xi.view(-1, 1, 1) * dd
c = xi.view(-1, 1, 1) * outer(self.ubm_mu - self.mu)
sigma = (a + b + c) * z.view(-1, 1, 1)
self.sigma = sigma * self.mask
self.sigma.diagonal(dim1=-2, dim2=-1).clamp_(min=self.var_floor)
Expand Down
2 changes: 1 addition & 1 deletion diffsptk/modules/lbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

class LindeBuzoGrayAlgorithm(nn.Module):
"""See `this page <https://sp-nitech.github.io/sptk/latest/main/lbg.html>`_
for details. This module is not differentiable.
for details. Note that the forward method is not differentiable.
Parameters
----------
Expand Down
Loading

0 comments on commit effd7b2

Please sign in to comment.