diff --git a/src/torchjd/aggregation/cagrad.py b/src/torchjd/aggregation/cagrad.py index b20d0c9..20fef34 100644 --- a/src/torchjd/aggregation/cagrad.py +++ b/src/torchjd/aggregation/cagrad.py @@ -76,7 +76,7 @@ def forward(self, matrix: Tensor) -> Tensor: U, S, _ = torch.svd(gramian) reduced_matrix = U @ S.sqrt().diag() - reduced_array = reduced_matrix.cpu().detach().numpy() + reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64) dimension = matrix.shape[0] reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index 0706102..7d7fbe0 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -107,10 +107,10 @@ def __init__( def forward(self, matrix: Tensor) -> Tensor: weights = self.weighting(matrix) - weights_array = weights.cpu().detach().numpy() + weights_array = weights.cpu().detach().numpy().astype(np.float64) gramian = _compute_normalized_gramian(matrix, self.norm_eps) - gramian_array = gramian.cpu().detach().numpy() + gramian_array = gramian.cpu().detach().numpy().astype(np.float64) dimension = gramian.shape[0] # Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s), diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index 19f9f3a..2050fd4 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -109,7 +109,7 @@ def forward(self, matrix: Tensor) -> Tensor: def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor: gramian = _compute_normalized_gramian(matrix, self.norm_eps) - gramian_array = gramian.cpu().detach().numpy() + gramian_array = gramian.cpu().detach().numpy().astype(np.float64) dimension = gramian.shape[0] regularization_array = self.reg_eps * np.eye(dimension)