diff --git a/matrix_functions.py b/matrix_functions.py index 1c25406..94b2269 100644 --- a/matrix_functions.py +++ b/matrix_functions.py @@ -134,9 +134,13 @@ def _matrix_perturbation( """ return ( - A.add(torch.eye(A.shape[0], dtype=A.dtype, device=A.device), alpha=epsilon) - if not is_eigenvalues - else A + epsilon + ( + A.add(torch.eye(A.shape[0], dtype=A.dtype, device=A.device), alpha=epsilon) + if not is_eigenvalues + else A + epsilon + ) + if epsilon != 0 + else A # Fast path when epsilon is 0.0, return A without modification ) @@ -407,17 +411,19 @@ def matrix_eigendecomposition( raise ValueError(f"{epsilon=} should be 0.0 when using pseudo-inverse!") # Add epsilon to the diagonal to help with numerical stability of the eigenvalue decomposition - # Only do it when perturb_before_computation is True - if ( - isinstance( + # Only do it when perturb_before_computation is True. + A_ridge = _matrix_perturbation( + A=A, + # If perturb_before_computation is False, we take the fast path in _matrix_perturbation() by effectively setting epsilon to 0, avoiding the perturbation step. + # If the perturb_before_computation field doesn't exist in the config, default to 0 (equivalent to False). + epsilon=epsilon + * getattr( eigendecomposition_config.rank_deficient_stability_config, - PerturbationConfig, - ) - and eigendecomposition_config.rank_deficient_stability_config.perturb_before_computation - ): - A_ridge = _matrix_perturbation(A, epsilon=epsilon, is_eigenvalues=False) - else: - A_ridge = A + "perturb_before_computation", + 0, + ), + is_eigenvalues=False, + ) match eigendecomposition_config: case EighEigendecompositionConfig(): @@ -620,13 +626,14 @@ def _matrix_inverse_root_eigen( # Add epsilon to the diagonal to help with numerical stability of the eigenvalue decomposition # Only do it when perturb_before_computation is True - if ( - isinstance(rank_deficient_stability_config, PerturbationConfig) - and rank_deficient_stability_config.perturb_before_computation - ): - A_ridge = _matrix_perturbation(A, epsilon=epsilon, is_eigenvalues=False) - else: - A_ridge = A + A_ridge = _matrix_perturbation( + A=A, + # If perturb_before_computation is False, we take the fast path in _matrix_perturbation() by effectively setting epsilon to 0, avoiding the perturbation step. + # If the perturb_before_computation field doesn't exist in the config, default to 0 (equivalent to False). + epsilon=epsilon + * getattr(rank_deficient_stability_config, "perturb_before_computation", 0), + is_eigenvalues=False, + ) # compute eigendecomposition and compute minimum eigenvalue L, Q = _eigh_eigenvalue_decomposition(