From e0fcf1556ff56e7f0739e229a5260960ac6aa215 Mon Sep 17 00:00:00 2001 From: Jonas Greiner Date: Thu, 12 Sep 2024 15:10:01 +0200 Subject: [PATCH 1/2] Ensure matrix logarithms are real for orbital localization --- pyscfad/_src/scipy/linalg.py | 74 ++++++++++++++++++++++++++++++++++++ pyscfad/lo/boys.py | 10 ++--- 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/pyscfad/_src/scipy/linalg.py b/pyscfad/_src/scipy/linalg.py index e69de29b..1f3cf64f 100644 --- a/pyscfad/_src/scipy/linalg.py +++ b/pyscfad/_src/scipy/linalg.py @@ -0,0 +1,74 @@ +import numpy +import scipy + + +def logm(A, real=False, **kwargs): + ''' + Calculates the matrix logarithm ensuring that it is real when the matrix is normal. + For a normal matrix, the Schur-decomposed matrix t is block-diagonal with blocks of + size 1 and 2, the blocks of size 1 are either positive numbers or they come in + pairs when they are negative while the blocks of size 2 always have the same value + along the diagonal and values with different signs but the same magnitude on the + off-diagonals. Since the block-diagonal matrix is similar to the original matrix, + the real logarithm can be calculated by determining the real logarithm of the + individual blocks and backtransforming with the Schur vectors + ''' + if real: + # perform the real Schur decomposition + t, q = scipy.linalg.schur(A) + + norb = A.shape[0] + idx = 0 + normalmatrix = True + while idx < norb: + # final block reached for an odd number of orbitals + if (idx == norb - 1): + # single positive block + if t[idx,idx] > 0.: + t[idx,idx] = numpy.log(t[idx,idx]) + # single negative block (should not happen for normal matrix) + else: + normalmatrix = False + break + else: + diag = numpy.isclose(t[idx,idx+1], 0.0) and numpy.isclose(t[idx+1,idx], 0.0) + # single positive block + if t[idx,idx] > 0. and diag: + t[idx,idx] = numpy.log(t[idx,idx]) + # pair of two negative blocks + elif ( + t[idx,idx] < 0. + and diag + and numpy.isclose(t[idx,idx],t[idx + 1,idx + 1]) + ): + log_lambda = numpy.log(-t[idx,idx]) + t[idx:idx+2,idx:idx+2] = numpy.array( + [[log_lambda, numpy.pi], [-numpy.pi, log_lambda]] + ) + idx += 1 + # antisymmetric 2x2 block + elif ( + numpy.isclose(t[idx,idx], t[idx + 1,idx + 1]) + and numpy.isclose(t[idx + 1,idx],-t[idx,idx + 1]) + ): + log_comp = numpy.log(complex(t[idx,idx], t[idx,idx + 1])) + t[idx:idx+2,idx:idx+2] = numpy.array( + [ + [numpy.real(log_comp), numpy.imag(log_comp)], + [-numpy.imag(log_comp), numpy.real(log_comp)], + ], + ) + idx += 1 + # should not happen for normal matrix + else: + normalmatrix = False + idx += 1 + + if not normalmatrix: + raise ValueError( + 'Real matrix logarithm can only be ensured for normal matrix' + ) + + return q @ t @ q.T + else: + return scipy.linalg.logm(A, **kwargs) diff --git a/pyscfad/lo/boys.py b/pyscfad/lo/boys.py index 2afe224f..9709b595 100644 --- a/pyscfad/lo/boys.py +++ b/pyscfad/lo/boys.py @@ -11,6 +11,7 @@ pack_uniq_var, ) from pyscfad.tools.linear_solver import gen_gmres +from pyscfad._src.scipy.linalg import logm # modified from pyscf v2.6 def kernel(localizer, mo_coeff=None, callback=None, verbose=None, @@ -132,13 +133,8 @@ def _extract_x0(loc, u): logger.warn(loc, 'Saddle point reached in orbital localization.' f'\n{h_diag}') - - #TODO ensure real solution of logm - mat = scipy.linalg.logm(u) - if not numpy.allclose(mat.imag, 0, atol=1e-6): - raise RuntimeError('Complex solutions are not supported for ' - 'differentiating the Boys localiztion.') - x = pack_uniq_var(mat.real) + + x = pack_uniq_var(logm(u, real=True)) return x def _boys(x, mol, mo_coeff, *, From 894cefab97ed07d38a12ae6016c0b3fb3a0520db Mon Sep 17 00:00:00 2001 From: fishjojo Date: Thu, 12 Sep 2024 12:43:36 -0700 Subject: [PATCH 2/2] add test --- pyscfad/_src/scipy/linalg.py | 87 +++++++++++++++++++------------ pyscfad/lo/boys.py | 11 ++-- pyscfad/scipy/linalg.py | 4 ++ pyscfad/scipy/test/test_linalg.py | 16 +++++- 4 files changed, 79 insertions(+), 39 deletions(-) diff --git a/pyscfad/_src/scipy/linalg.py b/pyscfad/_src/scipy/linalg.py index 1f3cf64f..d3e841ea 100644 --- a/pyscfad/_src/scipy/linalg.py +++ b/pyscfad/_src/scipy/linalg.py @@ -1,45 +1,53 @@ import numpy import scipy +def logm(A, disp=True, real=False): + """Compute matrix logarithm. -def logm(A, real=False, **kwargs): - ''' - Calculates the matrix logarithm ensuring that it is real when the matrix is normal. - For a normal matrix, the Schur-decomposed matrix t is block-diagonal with blocks of - size 1 and 2, the blocks of size 1 are either positive numbers or they come in - pairs when they are negative while the blocks of size 2 always have the same value - along the diagonal and values with different signs but the same magnitude on the - off-diagonals. Since the block-diagonal matrix is similar to the original matrix, - the real logarithm can be calculated by determining the real logarithm of the - individual blocks and backtransforming with the Schur vectors - ''' - if real: + Compute the matrix logarithm ensuring that it is real + when `A` is nonsingular and each Jordan block of `A` belonging + to negative eigenvalue occurs an even number of times. + + Parameters + ---------- + A : (N, N) array_like + Matrix whose logarithm to evaluate + real : bool, default=False + If `True`, compute a real logarithm of a real matrix if it exists. + Otherwise, call `scipy.linalg.logm`. + + See Also + -------- + scipy.linalg.logm + """ + if real and numpy.isreal(A).all(): + A = numpy.real(A) # perform the real Schur decomposition t, q = scipy.linalg.schur(A) - norb = A.shape[0] + n = A.shape[0] idx = 0 - normalmatrix = True - while idx < norb: + real_output = True + while idx < n: # final block reached for an odd number of orbitals - if (idx == norb - 1): + if idx == n - 1: # single positive block - if t[idx,idx] > 0.: + if t[idx,idx] > 0: t[idx,idx] = numpy.log(t[idx,idx]) - # single negative block (should not happen for normal matrix) + # single negative block else: - normalmatrix = False + real_output = False break else: - diag = numpy.isclose(t[idx,idx+1], 0.0) and numpy.isclose(t[idx+1,idx], 0.0) + diag = numpy.isclose(t[idx,idx+1], 0) and numpy.isclose(t[idx+1,idx], 0) # single positive block - if t[idx,idx] > 0. and diag: + if t[idx,idx] > 0 and diag: t[idx,idx] = numpy.log(t[idx,idx]) # pair of two negative blocks elif ( - t[idx,idx] < 0. + t[idx,idx] < 0 and diag - and numpy.isclose(t[idx,idx],t[idx + 1,idx + 1]) + and numpy.isclose(t[idx,idx], t[idx + 1,idx + 1]) ): log_lambda = numpy.log(-t[idx,idx]) t[idx:idx+2,idx:idx+2] = numpy.array( @@ -48,27 +56,38 @@ def logm(A, real=False, **kwargs): idx += 1 # antisymmetric 2x2 block elif ( - numpy.isclose(t[idx,idx], t[idx + 1,idx + 1]) - and numpy.isclose(t[idx + 1,idx],-t[idx,idx + 1]) + numpy.isclose(t[idx,idx], t[idx + 1,idx + 1]) + and numpy.isclose(t[idx + 1,idx], -t[idx,idx + 1]) ): log_comp = numpy.log(complex(t[idx,idx], t[idx,idx + 1])) t[idx:idx+2,idx:idx+2] = numpy.array( [ - [numpy.real(log_comp), numpy.imag(log_comp)], + [numpy.real(log_comp), numpy.imag(log_comp)], [-numpy.imag(log_comp), numpy.real(log_comp)], ], ) idx += 1 - # should not happen for normal matrix else: - normalmatrix = False + real_output = False + break idx += 1 - if not normalmatrix: - raise ValueError( - 'Real matrix logarithm can only be ensured for normal matrix' - ) + if not real_output: + return scipy.linalg.logm(A, disp=disp) + + F = q @ t @ q.T + + # NOTE copied from scipy + errtol = 1000 * numpy.finfo("d").eps + # TODO use a better error approximation + errest = scipy.linalg.norm(scipy.linalg.expm(F) - A, 1) / scipy.linalg.norm(A, 1) + if disp: + if not numpy.isfinite(errest) or errest >= errtol: + print("logm result may be inaccurate, approximate err =", errest) + return F + else: + return F, errest - return q @ t @ q.T else: - return scipy.linalg.logm(A, **kwargs) + return scipy.linalg.logm(A, disp=disp) + diff --git a/pyscfad/lo/boys.py b/pyscfad/lo/boys.py index 9709b595..a975c57e 100644 --- a/pyscfad/lo/boys.py +++ b/pyscfad/lo/boys.py @@ -1,5 +1,4 @@ import numpy -import scipy import jax from pyscf.lib import logger from pyscf.lo import boys as pyscf_boys @@ -11,7 +10,7 @@ pack_uniq_var, ) from pyscfad.tools.linear_solver import gen_gmres -from pyscfad._src.scipy.linalg import logm +from pyscfad.scipy.linalg import logm # modified from pyscf v2.6 def kernel(localizer, mo_coeff=None, callback=None, verbose=None, @@ -133,8 +132,12 @@ def _extract_x0(loc, u): logger.warn(loc, 'Saddle point reached in orbital localization.' f'\n{h_diag}') - - x = pack_uniq_var(logm(u, real=True)) + + mat = logm(u, real=True) + if not numpy.isreal(mat).all(): + raise RuntimeError('Complex solutions are not supported for ' + 'differentiating the Boys localiztion.') + x = pack_uniq_var(numpy.real(mat)) return x def _boys(x, mol, mo_coeff, *, diff --git a/pyscfad/scipy/linalg.py b/pyscfad/scipy/linalg.py index a90743d5..fefac339 100644 --- a/pyscfad/scipy/linalg.py +++ b/pyscfad/scipy/linalg.py @@ -4,3 +4,7 @@ eigh as eigh, svd as svd, ) + +from pyscfad._src.scipy.linalg import ( + logm as logm, +) diff --git a/pyscfad/scipy/test/test_linalg.py b/pyscfad/scipy/test/test_linalg.py index 7393bf00..25d09834 100644 --- a/pyscfad/scipy/test/test_linalg.py +++ b/pyscfad/scipy/test/test_linalg.py @@ -1,8 +1,10 @@ from functools import partial +import numpy as np +import scipy import jax from jax import numpy as jnp from jax import scipy as jsp -from pyscfad.scipy.linalg import eigh, svd +from pyscfad.scipy.linalg import eigh, svd, logm def test_eigh(): a = jnp.ones((2,2)) @@ -41,3 +43,15 @@ def test_svd(): assert abs(jac0[1] - jac1[1]).max() < 1e-7 assert abs(jac0[2] - jac1[2]).max() < 1e-7 +def test_logm(): + theta = 2 * np.pi / 3 + a = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + A = scipy.linalg.block_diag(np.diag(np.array([-1, -1, 1])), a) + + X = logm(A, real=True) + assert np.isreal(X).all() + assert scipy.linalg.norm(scipy.linalg.expm(X) - A, 1) < 1e-12 + + assert np.allclose(logm(A), scipy.linalg.logm(A)) +