diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index 32e57e0..c7b23e9 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -43,7 +43,7 @@ def __call__(self, x: Tensor) -> Tensor: delta = x - self.bias if self.bias is not None else x # Ensure we do the matmul in the most efficient order. - x_ = x - (delta @ self.proj_right.T) @ self.proj_left.T + x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH return x_.type_as(x) @@ -133,7 +133,7 @@ def __init__( self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype) self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype) - self.n = torch.tensor(0, device=device, dtype=dtype) + self.n = torch.tensor(0, device=device) self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) if self.method == "leace": @@ -203,7 +203,7 @@ def eraser(self) -> LeaceEraser: u *= s > self.svd_tol proj_left = W_inv @ u - proj_right = u.T @ W + proj_right = u.mH @ W if self.constrain_cov_trace and self.method == "leace": P = eye - proj_left @ proj_right @@ -216,8 +216,8 @@ def eraser(self) -> LeaceEraser: # If applying the projection matrix increases the variance, this might # cause instability, especially when erasure is applied multiple times. # We regularize toward the orthogonal projection matrix to avoid this. - if new_trace > old_trace: - Q = eye - u @ u.T + if new_trace.real > old_trace.real: + Q = eye - u @ u.mH # Set up the variables for the quadratic equation x = new_trace @@ -234,7 +234,7 @@ def eraser(self) -> LeaceEraser: alpha2 = (-y / 2 + z + discr / 2) / (x - y + z) # Choose the positive root - alpha = torch.where(alpha1 > 0, alpha1, alpha2).clamp(0, 1) + alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1) P = alpha * P + (1 - alpha) * Q # TODO: Avoid using SVD here diff --git a/concept_erasure/oracle.py b/concept_erasure/oracle.py index e7471d6..ae0a2f2 100644 --- a/concept_erasure/oracle.py +++ b/concept_erasure/oracle.py @@ -93,7 +93,7 @@ def __init__( self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype) self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype) - self.n = torch.tensor(0, device=device, dtype=dtype) + self.n = torch.tensor(0, device=device) self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) self.sigma_zz_ = torch.zeros(z_dim, z_dim, device=device, dtype=dtype) diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index eda3e89..e5f984d 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -115,7 +115,7 @@ def __init__( self.shrinkage = shrinkage self.mean_x = torch.zeros(num_classes, x_dim, device=device, dtype=dtype) - self.n = torch.zeros(num_classes, device=device, dtype=torch.long) + self.n = torch.zeros(num_classes, device=device) self.sigma_xx_ = torch.zeros( num_classes, x_dim, x_dim, device=device, dtype=dtype ) diff --git a/tests/test_leace.py b/tests/test_leace.py index 917392a..4657f4b 100644 --- a/tests/test_leace.py +++ b/tests/test_leace.py @@ -22,19 +22,22 @@ @pytest.mark.parametrize("shrinkage", [False, True]) -def test_stats(shrinkage: bool): +@pytest.mark.parametrize("dtype", [torch.float64, torch.complex128]) +def test_stats(shrinkage: bool, dtype: torch.dtype): batch_size = 10 num_batches = 5 num_classes = 2 num_features = 3 N = batch_size * num_batches - fitter = LeaceFitter(num_features, num_classes, shrinkage=shrinkage) - oracle = OracleFitter(num_features, num_classes, shrinkage=shrinkage) + fitter = LeaceFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage) + oracle = OracleFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage) # Generate random data torch.manual_seed(42) - x_data = [torch.randn(batch_size, num_features) for _ in range(num_batches)] + x_data = [ + torch.randn(batch_size, num_features, dtype=dtype) for _ in range(num_batches) + ] z_data = [ torch.randint(0, num_classes, (batch_size, num_classes)) for _ in range(num_batches) @@ -59,8 +62,12 @@ def test_stats(shrinkage: bool): x_centered = x_all - mean_x z_centered = z_all - mean_z - expected_sigma_xx = torch.einsum("b...m,b...n->...mn", x_centered, x_centered) - expected_sigma_zz = torch.einsum("b...m,b...n->...mn", z_centered, z_centered) + expected_sigma_xx = torch.einsum( + "b...m,b...n->...mn", x_centered.conj(), x_centered + ) + expected_sigma_zz = torch.einsum( + "b...m,b...n->...mn", z_centered.conj(), z_centered + ) if shrinkage: expected_sigma_xx = optimal_linear_shrinkage( expected_sigma_xx / N, batch_size * num_batches @@ -72,7 +79,9 @@ def test_stats(shrinkage: bool): expected_sigma_xx /= N - 1 expected_sigma_zz /= N - 1 - expected_sigma_xz = torch.einsum("b...m,b...n->...mn", x_centered, z_centered) + expected_sigma_xz = torch.einsum( + "b...m,b...n->...mn", x_centered.conj(), z_centered + ) expected_sigma_xz /= N - 1 torch.testing.assert_close(fitter.sigma_xx, expected_sigma_xx)