Skip to content

Commit

Permalink
Merge pull request #9 from EleutherAI/complex
Browse files Browse the repository at this point in the history
Basic support for complex numbers
  • Loading branch information
norabelrose authored Sep 19, 2023
2 parents 9689ef5 + c3736bf commit ff6d7bc
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 88 deletions.
3 changes: 2 additions & 1 deletion concept_erasure/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def invert_indices(indices: Tensor) -> Tensor:
reverse_indices = torch.empty_like(indices)

# Scatter the indices to reverse the permutation
reverse_indices.scatter_(0, indices, torch.arange(len(indices)))
arange = torch.arange(len(indices), device=indices.device)
reverse_indices.scatter_(0, indices, arange)

return reverse_indices
28 changes: 14 additions & 14 deletions concept_erasure/leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, x: Tensor) -> Tensor:
delta = x - self.bias if self.bias is not None else x

# Ensure we do the matmul in the most efficient order.
x_ = x - (delta @ self.proj_right.T) @ self.proj_left.T
x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH
return x_.type_as(x)


Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype)
self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype)

self.n = torch.tensor(0, device=device, dtype=dtype)
self.n = torch.tensor(0, device=device)
self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype)

if self.method == "leace":
Expand Down Expand Up @@ -162,7 +162,7 @@ def update(self, x: Tensor, z: Tensor) -> "LeaceFitter":
# Update the covariance matrix of X if needed (for LEACE)
if self.method == "leace":
assert self.sigma_xx_ is not None
self.sigma_xx_.addmm_(delta_x.mT, delta_x2)
self.sigma_xx_.addmm_(delta_x.mH, delta_x2)

z = z.reshape(n, -1).type_as(x)
assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}"
Expand All @@ -172,7 +172,7 @@ def update(self, x: Tensor, z: Tensor) -> "LeaceFitter":
delta_z2 = z - self.mean_z

# Update the cross-covariance matrix
self.sigma_xz_.addmm_(delta_x.mT, delta_z2)
self.sigma_xz_.addmm_(delta_x.mH, delta_z2)

return self

Expand All @@ -192,8 +192,8 @@ def eraser(self) -> LeaceEraser:
# Assuming PSD; account for numerical error
L.clamp_min_(0.0)

W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mT
W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mT
W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mH
W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mH
else:
W, W_inv = eye, eye

Expand All @@ -203,26 +203,26 @@ def eraser(self) -> LeaceEraser:
u *= s > self.svd_tol

proj_left = W_inv @ u
proj_right = u.T @ W
proj_right = u.mH @ W

if self.constrain_cov_trace and self.method == "leace":
P = eye - proj_left @ proj_right

# Prevent the covariance trace from increasing
sigma = self.sigma_xx
old_trace = torch.trace(sigma)
new_trace = torch.trace(P @ sigma @ P.mT)
new_trace = torch.trace(P @ sigma @ P.mH)

# If applying the projection matrix increases the variance, this might
# cause instability, especially when erasure is applied multiple times.
# We regularize toward the orthogonal projection matrix to avoid this.
if new_trace > old_trace:
Q = eye - u @ u.T
if new_trace.real > old_trace.real:
Q = eye - u @ u.mH

# Set up the variables for the quadratic equation
x = new_trace
y = 2 * torch.trace(P @ sigma @ Q.mT)
z = torch.trace(Q @ sigma @ Q.mT)
y = 2 * torch.trace(P @ sigma @ Q.mH)
z = torch.trace(Q @ sigma @ Q.mH)
w = old_trace

# Solve for the mixture of P and Q that makes the trace equal to the
Expand All @@ -234,7 +234,7 @@ def eraser(self) -> LeaceEraser:
alpha2 = (-y / 2 + z + discr / 2) / (x - y + z)

# Choose the positive root
alpha = torch.where(alpha1 > 0, alpha1, alpha2).clamp(0, 1)
alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1)
P = alpha * P + (1 - alpha) * Q

# TODO: Avoid using SVD here
Expand All @@ -255,7 +255,7 @@ def sigma_xx(self) -> Tensor:
), "Covariance statistics are not being tracked for X"

# Accumulated numerical error may cause this to be slightly non-symmetric
S_hat = (self.sigma_xx_ + self.sigma_xx_.mT) / 2
S_hat = (self.sigma_xx_ + self.sigma_xx_.mH) / 2

# Apply Random Matrix Theory-based shrinkage
if self.shrinkage:
Expand Down
8 changes: 4 additions & 4 deletions concept_erasure/optimal_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def psd_sqrt(A: Tensor) -> Tensor:
"""Compute the unique p.s.d. square root of a positive semidefinite matrix."""
L, U = torch.linalg.eigh(A)
L = L[..., None, :].clamp_min(0.0)
return U * L.sqrt() @ U.mT
return U * L.sqrt() @ U.mH


