Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use shrinkage for (cross-)covariance estimation #261

Open
norabelrose opened this issue Jun 22, 2023 · 3 comments
Open

Use shrinkage for (cross-)covariance estimation #261

norabelrose opened this issue Jun 22, 2023 · 3 comments

Comments

@norabelrose
Copy link
Member

norabelrose commented Jun 22, 2023

We're now using the shrinkage technique from this paper in the concept-erasure repo; it makes covariance estimation robust to small sample sizes. Might make CRC-TPC, VINC, etc. work better

def gaussian_shrinkage(S_hat: Tensor, n: int) -> Tensor:
    """Applies Rao-Blackwell LW shrinkage to a sample covariance matrix."""
    p = S_hat.shape[-1]
    assert n > 1 and S_hat.shape == (p, p)

    trace_S = torch.trace(S_hat)
    trace_S_sq = torch.trace(S_hat ** 2)
    trace_sq_S = trace_S ** 2

    numer = (n - 2) / n * trace_S_sq + trace_sq_S
    denom = (n + 2) * (trace_S_sq - trace_sq_S / p)
    rho = torch.clamp(numer / denom, 0, 1)

    eye = torch.eye(p, dtype=S_hat.dtype, device=S_hat.device)
    F_hat = eye * trace_S / p

    return (1 - rho) * S_hat + rho * F_hat
@albanie
Copy link

albanie commented Jun 23, 2023

Should the denominator be:

denom = (n + 2) * (trace_S_sq - trace_sq_S / p)

rather than

denom = (n + 2) * (trace_S_sq + trace_sq_S / p)

?

@norabelrose
Copy link
Member Author

Yes it should be this thing
Captura de pantalla 2023-06-24 a la(s) 1 16 14 p m

@norabelrose
Copy link
Member Author

Actually we should use the distribution-free, random matrix theory-based, asymptotically Frobenius-optimal formula from https://arxiv.org/abs/1308.2608. Just switched the concept-erasure repo to it.

import torch
from torch import Tensor


def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
    """Optimal linear shrinkage for a sample covariance matrix or batch thereof.

    The formula is distribution-free and asymptotically optimal in the Frobenius norm
    as the dimensionality and sample size tend to infinity.

    See "On the Strong Convergence of the Optimal Linear Shrinkage Estimator for Large
    Dimensional Covariance Matrix" <https://arxiv.org/abs/1308.2608> for details.

    Args:
        S_n: Sample covariance matrices of shape (*, p, p).
        n: Sample size.
    """
    p = S_n.shape[-1]
    assert n > 1 and S_n.shape[-2:] == (p, p)

    # Sigma0 is actually a free parameter; here we're using an isotropic
    # covariance matrix with the same trace as S_n.
    # TODO: Make this configurable, try using diag(S_n) or something
    eye = torch.eye(p, dtype=S_n.dtype, device=S_n.device).expand_as(S_n)
    trace_S = trace(S_n)
    sigma0 = eye * trace_S / p

    sigma0_norm_sq = sigma0.pow(2).sum(dim=(-2, -1), keepdim=True)
    S_norm_sq = S_n.pow(2).sum(dim=(-2, -1), keepdim=True)

    prod_trace = trace(S_n @ sigma0)
    top = trace_S.pow(2) * sigma0_norm_sq / n
    bottom = S_norm_sq * sigma0_norm_sq - prod_trace**2

    alpha = 1 - top / bottom
    beta = (1 - alpha) * prod_trace / sigma0_norm_sq

    return alpha * S_n + beta * sigma0


def trace(matrices: Tensor) -> Tensor:
    """Version of `torch.trace` that works for batches of matrices."""
    diag = torch.linalg.diagonal(matrices)
    return diag.sum(dim=-1, keepdim=True).unsqueeze(-1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants