Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure matrix logarithms are real for orbital localization #49

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions pyscfad/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy
import scipy

def logm(A, disp=True, real=False):
"""Compute matrix logarithm.

Compute the matrix logarithm ensuring that it is real
when `A` is nonsingular and each Jordan block of `A` belonging
to negative eigenvalue occurs an even number of times.

Parameters
----------
A : (N, N) array_like
Matrix whose logarithm to evaluate
real : bool, default=False
If `True`, compute a real logarithm of a real matrix if it exists.
Otherwise, call `scipy.linalg.logm`.

See Also
--------
scipy.linalg.logm
"""
if real and numpy.isreal(A).all():
A = numpy.real(A)
# perform the real Schur decomposition
t, q = scipy.linalg.schur(A)

n = A.shape[0]
idx = 0
real_output = True
while idx < n:
# final block reached for an odd number of orbitals
if idx == n - 1:
# single positive block
if t[idx,idx] > 0:
t[idx,idx] = numpy.log(t[idx,idx])
# single negative block
else:
real_output = False
break
else:
diag = numpy.isclose(t[idx,idx+1], 0) and numpy.isclose(t[idx+1,idx], 0)
# single positive block
if t[idx,idx] > 0 and diag:
t[idx,idx] = numpy.log(t[idx,idx])
# pair of two negative blocks
elif (
t[idx,idx] < 0
and diag
and numpy.isclose(t[idx,idx], t[idx + 1,idx + 1])
):
log_lambda = numpy.log(-t[idx,idx])
t[idx:idx+2,idx:idx+2] = numpy.array(
[[log_lambda, numpy.pi], [-numpy.pi, log_lambda]]
)
idx += 1
# antisymmetric 2x2 block
elif (
numpy.isclose(t[idx,idx], t[idx + 1,idx + 1])
and numpy.isclose(t[idx + 1,idx], -t[idx,idx + 1])
):
log_comp = numpy.log(complex(t[idx,idx], t[idx,idx + 1]))
t[idx:idx+2,idx:idx+2] = numpy.array(
[
[numpy.real(log_comp), numpy.imag(log_comp)],
[-numpy.imag(log_comp), numpy.real(log_comp)],
],
)
idx += 1
else:
real_output = False
break
idx += 1

if not real_output:
return scipy.linalg.logm(A, disp=disp)

F = q @ t @ q.T

# NOTE copied from scipy
errtol = 1000 * numpy.finfo("d").eps
# TODO use a better error approximation
errest = scipy.linalg.norm(scipy.linalg.expm(F) - A, 1) / scipy.linalg.norm(A, 1)
if disp:
if not numpy.isfinite(errest) or errest >= errtol:
print("logm result may be inaccurate, approximate err =", errest)
return F
else:
return F, errest

else:
return scipy.linalg.logm(A, disp=disp)

9 changes: 4 additions & 5 deletions pyscfad/lo/boys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy
import scipy
import jax
from pyscf.lib import logger
from pyscf.lo import boys as pyscf_boys
Expand All @@ -11,6 +10,7 @@
pack_uniq_var,
)
from pyscfad.tools.linear_solver import gen_gmres
from pyscfad.scipy.linalg import logm

# modified from pyscf v2.6
def kernel(localizer, mo_coeff=None, callback=None, verbose=None,
Expand Down Expand Up @@ -133,12 +133,11 @@ def _extract_x0(loc, u):
'Saddle point reached in orbital localization.'
f'\n{h_diag}')

#TODO ensure real solution of logm
mat = scipy.linalg.logm(u)
if not numpy.allclose(mat.imag, 0, atol=1e-6):
mat = logm(u, real=True)
if not numpy.isreal(mat).all():
raise RuntimeError('Complex solutions are not supported for '
'differentiating the Boys localiztion.')
x = pack_uniq_var(mat.real)
x = pack_uniq_var(numpy.real(mat))
return x

def _boys(x, mol, mo_coeff, *,
Expand Down
4 changes: 4 additions & 0 deletions pyscfad/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
eigh as eigh,
svd as svd,
)

from pyscfad._src.scipy.linalg import (
logm as logm,
)
16 changes: 15 additions & 1 deletion pyscfad/scipy/test/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import partial
import numpy as np
import scipy
import jax
from jax import numpy as jnp
from jax import scipy as jsp
from pyscfad.scipy.linalg import eigh, svd
from pyscfad.scipy.linalg import eigh, svd, logm

def test_eigh():
a = jnp.ones((2,2))
Expand Down Expand Up @@ -41,3 +43,15 @@ def test_svd():
assert abs(jac0[1] - jac1[1]).max() < 1e-7
assert abs(jac0[2] - jac1[2]).max() < 1e-7

def test_logm():
theta = 2 * np.pi / 3
a = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
A = scipy.linalg.block_diag(np.diag(np.array([-1, -1, 1])), a)

X = logm(A, real=True)
assert np.isreal(X).all()
assert scipy.linalg.norm(scipy.linalg.expm(X) - A, 1) < 1e-12

assert np.allclose(logm(A), scipy.linalg.logm(A))

Loading