From c8a035c2ecae73e2cac47e697a5b1a17014b130e Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 19 Sep 2023 04:27:59 +0000 Subject: [PATCH 1/4] QuadraticEditor class --- concept_erasure/groupby.py | 3 +- concept_erasure/quadratic.py | 78 +++++++++++++++++------------------- 2 files changed, 39 insertions(+), 42 deletions(-) diff --git a/concept_erasure/groupby.py b/concept_erasure/groupby.py index dbd3037..4ca4ba2 100644 --- a/concept_erasure/groupby.py +++ b/concept_erasure/groupby.py @@ -102,6 +102,7 @@ def invert_indices(indices: Tensor) -> Tensor: reverse_indices = torch.empty_like(indices) # Scatter the indices to reverse the permutation - reverse_indices.scatter_(0, indices, torch.arange(len(indices))) + arange = torch.arange(len(indices), device=indices.device) + reverse_indices.scatter_(0, indices, arange) return reverse_indices diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index 2a956a9..cbdb9ae 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -2,7 +2,6 @@ import torch from torch import Tensor -from torch.distributions import MultivariateNormal from .caching import cached_property, invalidates_cache from .groupby import groupby @@ -17,18 +16,12 @@ class QuadraticEraser: class_means: Tensor """`[k, d]` batch of class centroids.""" - class_prior: Tensor - """`[k]` prior probability of each class.""" - global_mean: Tensor """`[d]` global centroid of the dataset.""" ot_maps: Tensor """`[k, d, d]` batch of optimal transport matrices to the concept barycenter.""" - scale_trils: Tensor | None = None - """`[k, d, d]` batch of covariance Cholesky factors for each class.""" - @classmethod def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "QuadraticEraser": """Convenience method to fit a QuadraticEraser on data and return it.""" @@ -36,33 +29,38 @@ def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "QuadraticEraser": def optimal_transport(self, z: int, x: Tensor) -> Tensor: """Transport `x` to the barycenter, assuming it was sampled from class `z`""" - return (x - self.class_means[z]) @ self.ot_maps[z].mT + self.global_mean + x_ = x.flatten(1) + x_ = (x_ - self.class_means[z]) @ self.ot_maps[z].mT + self.global_mean + return x_.view_as(x) - def predict(self, x: Tensor) -> Tensor: - """Compute the log posterior p(z|x) for each class z.""" - assert self.scale_trils is not None, "Set store_covariance=True for prediction" + def __call__(self, x: Tensor, z: Tensor) -> Tensor: + """Apply erasure to `x` with oracle labels `z`.""" + # Efficiently group `x` by `z`, optimally transport each group, then coalesce + return groupby(x, z).map(self.optimal_transport).coalesce() - # Because we provide the Cholesky factor directly, the initialization is cheap - gaussian = MultivariateNormal( - loc=self.class_means, scale_tril=self.scale_trils, validate_args=False - ) - # Bayes rule - log_prior = self.class_prior.log() - log_likelihood = gaussian.log_prob(x[:, None]) - return torch.log_softmax(log_prior + log_likelihood, dim=-1) - def __call__(self, x: Tensor, z: Tensor | None = None) -> Tensor: - """Apply erasure to `x` with oracle labels `z`. +@dataclass(frozen=True) +class QuadraticEditor: + """Performs surgical quadratic concept editing.""" - If `z` is not provided, we will estimate it from `x` using `self.predict`. This - is only possible if `store_covariance=True` was passed to `QuadraticFitter`. - """ - if z is None: - assert self.scale_trils is not None, "Set store_covariance=True" - z = self.predict(x).argmax(-1) + class_means: Tensor + """`[k, d]` batch of class centroids.""" - # Efficiently group `x` by `z`, optimally transport each group, then coalesce - return groupby(x, z).map(self.optimal_transport).coalesce() + ot_maps: Tensor + """`[k, k, d, d]` batch of pairwise optimal transport matrices between classes.""" + + def transport(self, x: Tensor, source_z: int, target_z: int) -> Tensor: + """Transport `x` from class `source_z` to class `target_z`""" + T = self.ot_maps[source_z, target_z] + return (x - self.class_means[source_z]) @ T.mT + self.class_means[target_z] + + def __call__(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor: + """Transport `x` from classes `source_z` to class `target_z`.""" + return ( + groupby(x, source_z) + .map(lambda src, x_grp: self.transport(x_grp, src, target_z)) + .coalesce() + ) class QuadraticFitter: @@ -96,7 +94,6 @@ def __init__( *, device: str | torch.device | None = None, dtype: torch.dtype | None = None, - store_covariance: bool = False, shrinkage: bool = True, ): """Initialize a `QuadraticFitter`. @@ -116,7 +113,6 @@ def __init__( self.num_classes = num_classes self.shrinkage = shrinkage - self.store_covariance = store_covariance self.mean_x = torch.zeros(num_classes, x_dim, device=device, dtype=dtype) self.n = torch.zeros(num_classes, device=device, dtype=torch.long) @@ -150,33 +146,33 @@ def update_single(self, x: Tensor, z: int) -> "QuadraticFitter": return self + def editor(self) -> QuadraticEditor: + """Quadratic editor for the concept.""" + sigma = self.sigma_xx + return QuadraticEditor(self.mean_x, ot_map(sigma[:, None], sigma)) + @cached_property def eraser(self) -> QuadraticEraser: """Erasure function lazily computed given the current statistics.""" - class_prior = self.n / self.n.sum() sigmas = self.sigma_xx # Compute Wasserstein barycenter of the classes if self.num_classes == 2: # Use closed form solution for the binary case + class_prior = self.n / self.n.sum() center = ot_midpoint(sigmas[0], sigmas[1], *class_prior.tolist()) else: # Use fixed point iteration for the general case center = ot_barycenter(self.sigma_xx, self.n) - # Then compute the optimal ransport maps from each class to the barycenter + # Then compute the optimal transport maps from each class to the barycenter ot_maps = ot_map(sigmas, center) - if self.store_covariance: - # Add jitter to ensure positive-definiteness - torch.linalg.diagonal(sigmas).add_(1e-3) - scale_trils = torch.linalg.cholesky(sigmas) - else: - scale_trils = None - return QuadraticEraser( - self.mean_x, class_prior, self.mean_x.mean(dim=0), ot_maps, scale_trils + self.mean_x, + self.mean_x.mean(dim=0), + ot_maps, ) @property From bf99e10bc1f7ce909db7f2c74f24e78dd95f4695 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 19 Sep 2023 04:47:21 +0000 Subject: [PATCH 2/4] Basic complex number support --- concept_erasure/leace.py | 16 ++++++++-------- concept_erasure/optimal_transport.py | 8 ++++---- concept_erasure/oracle.py | 6 +++--- concept_erasure/quadratic.py | 10 +++++----- tests/test_shrinkage.py | 4 ++-- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index 02d6241..32e57e0 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -162,7 +162,7 @@ def update(self, x: Tensor, z: Tensor) -> "LeaceFitter": # Update the covariance matrix of X if needed (for LEACE) if self.method == "leace": assert self.sigma_xx_ is not None - self.sigma_xx_.addmm_(delta_x.mT, delta_x2) + self.sigma_xx_.addmm_(delta_x.mH, delta_x2) z = z.reshape(n, -1).type_as(x) assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}" @@ -172,7 +172,7 @@ def update(self, x: Tensor, z: Tensor) -> "LeaceFitter": delta_z2 = z - self.mean_z # Update the cross-covariance matrix - self.sigma_xz_.addmm_(delta_x.mT, delta_z2) + self.sigma_xz_.addmm_(delta_x.mH, delta_z2) return self @@ -192,8 +192,8 @@ def eraser(self) -> LeaceEraser: # Assuming PSD; account for numerical error L.clamp_min_(0.0) - W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mT - W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mT + W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mH + W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mH else: W, W_inv = eye, eye @@ -211,7 +211,7 @@ def eraser(self) -> LeaceEraser: # Prevent the covariance trace from increasing sigma = self.sigma_xx old_trace = torch.trace(sigma) - new_trace = torch.trace(P @ sigma @ P.mT) + new_trace = torch.trace(P @ sigma @ P.mH) # If applying the projection matrix increases the variance, this might # cause instability, especially when erasure is applied multiple times. @@ -221,8 +221,8 @@ def eraser(self) -> LeaceEraser: # Set up the variables for the quadratic equation x = new_trace - y = 2 * torch.trace(P @ sigma @ Q.mT) - z = torch.trace(Q @ sigma @ Q.mT) + y = 2 * torch.trace(P @ sigma @ Q.mH) + z = torch.trace(Q @ sigma @ Q.mH) w = old_trace # Solve for the mixture of P and Q that makes the trace equal to the @@ -255,7 +255,7 @@ def sigma_xx(self) -> Tensor: ), "Covariance statistics are not being tracked for X" # Accumulated numerical error may cause this to be slightly non-symmetric - S_hat = (self.sigma_xx_ + self.sigma_xx_.mT) / 2 + S_hat = (self.sigma_xx_ + self.sigma_xx_.mH) / 2 # Apply Random Matrix Theory-based shrinkage if self.shrinkage: diff --git a/concept_erasure/optimal_transport.py b/concept_erasure/optimal_transport.py index 2db15c8..625d246 100644 --- a/concept_erasure/optimal_transport.py +++ b/concept_erasure/optimal_transport.py @@ -14,7 +14,7 @@ def psd_sqrt(A: Tensor) -> Tensor: """Compute the unique p.s.d. square root of a positive semidefinite matrix.""" L, U = torch.linalg.eigh(A) L = L[..., None, :].clamp_min(0.0) - return U * L.sqrt() @ U.mT + return U * L.sqrt() @ U.mH def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]: @@ -23,12 +23,12 @@ def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]: L = L[..., None, :].clamp_min(0.0) # Square root is easy - sqrt = U * L.sqrt() @ U.mT + sqrt = U * L.sqrt() @ U.mH # We actually compute the pseudo-inverse here for numerical stability. # Use the same heuristic as `torch.linalg.pinv` to determine the tolerance. thresh = L[..., None, -1] * A.shape[-1] * torch.finfo(A.dtype).eps - rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mT + rsqrt = U * torch.where(L > thresh, L.rsqrt(), 0.0) @ U.mH return sqrt, rsqrt @@ -84,7 +84,7 @@ def ot_barycenter( # Equation 7 from Álvarez-Esteban et al. (2016) T = torch.sum(weights * rsqrt_mu @ inner @ rsqrt_mu, dim=0) - mu = T @ mu @ T.mT + mu = T @ mu @ T.mH return mu diff --git a/concept_erasure/oracle.py b/concept_erasure/oracle.py index 4886d4a..e7471d6 100644 --- a/concept_erasure/oracle.py +++ b/concept_erasure/oracle.py @@ -119,8 +119,8 @@ def update(self, x: Tensor, z: Tensor) -> "OracleFitter": self.mean_z += delta_z.sum(dim=0) / self.n delta_z2 = z - self.mean_z - self.sigma_xz_.addmm_(delta_x.mT, delta_z2) - self.sigma_zz_.addmm_(delta_z.mT, delta_z2) + self.sigma_xz_.addmm_(delta_x.mH, delta_z2) + self.sigma_zz_.addmm_(delta_z.mH, delta_z2) return self @@ -141,7 +141,7 @@ def sigma_zz(self) -> Tensor: ), "Covariance statistics are not being tracked for X" # Accumulated numerical error may cause this to be slightly non-symmetric - S_hat = (self.sigma_zz_ + self.sigma_zz_.mT) / 2 + S_hat = (self.sigma_zz_ + self.sigma_zz_.mH) / 2 # Apply Random Matrix Theory-based shrinkage if self.shrinkage: diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index cbdb9ae..eda3e89 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -30,7 +30,7 @@ def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "QuadraticEraser": def optimal_transport(self, z: int, x: Tensor) -> Tensor: """Transport `x` to the barycenter, assuming it was sampled from class `z`""" x_ = x.flatten(1) - x_ = (x_ - self.class_means[z]) @ self.ot_maps[z].mT + self.global_mean + x_ = (x_ - self.class_means[z]) @ self.ot_maps[z].mH + self.global_mean return x_.view_as(x) def __call__(self, x: Tensor, z: Tensor) -> Tensor: @@ -52,7 +52,7 @@ class QuadraticEditor: def transport(self, x: Tensor, source_z: int, target_z: int) -> Tensor: """Transport `x` from class `source_z` to class `target_z`""" T = self.ot_maps[source_z, target_z] - return (x - self.class_means[source_z]) @ T.mT + self.class_means[target_z] + return (x - self.class_means[source_z]) @ T.mH + self.class_means[target_z] def __call__(self, x: Tensor, source_z: Tensor, target_z: int) -> Tensor: """Transport `x` from classes `source_z` to class `target_z`.""" @@ -142,7 +142,7 @@ def update_single(self, x: Tensor, z: int) -> "QuadraticFitter": self.mean_x[z] += delta_x.sum(dim=0) / self.n[z] delta_x2 = x - self.mean_x[z] - self.sigma_xx_[z].addmm_(delta_x.mT, delta_x2) + self.sigma_xx_[z].addmm_(delta_x.mH, delta_x2) return self @@ -183,8 +183,8 @@ def sigma_xx(self) -> Tensor: self.sigma_xx_ is not None ), "Covariance statistics are not being tracked for X" - # Accumulated numerical error may cause this to be slightly non-symmetric - S_hat = (self.sigma_xx_ + self.sigma_xx_.mT) / 2 + # Accumulated numerical error may cause this to be slightly non-Hermitian + S_hat = (self.sigma_xx_ + self.sigma_xx_.mH) / 2 # Apply Random Matrix Theory-based shrinkage n = self.n.view(-1, 1, 1) diff --git a/tests/test_shrinkage.py b/tests/test_shrinkage.py index 79468bd..d0c3644 100644 --- a/tests/test_shrinkage.py +++ b/tests/test_shrinkage.py @@ -26,7 +26,7 @@ def test_olse_shrinkage(p: int, n: int): # Generate a random covariance matrix A = torch.randn(N, p, p) - S_true = A @ A.mT / p + S_true = A @ A.mH / p torch.linalg.diagonal(S_true).add_(1e-3) # Generate data with this covariance @@ -37,7 +37,7 @@ def test_olse_shrinkage(p: int, n: int): # Compute the sample covariance matrix X_centered = X - X.mean(dim=0, keepdim=True) - S_hat = (X_centered.mT @ X_centered) / n + S_hat = (X_centered.mH @ X_centered) / n # Apply shrinkage S_olse = optimal_linear_shrinkage(S_hat, n) From 0a04406bb91d11f4076d72a6f9a28e7bd202a2ac Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 19 Sep 2023 05:09:49 +0000 Subject: [PATCH 3/4] Extend optimal_linear_shrinkage for complex numbers --- concept_erasure/shrinkage.py | 8 ++++---- tests/test_shrinkage.py | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/concept_erasure/shrinkage.py b/concept_erasure/shrinkage.py index d85a3be..3ee2fb1 100644 --- a/concept_erasure/shrinkage.py +++ b/concept_erasure/shrinkage.py @@ -31,12 +31,12 @@ def optimal_linear_shrinkage(S_n: Tensor, n: int | Tensor) -> Tensor: trace_S = trace(S_n) sigma0 = eye * trace_S / p - sigma0_norm_sq = sigma0.pow(2).sum(dim=(-2, -1), keepdim=True) - S_norm_sq = S_n.pow(2).sum(dim=(-2, -1), keepdim=True) + sigma0_norm_sq = sigma0.norm(dim=(-2, -1), keepdim=True) ** 2 + S_norm_sq = S_n.norm(dim=(-2, -1), keepdim=True) ** 2 prod_trace = trace(S_n @ sigma0) - top = trace_S.pow(2) * sigma0_norm_sq / n - bottom = S_norm_sq * sigma0_norm_sq - prod_trace**2 + top = trace_S * trace_S.conj() * sigma0_norm_sq / n + bottom = S_norm_sq * sigma0_norm_sq - prod_trace * prod_trace.conj() # Epsilon prevents dividing by zero for the zero matrix. In that case we end up # setting alpha = 0, beta = 1, but it doesn't matter since we're shrinking toward diff --git a/tests/test_shrinkage.py b/tests/test_shrinkage.py index d0c3644..7581c70 100644 --- a/tests/test_shrinkage.py +++ b/tests/test_shrinkage.py @@ -1,6 +1,5 @@ import pytest import torch -from torch.distributions import MultivariateNormal from concept_erasure import optimal_linear_shrinkage @@ -18,21 +17,21 @@ (16, 256), ], ) -def test_olse_shrinkage(p: int, n: int): +@pytest.mark.parametrize("dtype", [torch.float32, torch.complex64]) +def test_olse_shrinkage(p: int, n: int, dtype: torch.dtype): torch.manual_seed(42) # Number of matrices N = 1000 # Generate a random covariance matrix - A = torch.randn(N, p, p) + A = torch.randn(N, p, p, dtype=dtype) S_true = A @ A.mH / p torch.linalg.diagonal(S_true).add_(1e-3) - # Generate data with this covariance - mean = torch.zeros(N, p) - dist = MultivariateNormal(mean, S_true) - X = dist.sample([n]).movedim(1, 0) + # Generate random Gaussian vectors with this covariance matrix + scale_tril = torch.linalg.cholesky(S_true) + X = torch.randn(N, n, p, dtype=dtype) @ scale_tril.mH assert X.shape == (N, n, p) # Compute the sample covariance matrix From c3736bf2aabe0683d99e504fd2f18b63ceb38391 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 19 Sep 2023 05:45:46 +0000 Subject: [PATCH 4/4] Passing basic tests w/ complex numbers --- concept_erasure/leace.py | 12 ++++++------ concept_erasure/oracle.py | 2 +- concept_erasure/quadratic.py | 2 +- tests/test_leace.py | 23 ++++++++++++++++------- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index 32e57e0..c7b23e9 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -43,7 +43,7 @@ def __call__(self, x: Tensor) -> Tensor: delta = x - self.bias if self.bias is not None else x # Ensure we do the matmul in the most efficient order. - x_ = x - (delta @ self.proj_right.T) @ self.proj_left.T + x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH return x_.type_as(x) @@ -133,7 +133,7 @@ def __init__( self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype) self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype) - self.n = torch.tensor(0, device=device, dtype=dtype) + self.n = torch.tensor(0, device=device) self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) if self.method == "leace": @@ -203,7 +203,7 @@ def eraser(self) -> LeaceEraser: u *= s > self.svd_tol proj_left = W_inv @ u - proj_right = u.T @ W + proj_right = u.mH @ W if self.constrain_cov_trace and self.method == "leace": P = eye - proj_left @ proj_right @@ -216,8 +216,8 @@ def eraser(self) -> LeaceEraser: # If applying the projection matrix increases the variance, this might # cause instability, especially when erasure is applied multiple times. # We regularize toward the orthogonal projection matrix to avoid this. - if new_trace > old_trace: - Q = eye - u @ u.T + if new_trace.real > old_trace.real: + Q = eye - u @ u.mH # Set up the variables for the quadratic equation x = new_trace @@ -234,7 +234,7 @@ def eraser(self) -> LeaceEraser: alpha2 = (-y / 2 + z + discr / 2) / (x - y + z) # Choose the positive root - alpha = torch.where(alpha1 > 0, alpha1, alpha2).clamp(0, 1) + alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1) P = alpha * P + (1 - alpha) * Q # TODO: Avoid using SVD here diff --git a/concept_erasure/oracle.py b/concept_erasure/oracle.py index e7471d6..ae0a2f2 100644 --- a/concept_erasure/oracle.py +++ b/concept_erasure/oracle.py @@ -93,7 +93,7 @@ def __init__( self.mean_x = torch.zeros(x_dim, device=device, dtype=dtype) self.mean_z = torch.zeros(z_dim, device=device, dtype=dtype) - self.n = torch.tensor(0, device=device, dtype=dtype) + self.n = torch.tensor(0, device=device) self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) self.sigma_zz_ = torch.zeros(z_dim, z_dim, device=device, dtype=dtype) diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index eda3e89..e5f984d 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -115,7 +115,7 @@ def __init__( self.shrinkage = shrinkage self.mean_x = torch.zeros(num_classes, x_dim, device=device, dtype=dtype) - self.n = torch.zeros(num_classes, device=device, dtype=torch.long) + self.n = torch.zeros(num_classes, device=device) self.sigma_xx_ = torch.zeros( num_classes, x_dim, x_dim, device=device, dtype=dtype ) diff --git a/tests/test_leace.py b/tests/test_leace.py index 917392a..4657f4b 100644 --- a/tests/test_leace.py +++ b/tests/test_leace.py @@ -22,19 +22,22 @@ @pytest.mark.parametrize("shrinkage", [False, True]) -def test_stats(shrinkage: bool): +@pytest.mark.parametrize("dtype", [torch.float64, torch.complex128]) +def test_stats(shrinkage: bool, dtype: torch.dtype): batch_size = 10 num_batches = 5 num_classes = 2 num_features = 3 N = batch_size * num_batches - fitter = LeaceFitter(num_features, num_classes, shrinkage=shrinkage) - oracle = OracleFitter(num_features, num_classes, shrinkage=shrinkage) + fitter = LeaceFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage) + oracle = OracleFitter(num_features, num_classes, dtype=dtype, shrinkage=shrinkage) # Generate random data torch.manual_seed(42) - x_data = [torch.randn(batch_size, num_features) for _ in range(num_batches)] + x_data = [ + torch.randn(batch_size, num_features, dtype=dtype) for _ in range(num_batches) + ] z_data = [ torch.randint(0, num_classes, (batch_size, num_classes)) for _ in range(num_batches) @@ -59,8 +62,12 @@ def test_stats(shrinkage: bool): x_centered = x_all - mean_x z_centered = z_all - mean_z - expected_sigma_xx = torch.einsum("b...m,b...n->...mn", x_centered, x_centered) - expected_sigma_zz = torch.einsum("b...m,b...n->...mn", z_centered, z_centered) + expected_sigma_xx = torch.einsum( + "b...m,b...n->...mn", x_centered.conj(), x_centered + ) + expected_sigma_zz = torch.einsum( + "b...m,b...n->...mn", z_centered.conj(), z_centered + ) if shrinkage: expected_sigma_xx = optimal_linear_shrinkage( expected_sigma_xx / N, batch_size * num_batches @@ -72,7 +79,9 @@ def test_stats(shrinkage: bool): expected_sigma_xx /= N - 1 expected_sigma_zz /= N - 1 - expected_sigma_xz = torch.einsum("b...m,b...n->...mn", x_centered, z_centered) + expected_sigma_xz = torch.einsum( + "b...m,b...n->...mn", x_centered.conj(), z_centered + ) expected_sigma_xz /= N - 1 torch.testing.assert_close(fitter.sigma_xx, expected_sigma_xx)