Skip to content

Commit

Permalink
Merge pull request #110 from sp-nitech/pca
Browse files Browse the repository at this point in the history
Update pca
  • Loading branch information
takenori-y authored Dec 10, 2024
2 parents c058537 + 82c65f5 commit 438a218
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 33 deletions.
4 changes: 2 additions & 2 deletions diffsptk/modules/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def forward(self, x, return_posterior=False):
z = posterior.sum(dim=0) + xi
self.w = z / (T + self.alpha)
z = 1 / z
self.w = torch.clamp(self.w, min=self.weight_floor)
self.w = torch.clip(self.w, min=self.weight_floor)
sum_floor = self.weight_floor * self.n_mixture
a = (1 - sum_floor) / (self.w.sum() - sum_floor)
b = self.weight_floor * (1 - a)
Expand Down Expand Up @@ -345,7 +345,7 @@ def forward(self, x, return_posterior=False):
c = xi.view(-1, 1, 1) * outer(self.ubm_mu - self.mu)
sigma = (a + b + c) * z.view(-1, 1, 1)
self.sigma = sigma * self.mask
self.sigma.diagonal(dim1=-2, dim2=-1).clamp_(min=self.var_floor)
self.sigma.diagonal(dim1=-2, dim2=-1).clip_(min=self.var_floor)

# Check convergence.
if self.verbose:
Expand Down
5 changes: 3 additions & 2 deletions diffsptk/modules/lbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ def forward(self, x, return_indices=False):
T = 0
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
assert batch_x.dim() == 2
s += batch_x.sum(0)
T += batch_x.size(0)
batch_xp = batch_x.to(device)
s += batch_xp.sum(0)
T += batch_xp.size(0)
self.vq.codebook[0] = s / T
else:
raise ValueError(f"init {self.init} is not supported.")
Expand Down
1 change: 1 addition & 0 deletions diffsptk/modules/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
n_data,
order,
n_comp,
*,
beta=0,
n_iter=100,
eps=1e-5,
Expand Down
81 changes: 56 additions & 25 deletions diffsptk/modules/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self,
order,
n_comp,
*,
cov_type="sample",
sort="descending",
batch_size=None,
Expand Down Expand Up @@ -90,7 +91,8 @@ def cov(x0, x1, x2):
raise ValueError(f"cov_type {cov_type} is not supported.")
self.cov = cov

self.register_buffer("v", torch.eye(n_comp, order + 1))
self.register_buffer("s", torch.zeros(n_comp))
self.register_buffer("V", torch.eye(n_comp, order + 1))
self.register_buffer("m", torch.zeros(order + 1))

def forward(self, x):
Expand All @@ -103,10 +105,10 @@ def forward(self, x):
Returns
-------
e : Tensor [shape=(K,)]
s : Tensor [shape=(K,)]
Eigenvalues.
v : Tensor [shape=(K, M+1)]
V : Tensor [shape=(K, M+1)]
Eigenvectors.
m : Tensor [shape=(M+1,)]
Expand All @@ -118,8 +120,8 @@ def forward(self, x):
>>> x.size()
torch.Size([10, 4])
>>> pca = diffsptk.PCA(3, 3)
>>> e, _, _ = pca(x)
>>> e
>>> s, _, _ = pca(x)
>>> s
tensor([1.3465, 0.7497, 0.4447])
>>> y = pca.transform(x)
>>> y.size()
Expand All @@ -130,19 +132,13 @@ def forward(self, x):
device = self.m.device

# Compute statistics.
first = True
x0 = x1 = x2 = 0
for (batch_x,) in tqdm(x, disable=self.hide_progress_bar):
assert batch_x.dim() == 2
xp = batch_x.to(device)
if first:
x0 = xp.size(0)
x1 = xp.sum(0)
x2 = outer(xp).sum(0)
first = False
else:
x0 += xp.size(0)
x1 += xp.sum(0)
x2 += outer(xp).sum(0)
x0 += xp.size(0)
x1 += xp.sum(0)
x2 += outer(xp).sum(0)

