diff --git a/concept_erasure/shrinkage.py b/concept_erasure/shrinkage.py index d85a3be..3ee2fb1 100644 --- a/concept_erasure/shrinkage.py +++ b/concept_erasure/shrinkage.py @@ -31,12 +31,12 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor: trace_S = trace(S_n) sigma0 = eye * trace_S / p - sigma0_norm_sq = sigma0.pow(2).sum(dim=(-2, -1), keepdim=True) - S_norm_sq = S_n.pow(2).sum(dim=(-2, -1), keepdim=True) + sigma0_norm_sq = sigma0.norm(dim=(-2, -1), keepdim=True) ** 2 + S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2 prod_trace = trace(S_n @ sigma0) - top = trace_S.pow(2) * sigma0_norm_sq / n - bottom = S_norm_sq * sigma0_norm_sq - prod_trace**2 + top = trace_S * trace_S.conj() * sigma0_norm_sq / n + bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj() # Epsilon prevents dividing by zero for the zero matrix. In that case we end up # setting alpha = 0, beta = 1, but it doesn't matter since we're shrinking toward diff --git a/tests/test_shrinkage.py b/tests/test_shrinkage.py index d0c3644..7581c70 100644 --- a/tests/test_shrinkage.py +++ b/tests/test_shrinkage.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.distributions import MultivariateNormal from concept_erasure import optimal_linear_shrinkage @@ -18,21 +17,21 @@ (16, 256), ], ) -def test_olse_shrinkage(p: int, n: int): +@pytest.mark.parametrize("dtype", [torch.float32, torch.complex64]) +def test_olse_shrinkage(p: int, n: int, dtype: torch.dtype): torch.manual_seed(42) # Number of matrices N = 1000 # Generate a random covariance matrix - A = torch.randn(N, p, p) + A = torch.randn(N, p, p, dtype=dtype) S_true = A @ A.mH / p torch.linalg.diagonal(S_true).add_(1e-3) - # Generate data with this covariance - mean = torch.zeros(N, p) - dist = MultivariateNormal(mean, S_true) - X = dist.sample([n]).movedim(1, 0) + # Generate random Gaussian vectors with this covariance matrix + scale_tril = torch.linalg.cholesky(S_true) + X = torch.randn(N, n, p, dtype=dtype) @ scale_tril.mH assert X.shape == (N, n, p) # Compute the sample covariance matrix