diff --git a/Makefile b/Makefile index ffcf7d1..5d57db8 100644 --- a/Makefile +++ b/Makefile @@ -51,7 +51,7 @@ check: tool . ./venv/bin/activate && python -m ruff format --check $(PROJECT) tests docs . ./venv/bin/activate && python -m isort --check $(PROJECT) tests . ./venv/bin/activate && python -m mdformat --check *.md - . ./venv/bin/activate && python -m docstrfmt --check docs + . ./venv/bin/activate && cd docs && python -m docstrfmt --check . ./tools/taplo/taplo fmt --check *.toml ./tools/yamlfmt/yamlfmt --lint *.cff *.yml .github/workflows/*.yml @@ -60,7 +60,7 @@ format: tool . ./venv/bin/activate && python -m ruff format $(PROJECT) tests docs . ./venv/bin/activate && python -m isort $(PROJECT) tests . ./venv/bin/activate && python -m mdformat *.md - . ./venv/bin/activate && python -m docstrfmt docs + . ./venv/bin/activate && cd docs && python -m docstrfmt --check . ./tools/taplo/taplo fmt *.toml ./tools/yamlfmt/yamlfmt *.cff *.yml .github/workflows/*.yml diff --git a/diffsptk/functional.py b/diffsptk/functional.py index 9db58e1..e70f759 100644 --- a/diffsptk/functional.py +++ b/diffsptk/functional.py @@ -175,7 +175,7 @@ def c2ndps(c, fft_length): ) -def cdist(c1, c2, full=False, reduction="mean", eps=1e-8): +def cdist(c1, c2, full=False, reduction="mean"): """Calculate cepstral distance between two inputs. Parameters @@ -192,16 +192,13 @@ def cdist(c1, c2, full=False, reduction="mean", eps=1e-8): reduction : ['none', 'mean', 'batchmean', 'sum'] Reduction type. - eps : float >= 0 - A small value to prevent NaN. - Returns ------- out : Tensor [shape=(...,) or scalar] Cepstral distance. """ - return nn.CepstralDistance._func(c1, c2, full=full, reduction=reduction, eps=eps) + return nn.CepstralDistance._func(c1, c2, full=full, reduction=reduction) def chroma(x, n_channel, sample_rate, norm=float("inf")): @@ -1953,7 +1950,7 @@ def rlevdur(a): return nn.ReverseLevinsonDurbin._func(a) -def rmse(x, y, reduction="mean", eps=1e-8): +def rmse(x, y, reduction="mean"): """Calculate RMSE. Parameters @@ -1967,16 +1964,13 @@ def rmse(x, y, reduction="mean", eps=1e-8): reduction : ['none', 'mean', 'sum'] Reduction type. - eps : float >= 0 - A small value to prevent NaN. - Returns ------- out : Tensor [shape=(...,) or scalar] RMSE. """ - return nn.RootMeanSquareError._func(x, y, reduction=reduction, eps=eps) + return nn.RootMeanSquareError._func(x, y, reduction=reduction) def root_pol(a, out_format="rectangular"): @@ -2074,88 +2068,6 @@ def spec( ) -def ssim( - x, - y, - reduction="mean", - *, - alpha=1, - beta=1, - gamma=1, - kernel_size=11, - sigma=1.5, - k1=0.01, - k2=0.03, - eps=1e-8, - padding="same", - dynamic_range=None, -): - """Calculate SSIM. - - Parameters - ---------- - x : Tensor [shape=(..., N, D)] - Input. - - y : Tensor [shape=(..., N, D)] - Target. - - reduction : ['none', 'mean', 'sum'] - Reduction type. - - alpha : float > 0 - Relative importance of luminance component. - - beta : float > 0 - Relative importance of contrast component. - - gamma : float > 0 - Relative importance of structure component. - - kernel_size : int >= 1 - Kernel size of Gaussian filter. - - sigma : float > 0 - Standard deviation of Gaussian filter. - - k1 : float > 0 - A small constant. - - k2 : float > 0 - A small constant. - - eps : float >= 0 - A small value to prevent NaN. - - padding : ['valid', 'same'] - Padding type. - - dynamic_range : float > 0 or None - Dynamic range of input. If None, input is automatically normalized. - - Returns - ------- - out : Tensor [shape=(..., N, D) or scalar] - SSIM or mean SSIM. - - """ - return nn.StructuralSimilarityIndex._func( - x, - y, - reduction=reduction, - alpha=alpha, - beta=beta, - gamma=gamma, - kernel_size=kernel_size, - sigma=sigma, - k1=k1, - k2=k2, - eps=eps, - padding=padding, - dynamic_range=dynamic_range, - ) - - def stft( x, *, diff --git a/diffsptk/misc/utils.py b/diffsptk/misc/utils.py index cda55e6..d20d7c7 100644 --- a/diffsptk/misc/utils.py +++ b/diffsptk/misc/utils.py @@ -254,6 +254,12 @@ def clog(x): return torch.log(x.abs()) +def outer(x, y=None): + return torch.matmul( + x.unsqueeze(-1), x.unsqueeze(-2) if y is None else y.unsqueeze(-2) + ) + + def iir(x, b=None, a=None): if b is None: b = torch.ones(1, dtype=x.dtype, device=x.device) diff --git a/diffsptk/modules/__init__.py b/diffsptk/modules/__init__.py index 71c7052..83b3c46 100644 --- a/diffsptk/modules/__init__.py +++ b/diffsptk/modules/__init__.py @@ -124,8 +124,6 @@ from .snr import SignalToNoiseRatio from .snr import SignalToNoiseRatio as SNR from .spec import Spectrum -from .ssim import StructuralSimilarityIndex -from .ssim import StructuralSimilarityIndex as SSIM from .stft import ShortTimeFourierTransform from .stft import ShortTimeFourierTransform as STFT from .ulaw import MuLawCompression diff --git a/diffsptk/modules/cdist.py b/diffsptk/modules/cdist.py index 3d18b11..6caa6e2 100644 --- a/diffsptk/modules/cdist.py +++ b/diffsptk/modules/cdist.py @@ -30,20 +30,15 @@ class CepstralDistance(nn.Module): reduction : ['none', 'mean', 'batchmean', 'sum'] Reduction type. - eps : float >= 0 - A small value to prevent NaN. - """ - def __init__(self, full=False, reduction="mean", eps=1e-8): + def __init__(self, full=False, reduction="mean"): super().__init__() assert reduction in ("none", "mean", "batchmean", "sum") - assert 0 <= eps self.full = full self.reduction = reduction - self.eps = eps def forward(self, c1, c2): """Calculate cepstral distance between two inputs. @@ -75,11 +70,11 @@ def forward(self, c1, c2): tensor(1.6551) """ - return self._forward(c1, c2, self.full, self.reduction, self.eps) + return self._forward(c1, c2, self.full, self.reduction) @staticmethod - def _forward(c1, c2, full, reduction, eps): - distance = torch.sqrt((c1[..., 1:] - c2[..., 1:]).square().sum(-1) + eps) + def _forward(c1, c2, full, reduction): + distance = torch.linalg.vector_norm(c1[..., 1:] - c2[..., 1:], ord=2, dim=-1) if reduction == "none": pass diff --git a/diffsptk/modules/gmm.py b/diffsptk/modules/gmm.py index 5a3dfce..9454010 100644 --- a/diffsptk/modules/gmm.py +++ b/diffsptk/modules/gmm.py @@ -21,12 +21,13 @@ from torch import nn from tqdm import tqdm +from ..misc.utils import outer from ..misc.utils import to_dataloader class GaussianMixtureModeling(nn.Module): """See `this page `_ - for details. This module is not differentiable. + for details. Note that the forward method is not differentiable. Parameters ---------- @@ -210,11 +211,10 @@ def warmup(self, x, **lbg_params): b = 0 for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) - xp = batch_x.to(device) - xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2)) + xx = outer(batch_x.to(device)) kxx.scatter_add_(0, idx[b:e], xx) b = e - mm = torch.matmul(mu.unsqueeze(-1), mu.unsqueeze(-2)) # (K, L, L) + mm = outer(mu) # (K, L, L) sigma = kxx / count.view(-1, 1, 1) - mm sigma = sigma * self.mask @@ -328,24 +328,21 @@ def forward(self, x, return_posterior=False): b = 0 for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): e = b + batch_x.size(0) - xp = batch_x.to(device) - xx = torch.matmul(xp.unsqueeze(-1), xp.unsqueeze(-2)) + xx = outer(batch_x.to(device)) pxx.append(torch.einsum("bk,blm->klm", posterior[b:e], xx)) b = e pxx = sum(pxx) - mm = torch.matmul(self.mu.unsqueeze(-1), self.mu.unsqueeze(-2)) + mm = outer(self.mu) if self.alpha == 0: sigma = pxx * z.view(-1, 1, 1) - mm else: y = posterior.sum(dim=0) nu = px / y.view(-1, 1) - nm = torch.matmul(nu.unsqueeze(-1), self.mu.unsqueeze(-2)) + nm = outer(nu, self.mu) mn = nm.transpose(-2, -1) a = pxx - y.view(-1, 1, 1) * (nm + mn - mm) b = xi.view(-1, 1, 1) * self.ubm_sigma - diff = self.ubm_mu - self.mu - dd = torch.matmul(diff.unsqueeze(-1), diff.unsqueeze(-2)) - c = xi.view(-1, 1, 1) * dd + c = xi.view(-1, 1, 1) * outer(self.ubm_mu - self.mu) sigma = (a + b + c) * z.view(-1, 1, 1) self.sigma = sigma * self.mask self.sigma.diagonal(dim1=-2, dim2=-1).clamp_(min=self.var_floor) diff --git a/diffsptk/modules/lbg.py b/diffsptk/modules/lbg.py index f31a945..761cd89 100644 --- a/diffsptk/modules/lbg.py +++ b/diffsptk/modules/lbg.py @@ -29,7 +29,7 @@ class LindeBuzoGrayAlgorithm(nn.Module): """See `this page `_ - for details. This module is not differentiable. + for details. Note that the forward method is not differentiable. Parameters ---------- diff --git a/diffsptk/modules/pca.py b/diffsptk/modules/pca.py index a2b519c..eb1ea85 100644 --- a/diffsptk/modules/pca.py +++ b/diffsptk/modules/pca.py @@ -16,13 +16,15 @@ import torch from torch import nn +from tqdm import tqdm -from ..misc.utils import check_size +from ..misc.utils import outer +from ..misc.utils import to_dataloader class PrincipalComponentAnalysis(nn.Module): """See `this page `_ - for details. + for details. Note that the forward method is not differentiable. Parameters ---------- @@ -38,39 +40,66 @@ class PrincipalComponentAnalysis(nn.Module): sort : ['ascending', 'descending'] Order of eigenvalues and eigenvectors. + batch_size : int >= 1 or None + Batch size. + + verbose : bool + If True, show progress bar. + """ - def __init__(self, order, n_comp, cov_type="sample", sort="descending"): + def __init__( + self, + order, + n_comp, + cov_type="sample", + sort="descending", + batch_size=None, + verbose=False, + ): super().__init__() assert 0 <= order assert 1 <= n_comp <= order + 1 assert sort in ["ascending", "descending"] - self.order = order self.n_comp = n_comp - self.cov_type = cov_type self.sort = sort + self.batch_size = batch_size + self.hide_progress_bar = not verbose + + def sample_cov(x0, x1, x2): + return x2 / x0 - torch.outer(x1, x1) / (x0 * x0) if cov_type in (0, "sample"): - self.cov = lambda x: torch.cov(x, correction=0) + + def cov(x0, x1, x2): + return sample_cov(x0, x1, x2) elif cov_type in (1, "unbiased"): - self.cov = lambda x: torch.cov(x, correction=1) + + def cov(x0, x1, x2): + c = sample_cov(x0, x1, x2) + return c * (x0 / (x0 - 1)) elif cov_type in (2, "correlation"): - self.cov = lambda x: torch.corrcoef(x) + + def cov(x0, x1, x2): + c = sample_cov(x0, x1, x2) + v = c.diag().sqrt() + return c / torch.outer(v, v) else: raise ValueError(f"cov_type {cov_type} is not supported.") + self.cov = cov - self.register_buffer("v", torch.eye(self.order + 1, self.n_comp)) - self.register_buffer("m", torch.zeros(self.order + 1)) + self.register_buffer("v", torch.eye(order + 1, n_comp)) + self.register_buffer("m", torch.zeros(order + 1)) def forward(self, x): """Perform PCA. Parameters ---------- - x : Tensor [shape=(..., M+1)] - Input vectors. + x : Tensor [shape=(T, M+1)] or DataLoader + Input vectors or dataloader yielding input vectors. Returns ------- @@ -97,19 +126,40 @@ def forward(self, x): torch.Size([10, 3]) """ - check_size(x.size(-1), self.order + 1, "dimension of input") - - x = x.reshape(-1, x.size(-1)).T - assert self.n_comp + 1 <= x.size(1), "Number of data samples is too small" - - e, v = torch.linalg.eigh(self.cov(x)) + x = to_dataloader(x, self.batch_size) + device = self.m.device + + # Compute statistics. + first = True + for (batch_x,) in tqdm(x, disable=self.hide_progress_bar): + assert batch_x.dim() == 2 + xp = batch_x.to(device) + if first: + x0 = xp.size(0) + x1 = xp.sum(0) + x2 = outer(xp).sum(0) + first = False + else: + x0 += xp.size(0) + x1 += xp.sum(0) + x2 += outer(xp).sum(0) + + if x0 <= self.n_comp: + raise RuntimeError("Number of data samples is too small.") + + # Compute mean and covariance matrix. + m = x1 / x0 + c = self.cov(x0, x1, x2) + + # Compute eigenvalues and eigenvectors. + e, v = torch.linalg.eigh(c) e = e[-self.n_comp :] v = v[:, -self.n_comp :] if self.sort == "descending": e = e.flip(-1) v = v.flip(-1) self.v[:] = v - self.m[:] = x.mean(1) + self.m[:] = m return e, self.v, self.m def transform(self, x): diff --git a/diffsptk/modules/rmse.py b/diffsptk/modules/rmse.py index 99ff1b7..48a46dc 100644 --- a/diffsptk/modules/rmse.py +++ b/diffsptk/modules/rmse.py @@ -27,29 +27,24 @@ class RootMeanSquareError(nn.Module): reduction : ['none', 'mean', 'sum'] Reduction type. - eps : float >= 0 - A small value to prevent NaN. - """ - def __init__(self, reduction="mean", eps=1e-8): + def __init__(self, reduction="mean"): super().__init__() assert reduction in ("none", "mean", "sum") - assert 0 <= eps self.reduction = reduction - self.eps = eps def forward(self, x, y): """Calculate RMSE. Parameters ---------- - x : Tensor [shape=(..., T)] + x : Tensor [shape=(..., D)] Input. - y : Tensor [shape=(..., T)] + y : Tensor [shape=(..., D)] Target. Returns @@ -71,11 +66,11 @@ def forward(self, x, y): tensor(1.8340) """ - return self._forward(x, y, self.reduction, self.eps) + return self._forward(x, y, self.reduction) @staticmethod - def _forward(x, y, reduction, eps): - error = torch.sqrt(torch.square(x - y).mean(-1) + eps) + def _forward(x, y, reduction): + error = torch.linalg.vector_norm(x - y, ord=2, dim=-1) / x.size(-1) ** 0.5 if reduction == "none": pass elif reduction == "sum": diff --git a/diffsptk/modules/ssim.py b/diffsptk/modules/ssim.py deleted file mode 100644 index 8c8478c..0000000 --- a/diffsptk/modules/ssim.py +++ /dev/null @@ -1,250 +0,0 @@ -# ------------------------------------------------------------------------ # -# Copyright 2022 SPTK Working Group # -# # -# Licensed under the Apache License, Version 2.0 (the "License"); # -# you may not use this file except in compliance with the License. # -# You may obtain a copy of the License at # -# # -# http://www.apache.org/licenses/LICENSE-2.0 # -# # -# Unless required by applicable law or agreed to in writing, software # -# distributed under the License is distributed on an "AS IS" BASIS, # -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # -# See the License for the specific language governing permissions and # -# limitations under the License. # -# ------------------------------------------------------------------------ # - -import torch -from torch import nn -import torch.nn.functional as F - -from ..misc.utils import to - - -class StructuralSimilarityIndex(nn.Module): - """Structural similarity index computation. - - Parameters - ---------- - reduction : ['none', 'mean', 'sum'] - Reduction type. - - alpha : float > 0 - Relative importance of luminance component. - - beta : float > 0 - Relative importance of contrast component. - - gamma : float > 0 - Relative importance of structure component. - - kernel_size : int >= 1 - Kernel size of Gaussian filter. - - sigma : float > 0 - Standard deviation of Gaussian filter. - - k1 : float > 0 - A small constant. - - k2 : float > 0 - A small constant. - - eps : float >= 0 - A small value to prevent NaN. - - padding : ['valid', 'same'] - Padding type. - - dynamic_range : float > 0 or None - Dynamic range of input. If None, input is automatically normalized. - - References - ---------- - .. [1] Z. Wang et al., "Image quality assessment: From error visibility to - structural similarity," *IEEE Transactions on Image Processing*, vol. 13, - no. 4, pp. 600-612, 2004. - - """ - - def __init__( - self, - reduction="mean", - *, - alpha=1, - beta=1, - gamma=1, - kernel_size=11, - sigma=1.5, - k1=0.01, - k2=0.03, - eps=1e-8, - padding="same", - dynamic_range=None, - ): - super().__init__() - - assert reduction in ["none", "mean", "sum"] - assert 1 <= kernel_size and kernel_size % 2 == 1 - assert 0 < sigma - assert 0 < k1 < 1 - assert 0 < k2 < 1 - assert 0 <= eps - - self.reduction = reduction - self.weights = (alpha, beta, gamma) - self.ks = (k1, k2) - self.eps = eps - self.padding = padding - self.dynamic_range = dynamic_range - self.register_buffer("kernel", self._precompute(kernel_size, sigma)) - - def forward(self, x, y): - """Calculate SSIM. - - Parameters - ---------- - x : Tensor [shape=(..., N, D)] - Input. - - y : Tensor [shape=(..., N, D)] - Target. - - Returns - ------- - out : Tensor [shape=(..., N, D) or scalar] - SSIM or mean SSIM. - - Examples - -------- - >>> x = diffsptk.nrand(20, 20) - >>> y = diffsptk.nrand(20, 20) - >>> ssim = diffsptk.StructuralSimilarityIndex() - >>> s = ssim(x, y) - >>> s - tensor(0.0588) - - """ - return self._forward( - x, - y, - self.reduction, - self.weights, - self.ks, - self.eps, - self.padding, - self.dynamic_range, - self.kernel, - ) - - @staticmethod - def _forward(x, y, reduction, weights, ks, eps, padding, dynamic_range, kernel): - org_shape = x.shape - x = x.view(-1, 1, x.shape[-2], x.shape[-1]) - y = y.view(-1, 1, y.shape[-2], y.shape[-1]) - - # Normalize x and y to [0, 1]. - if dynamic_range is None: - x_max = torch.amax(x, dim=(-2, -1), keepdim=True) - x_min = torch.amin(x, dim=(-2, -1), keepdim=True) - y_max = torch.amax(y, dim=(-2, -1), keepdim=True) - y_min = torch.amin(y, dim=(-2, -1), keepdim=True) - xy_max = torch.maximum(x_max, y_max) - xy_min = torch.minimum(x_min, y_min) - d = xy_max - xy_min + eps - x = (x - xy_min) / d - y = (y - xy_min) / d - dynamic_range = 1 - - # Pad x and y. - if padding == "valid": - pass - elif padding == "same": - pad_size = kernel.shape[-1] // 2 - x = F.pad(x, (pad_size, pad_size, pad_size, pad_size), mode="reflect") - y = F.pad(y, (pad_size, pad_size, pad_size, pad_size), mode="reflect") - else: - raise ValueError(f"padding {padding} is not supported.") - - # Set constants. - K1, K2 = ks - L = dynamic_range - C1 = (K1 * L) ** 2 - C2 = (K2 * L) ** 2 - C3 = 0.5 * C2 - - # Calculate luminance. - mu_x = F.conv2d(x, kernel, padding=0) - mu_y = F.conv2d(y, kernel, padding=0) - mu_x2 = mu_x**2 - mu_y2 = mu_y**2 - luminance = (2 * mu_x * mu_y + C1) / (mu_x2 + mu_y2 + C1) - - # Calculate contrast. - sigma_x2 = torch.clip(F.conv2d(x**2, kernel, padding=0) - mu_x2, min=eps) - sigma_y2 = torch.clip(F.conv2d(y**2, kernel, padding=0) - mu_y2, min=eps) - sigma_x = torch.sqrt(sigma_x2) - sigma_y = torch.sqrt(sigma_y2) - contrast = (2 * sigma_x * sigma_y + C2) / (sigma_x2 + sigma_y2 + C2) - - # Calculate structure. - mu_xy = mu_x * mu_y - sigma_xy = F.conv2d(x * y, kernel, padding=0) - mu_xy - structure = (sigma_xy + C3) / (sigma_x * sigma_y + C3) - - # Calculate SSIM. - alpha, beta, gamma = weights - ssim = (luminance**alpha) * (contrast**beta) * (structure**gamma) - ssim = ssim.view(*org_shape[:-2], *ssim.shape[-2:]) - - if reduction == "none": - pass - elif reduction == "sum": - ssim = ssim.sum() - elif reduction == "mean": - ssim = ssim.mean() - else: - raise ValueError(f"reduction {reduction} is not supported.") - return ssim - - @staticmethod - def _func( - x, - y, - reduction, - alpha, - beta, - gamma, - kernel_size, - sigma, - k1, - k2, - eps, - padding, - dynamic_range, - ): - kernel = StructuralSimilarityIndex._precompute( - kernel_size, sigma, dtype=x.dtype, device=x.device - ) - return StructuralSimilarityIndex._forward( - x, - y, - reduction, - (alpha, beta, gamma), - (k1, k2), - eps, - padding, - dynamic_range, - kernel, - ) - - @staticmethod - def _precompute(kernel_size, sigma, dtype=None, device=None): - # Generate 2D Gaussian kernel. - center = kernel_size // 2 - x = torch.arange(kernel_size, dtype=torch.double, device=device) - center - xx = x**2 - G = torch.exp(-0.5 * (xx.unsqueeze(0) + xx.unsqueeze(1)) / sigma**2) - G /= G.sum() # Normalized to unit sum. - G = G.view(1, 1, kernel_size, kernel_size) - return to(G, dtype=dtype) diff --git a/docs/modules/ssim.rst b/docs/modules/ssim.rst deleted file mode 100644 index 8990057..0000000 --- a/docs/modules/ssim.rst +++ /dev/null @@ -1,15 +0,0 @@ -.. _ssim: - -ssim -==== - -.. autoclass:: diffsptk.SSIM - -.. autoclass:: diffsptk.StructuralSimilarityIndex - :members: - -.. autofunction:: diffsptk.functional.ssim - -.. seealso:: - - :ref:`rmse` diff --git a/pyproject.toml b/pyproject.toml index 735978a..61822df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ dev = [ "pytest", "pytest-cov", "ruff", - "scikit-image", "sphinx", "twine", ] diff --git a/tests/test_pca.py b/tests/test_pca.py index 0107f5a..7e559e0 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -24,7 +24,8 @@ @pytest.mark.parametrize("device", ["cpu", "cuda"]) @pytest.mark.parametrize("cov_type", [0, 1, 2]) -def test_compatibility(device, cov_type, B=10, M=4, N=3): +@pytest.mark.parametrize("batch_size", [None, 5]) +def test_compatibility(device, cov_type, batch_size, B=10, M=4, N=3): if device == "cuda" and not torch.cuda.is_available(): return @@ -42,7 +43,7 @@ def test_compatibility(device, cov_type, B=10, M=4, N=3): U.call(f"rm {tmp1} {tmp2}", get=False) # Python - pca = diffsptk.PCA(M, N, cov_type=cov_type).to(device) + pca = diffsptk.PCA(M, N, cov_type=cov_type, batch_size=batch_size).to(device) x = torch.from_numpy(U.call(f"nrand -l {B*(M+1)}")).reshape(B, M + 1).to(device) e, v, m = pca(x) e2 = e.cpu().numpy() diff --git a/tests/test_ssim.py b/tests/test_ssim.py deleted file mode 100644 index aaa0ae4..0000000 --- a/tests/test_ssim.py +++ /dev/null @@ -1,73 +0,0 @@ -# ------------------------------------------------------------------------ # -# Copyright 2022 SPTK Working Group # -# # -# Licensed under the Apache License, Version 2.0 (the "License"); # -# you may not use this file except in compliance with the License. # -# You may obtain a copy of the License at # -# # -# http://www.apache.org/licenses/LICENSE-2.0 # -# # -# Unless required by applicable law or agreed to in writing, software # -# distributed under the License is distributed on an "AS IS" BASIS, # -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # -# See the License for the specific language governing permissions and # -# limitations under the License. # -# ------------------------------------------------------------------------ # - -import pytest -from skimage.metrics import structural_similarity -import torch - -import diffsptk -import tests.utils as U - - -@pytest.mark.parametrize("device", ["cpu", "cuda"]) -@pytest.mark.parametrize("module", [False, True]) -def test_compatibility(device, module, B=1, T=100, D=100): - if device == "cuda" and not torch.cuda.is_available(): - return - - ssim = U.choice( - module, - diffsptk.SSIM, - diffsptk.functional.ssim, - {}, - {"reduction": "mean", "dynamic_range": 1, "padding": "valid"}, - n_input=2, - ) - if hasattr(ssim, "to"): - ssim = ssim.to(device) - - x = torch.rand(B, T, D, device=device) - y = torch.rand(B, T, D, device=device) - - s1 = structural_similarity( - x.cpu().numpy(), - y.cpu().numpy(), - channel_axis=0, - data_range=1, - gaussian_weights=True, - use_sample_covariance=False, - ) - - s2 = ssim(x, y).cpu().item() - assert U.allclose(s1, s2) - - s3 = ssim(x, x).cpu().item() - assert U.allclose(1, s3) - - U.check_differentiability(device, ssim, [(B, T, D), (B, T, D)]) - - -def test_special_case(B=2, T=30, D=30): - x = torch.rand(B, T, D) - y = torch.rand(B, T, D) - s = diffsptk.SSIM(reduction="none", padding="same", dynamic_range=1)(x, y) - assert s.shape == x.shape - - x = torch.randn(B, T, D) - y = torch.randn(B, T, D) - s1 = diffsptk.SSIM(reduction="sum", padding="same", dynamic_range=None)(x, y) - s2 = diffsptk.SSIM(reduction="mean", padding="same", dynamic_range=None)(x, y) - assert U.allclose(s1, s2 * x.numel())