Skip to content

Commit

Permalink
refactor(aggregation): Explicitly cast arrays to float64 (#236)
Browse files Browse the repository at this point in the history
When casting a Tensor to a numpy array, the obtained array is typically typed as float32 while numpy typically uses float64. This makes the cast explicit.
  • Loading branch information
PierreQuinton authored Jan 21, 2025
1 parent 95bb00d commit 3a3c459
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def forward(self, matrix: Tensor) -> Tensor:
U, S, _ = torch.svd(gramian)

reduced_matrix = U @ S.sqrt().diag()
reduced_array = reduced_matrix.cpu().detach().numpy()
reduced_array = reduced_matrix.cpu().detach().numpy().astype(np.float64)

dimension = matrix.shape[0]
reduced_g_0 = reduced_array.T @ np.ones(dimension) / dimension
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def __init__(

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
weights_array = weights.cpu().detach().numpy()
weights_array = weights.cpu().detach().numpy().astype(np.float64)

gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy()
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

# Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s),
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, matrix: Tensor) -> Tensor:

def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor:
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy()
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

regularization_array = self.reg_eps * np.eye(dimension)
Expand Down

0 comments on commit 3a3c459

Please sign in to comment.