if x0 <= self.n_comp:
raise RuntimeError("Number of data samples is too small.")
Expand All @@ -152,15 +148,16 @@ def forward(self, x):
c = self.cov(x0, x1, x2)

# Compute eigenvalues and eigenvectors.
e, v = torch.linalg.eigh(c)
e = e[-self.n_comp :]
v = v[:, -self.n_comp :]
val, vec = torch.linalg.eigh(c)
val = val[-self.n_comp :]
vec = vec[:, -self.n_comp :]
if self.sort == "descending":
e = e.flip(-1)
v = v.flip(-1)
self.v[:] = v.T
val = val.flip(-1)
vec = vec.flip(-1)
self.s[:] = val
self.V[:] = vec.T
self.m[:] = m
return e, self.v, self.m
return self.s, self.V, self.m

def transform(self, x):
"""Transform input vectors using estimated eigenvectors.
Expand All @@ -176,6 +173,40 @@ def transform(self, x):
Transformed vectors.
"""
v = self.v.T
v = self.v.flip(-1) if self.sort == "ascending" else v
return torch.matmul(x - self.m, v)
V = self.V.T.flip(-1) if self.sort == "ascending" else self.V.T
return torch.matmul(self.center(x), V)

def center(self, x):
"""Center input vectors using estimated mean.
Parameters
----------
x : Tensor [shape=(..., M+1)]
Input vectors.
Returns
-------
out : Tensor [shape=(..., M+1)]
Centered vectors.
"""
return x - self.m

def whiten(self, x):
"""Whiten input vectors using estimated parameters.
Parameters
----------
x : Tensor [shape=(..., M+1)]
Input vectors.
Returns
-------
out : Tensor [shape=(..., K)]
Whitened vectors.
"""
V = self.V.T.flip(-1) if self.sort == "ascending" else self.V.T
s = self.s.flip(-1) if self.sort == "ascending" else self.s
d = torch.sqrt(torch.clip(s, min=1e-10))
return torch.matmul(x, V / d)
14 changes: 10 additions & 4 deletions tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,25 @@ def test_compatibility(device, cov_type, batch_size, B=10, M=4, K=3):
f"pca -m {M} -n {K} -u {cov_type} -v {tmp1} -d 1e-8 > {tmp2}"
)
U.call(cmd, get=False)
e1 = U.call(f"bcut -e {K-1} {tmp1}")
s1 = U.call(f"bcut -e {K-1} {tmp1}")
v1 = U.call(f"cat {tmp2}").reshape(K + 1, M + 1)
y1 = U.call(f"nrand -l {B*(M+1)} | pcas -m {M} -n {K} {tmp2}").reshape(-1, K)
U.call(f"rm {tmp1} {tmp2}", get=False)

# Python
pca = diffsptk.PCA(M, K, cov_type=cov_type, batch_size=batch_size).to(device)
x = torch.from_numpy(U.call(f"nrand -l {B*(M+1)}")).reshape(B, M + 1).to(device)
e, v, m = pca(x)
e2 = e.cpu().numpy()
s, v, m = pca(x)
s2 = s.cpu().numpy()
v2 = torch.cat([m.unsqueeze(0), v], dim=0).cpu().numpy()
y2 = pca.transform(x).cpu().numpy()

assert U.allclose(e1, e2)
assert U.allclose(s1, s2)
assert U.allclose(np.abs(v1), np.abs(v2))
assert U.allclose(np.abs(y1), np.abs(y2))

z = pca.center(x)
assert torch.allclose(torch.mean(z, dim=0), torch.zeros(M + 1))
if cov_type <= 1:
z = pca.whiten(x)
assert torch.allclose(torch.cov(z.T, correction=cov_type), torch.eye(K))

0 comments on commit 438a218

Please sign in to comment.