def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]:
Expand All @@ -23,12 +23,12 @@ def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]:
L = L[..., None, :].clamp_min(0.0)

# Square root is easy
sqrt = U * L.sqrt() @ U.mT
sqrt = U * L.sqrt() @ U.mH

# We actually compute the pseudo-inverse here for numerical stability.
# Use the same heuristic as `torch.linalg.pinv` to determine the tolerance.
thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps
rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mT
rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mH

return sqrt, rsqrt

Expand Down Expand Up @@ -84,7 +84,7 @@ def ot_barycenter(

# Equation 7 from Álvarez-Esteban et al. (2016)
T = torch.sum(weights * rsqrt_mu @ inner @ rsqrt_mu, dim=0)
mu = T @ mu @ T.mT
mu = T @ mu @ T.mH

return mu

Expand Down
8 changes: 4 additions & 4 deletions concept_erasure/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype)
self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype)

self.n = torch.tensor(0, device=device, dtype=dtype)
self.n = torch.tensor(0, device=device)
self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype)
self.sigma_zz_ = torch.zeros(z_dim, z_dim, device=device, dtype=dtype)

Expand All @@ -119,8 +119,8 @@ def update(self, x: Tensor, z: Tensor) -> "OracleFitter":
self.mean_z += delta_z.sum(dim=0) / self.n
delta_z2 = z - self.mean_z

self.sigma_xz_.addmm_(delta_x.mT, delta_z2)
self.sigma_zz_.addmm_(delta_z.mT, delta_z2)
self.sigma_xz_.addmm_(delta_x.mH, delta_z2)
self.sigma_zz_.addmm_(delta_z.mH, delta_z2)

return self

Expand All @@ -141,7 +141,7 @@ def sigma_zz(self) -> Tensor:
), "Covariance statistics are not being tracked for X"

# Accumulated numerical error may cause this to be slightly non-symmetric
S_hat = (self.sigma_zz_ + self.sigma_zz_.mT) / 2
S_hat = (self.sigma_zz_ + self.sigma_zz_.mH) / 2

# Apply Random Matrix Theory-based shrinkage
if self.shrinkage:
Expand Down
86 changes: 41 additions & 45 deletions concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import torch
from torch import Tensor
from torch.distributions import MultivariateNormal

from .caching import cached_property, invalidates_cache
from .groupby import groupby
Expand All @@ -17,52 +16,51 @@ class QuadraticEraser:
class_means: Tensor
"""`[k, d]` batch of class centroids."""

class_prior: Tensor
"""`[k]` prior probability of each class."""

global_mean: Tensor
"""`[d]` global centroid of the dataset."""

ot_maps: Tensor
"""`[k, d, d]` batch of optimal transport matrices to the concept barycenter."""

scale_trils: Tensor | None = None
"""`[k, d, d]` batch of covariance Cholesky factors for each class."""

@classmethod
def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "QuadraticEraser":
"""Convenience method to fit a QuadraticEraser on data and return it."""
return QuadraticFitter.fit(x, z, **kwargs).eraser

def optimal_transport(self, z: int, x: Tensor) -> Tensor:
"""Transport `x` to the barycenter, assuming it was sampled from class `z`"""
return (x - self.class_means[z]) @ self.ot_maps[z].mT + self.global_mean
x_ = x.flatten(1)
x_ = (x_ - self.class_means[z]) @ self.ot_maps[z].mH + self.global_mean
return x_.view_as(x)

def predict(self, x: Tensor) -> Tensor:
"""Compute the log posterior p(z|x) for each class z."""
assert self.scale_trils is not None, "Set store_covariance=True for prediction"
def __call__(self, x: Tensor, z: Tensor) -> Tensor:
"""Apply erasure to `x` with oracle labels `z`."""
# Efficiently group `x` by `z`, optimally transport each group, then coalesce
return groupby(x, z).map(self.optimal_transport).coalesce()

# Because we provide the Cholesky factor directly, the initialization is cheap
gaussian = MultivariateNormal(
loc=self.class_means, scale_tril=self.scale_trils, validate_args=False
)
# Bayes rule
log_prior = self.class_prior.log()
log_likelihood = gaussian.log_prob(x[:, None])
return torch.log_softmax(log_prior + log_likelihood, dim=-1)

def __call__(self, x: Tensor, z: Tensor | None = None) -> Tensor:
"""Apply erasure to `x` with oracle labels `z`.
@dataclass(frozen=True)
class QuadraticEditor:
"""Performs surgical quadratic concept editing."""

If `z` is not provided, we will estimate it from `x` using `self.predict`. This
is only possible if `store_covariance=True` was passed to `QuadraticFitter`.
"""
if z is None:
assert self.scale_trils is not None, "Set store_covariance=True"
z = self.predict(x).argmax(-1)
class_means: Tensor
"""`[k, d]` batch of class centroids."""

# Efficiently group `x` by `z`, optimally transport each group, then coalesce
return groupby(x, z).map(self.optimal_transport).coalesce()
ot_maps: Tensor
"""`[k, k, d, d]` batch of pairwise optimal transport matrices between classes."""

def transport(self, x: Tensor, source_z: int, target_z: int) -> Tensor:
"""Transport `x` from class `source_z` to class `target_z`"""
T = self.ot_maps[source_z, target_z]
return (x - self.class_means[source_z]) @ T.mH + self.class_means[target_z]

def __call__(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor:
"""Transport `x` from classes `source_z` to class `target_z`."""
return (
groupby(x, source_z)
.map(lambda src, x_grp: self.transport(x_grp, src, target_z))
.coalesce()
)


class QuadraticFitter:
Expand Down Expand Up @@ -96,7 +94,6 @@ def __init__(
*,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
store_covariance: bool = False,
shrinkage: bool = True,
):
"""Initialize a `QuadraticFitter`.
Expand All @@ -116,10 +113,9 @@ def __init__(

self.num_classes = num_classes
self.shrinkage = shrinkage
self.store_covariance = store_covariance

self.mean_x = torch.zeros(num_classes, x_dim, device=device, dtype=dtype)
self.n = torch.zeros(num_classes, device=device, dtype=torch.long)
self.n = torch.zeros(num_classes, device=device)
self.sigma_xx_ = torch.zeros(
num_classes, x_dim, x_dim, device=device, dtype=dtype
)
Expand All @@ -146,37 +142,37 @@ def update_single(self, x: Tensor, z: int) -> "QuadraticFitter":
self.mean_x[z] += delta_x.sum(dim=0) / self.n[z]
delta_x2 = x - self.mean_x[z]

self.sigma_xx_[z].addmm_(delta_x.mT, delta_x2)
self.sigma_xx_[z].addmm_(delta_x.mH, delta_x2)

return self

def editor(self) -> QuadraticEditor:
"""Quadratic editor for the concept."""
sigma = self.sigma_xx
return QuadraticEditor(self.mean_x, ot_map(sigma[:, None], sigma))

@cached_property
def eraser(self) -> QuadraticEraser:
"""Erasure function lazily computed given the current statistics."""

class_prior = self.n / self.n.sum()
sigmas = self.sigma_xx

# Compute Wasserstein barycenter of the classes
if self.num_classes == 2:
# Use closed form solution for the binary case
class_prior = self.n / self.n.sum()
center = ot_midpoint(sigmas[0], sigmas[1], *class_prior.tolist())
else:
# Use fixed point iteration for the general case
center = ot_barycenter(self.sigma_xx, self.n)

# Then compute the optimal ransport maps from each class to the barycenter
# Then compute the optimal transport maps from each class to the barycenter
ot_maps = ot_map(sigmas, center)

if self.store_covariance:
# Add jitter to ensure positive-definiteness
torch.linalg.diagonal(sigmas).add_(1e-3)
scale_trils = torch.linalg.cholesky(sigmas)
else:
scale_trils = None

return QuadraticEraser(
self.mean_x, class_prior, self.mean_x.mean(dim=0), ot_maps, scale_trils
self.mean_x,
self.mean_x.mean(dim=0),
ot_maps,
)

@property
Expand All @@ -187,8 +183,8 @@ def sigma_xx(self) -> Tensor:
self.sigma_xx_ is not None
), "Covariance statistics are not being tracked for X"

# Accumulated numerical error may cause this to be slightly non-symmetric
S_hat = (self.sigma_xx_ + self.sigma_xx_.mT) / 2
# Accumulated numerical error may cause this to be slightly non-Hermitian
S_hat = (self.sigma_xx_ + self.sigma_xx_.mH) / 2

# Apply Random Matrix Theory-based shrinkage
n = self.n.view(-1, 1, 1)
Expand Down
8 changes: 4 additions & 4 deletions concept_erasure/shrinkage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor:
trace_S = trace(S_n)
sigma0 = eye * trace_S / p

sigma0_norm_sq = sigma0.pow(2).sum(dim=(-2, -1), keepdim=True)
S_norm_sq = S_n.pow(2).sum(dim=(-2, -1), keepdim=True)
sigma0_norm_sq = sigma0.norm(dim=(-2, -1), keepdim=True) ** 2
S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2

prod_trace = trace(S_n @ sigma0)
top = trace_S.pow(2) * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace**2
top = trace_S * trace_S.conj() * sigma0_norm_sq / n
bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj()

# Epsilon prevents dividing by zero for the zero matrix. In that case we end up
# setting alpha = 0, beta = 1, but it doesn't matter since we're shrinking toward
Expand Down
Loading

0 comments on commit ff6d7bc

Please sign in to comment.