Skip to content

Commit

Permalink
Passing basic tests w/ complex numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Sep 19, 2023
1 parent 0a04406 commit c3736bf
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
12 changes: 6 additions & 6 deletions concept_erasure/leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, x: Tensor) -> Tensor:
delta = x - self.bias if self.bias is not None else x

# Ensure we do the matmul in the most efficient order.
x_ = x - (delta @ self.proj_right.T) @ self.proj_left.T
x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH
return x_.type_as(x)


Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(
self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype)
self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype)

self.n = torch.tensor(0, device=device, dtype=dtype)
self.n = torch.tensor(0, device=device)
self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype)

if self.method == "leace":
Expand Down Expand Up @@ -203,7 +203,7 @@ def eraser(self) -> LeaceEraser:
u *= s > self.svd_tol

proj_left = W_inv @ u
proj_right = u.T @ W
proj_right = u.mH @ W

if self.constrain_cov_trace and self.method == "leace":
P = eye - proj_left @ proj_right
Expand All @@ -216,8 +216,8 @@ def eraser(self) -> LeaceEraser:
# If applying the projection matrix increases the variance, this might
# cause instability, especially when erasure is applied multiple times.
# We regularize toward the orthogonal projection matrix to avoid this.
if new_trace > old_trace:
Q = eye - u @ u.T
if new_trace.real > old_trace.real:
Q = eye - u @ u.mH

# Set up the variables for the quadratic equation
x = new_trace
Expand All @@ -234,7 +234,7 @@ def eraser(self) -> LeaceEraser:
alpha2 = (-y / 2 + z + discr / 2) / (x - y + z)

# Choose the positive root
alpha = torch.where(alpha1 > 0, alpha1, alpha2).clamp(0, 1)
alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1)
P = alpha * P + (1 - alpha) * Q

# TODO: Avoid using SVD here
Expand Down
2 changes: 1 addition & 1 deletion concept_erasure/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype)
self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype)

self.n = torch.tensor(0, device=device, dtype=dtype)
self.n = torch.tensor(0, device=device)
self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype)
self.sigma_zz_ = torch.zeros(z_dim, z_dim, device=device, dtype=dtype)

Expand Down
2 changes: 1 addition & 1 deletion concept_erasure/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.shrinkage = shrinkage

self.mean_x = torch.zeros(num_classes, x_dim, device=device, dtype=dtype)
self.n = torch.zeros(num_classes, device=device, dtype=torch.long)
self.n = torch.zeros(num_classes, device=device)
self.sigma_xx_ = torch.zeros(
num_classes, x_dim, x_dim, device=device, dtype=dtype
)
Expand Down
23 changes: 16 additions & 7 deletions tests/test_leace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,22 @@


@pytest.mark.parametrize("shrinkage", [False, True])
def test_stats(shrinkage: bool):
@pytest.mark.parametrize("dtype", [torch.float64, torch.complex128])
def test_stats(shrinkage: bool, dtype: torch.dtype):
batch_size = 10
num_batches = 5
num_classes = 2
num_features = 3
N = batch_size * num_batches

fitter = LeaceFitter(num_features, num_classes, shrinkage=shrinkage)
oracle = OracleFitter(num_features, num_classes, shrinkage=shrinkage)
fitter = LeaceFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage)
oracle = OracleFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage)

# Generate random data
torch.manual_seed(42)
x_data = [torch.randn(batch_size, num_features) for _ in range(num_batches)]
x_data = [
torch.randn(batch_size, num_features, dtype=dtype) for _ in range(num_batches)
]
z_data = [
torch.randint(0, num_classes, (batch_size, num_classes))
for _ in range(num_batches)
Expand All @@ -59,8 +62,12 @@ def test_stats(shrinkage: bool):
x_centered = x_all - mean_x
z_centered = z_all - mean_z

expected_sigma_xx = torch.einsum("b...m,b...n->...mn", x_centered, x_centered)
expected_sigma_zz = torch.einsum("b...m,b...n->...mn", z_centered, z_centered)
expected_sigma_xx = torch.einsum(
"b...m,b...n->...mn", x_centered.conj(), x_centered
)
expected_sigma_zz = torch.einsum(
"b...m,b...n->...mn", z_centered.conj(), z_centered
)
if shrinkage:
expected_sigma_xx = optimal_linear_shrinkage(
expected_sigma_xx / N, batch_size * num_batches
Expand All @@ -72,7 +79,9 @@ def test_stats(shrinkage: bool):
expected_sigma_xx /= N - 1
expected_sigma_zz /= N - 1

expected_sigma_xz = torch.einsum("b...m,b...n->...mn", x_centered, z_centered)
expected_sigma_xz = torch.einsum(
"b...m,b...n->...mn", x_centered.conj(), z_centered
)
expected_sigma_xz /= N - 1

torch.testing.assert_close(fitter.sigma_xx, expected_sigma_xx)
Expand Down

0 comments on commit c3736bf

Please sign in to comment.