Skip to content

Commit

Permalink
Extend optimal_linear_shrinkage for complex numbers
Browse files Browse the repository at this point in the history
norabelrose committed Sep 19, 2023
1 parent bf99e10 commit 0a04406
Showing 2 changed files with 10 additions and 11 deletions.
8 changes: 4 additions & 4 deletions concept_erasure/shrinkage.py
Original file line number Diff line number Diff line change
@@ -31,12 +31,12 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
trace_S = trace(S_n)
sigma0 = eye * trace_S / p

sigma0_norm_sq = sigma0.pow(2).sum(dim=(-2, -1), keepdim=True)
S_norm_sq = S_n.pow(2).sum(dim=(-2, -1), keepdim=True)
sigma0_norm_sq = sigma0.norm(dim=(-2, -1), keepdim=True) ** 2
S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2

prod_trace = trace(S_n @ sigma0)
top = trace_S.pow(2) * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace**2
top = trace_S * trace_S.conj() * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj()

# Epsilon prevents dividing by zero for the zero matrix. In that case we end up
# setting alpha = 0, beta = 1, but it doesn't matter since we're shrinking toward
13 changes: 6 additions & 7 deletions tests/test_shrinkage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest
import torch
from torch.distributions import MultivariateNormal

from concept_erasure import optimal_linear_shrinkage

@@ -18,21 +17,21 @@
(16, 256),
],
)
def test_olse_shrinkage(p: int, n: int):
@pytest.mark.parametrize("dtype", [torch.float32, torch.complex64])
def test_olse_shrinkage(p: int, n: int, dtype: torch.dtype):
torch.manual_seed(42)

# Number of matrices
N = 1000

# Generate a random covariance matrix
A = torch.randn(N, p, p)
A = torch.randn(N, p, p, dtype=dtype)
S_true = A @ A.mH / p
torch.linalg.diagonal(S_true).add_(1e-3)

# Generate data with this covariance
mean = torch.zeros(N, p)
dist = MultivariateNormal(mean, S_true)
X = dist.sample([n]).movedim(1, 0)
# Generate random Gaussian vectors with this covariance matrix
scale_tril = torch.linalg.cholesky(S_true)
X = torch.randn(N, n, p, dtype=dtype) @ scale_tril.mH
assert X.shape == (N, n, p)

# Compute the sample covariance matrix

0 comments on commit 0a04406

Please sign in to comment.