Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Oct 30, 2024
2 parents 82e7a12 + 952acdf commit 3864c7f
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 6 deletions.
93 changes: 93 additions & 0 deletions pyscfad/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import numpy
import scipy

def logm(A, disp=True, real=False):
"""Compute matrix logarithm.
Compute the matrix logarithm ensuring that it is real
when `A` is nonsingular and each Jordan block of `A` belonging
to negative eigenvalue occurs an even number of times.
Parameters
----------
A : (N, N) array_like
Matrix whose logarithm to evaluate
real : bool, default=False
If `True`, compute a real logarithm of a real matrix if it exists.
Otherwise, call `scipy.linalg.logm`.
See Also
--------
scipy.linalg.logm
"""
if real and numpy.isreal(A).all():
A = numpy.real(A)
# perform the real Schur decomposition
t, q = scipy.linalg.schur(A)

n = A.shape[0]
idx = 0
real_output = True
while idx < n:
# final block reached for an odd number of orbitals
if idx == n - 1:
# single positive block
if t[idx,idx] > 0:
t[idx,idx] = numpy.log(t[idx,idx])
# single negative block
else:
real_output = False
break
else:
diag = numpy.isclose(t[idx,idx+1], 0) and numpy.isclose(t[idx+1,idx], 0)
# single positive block
if t[idx,idx] > 0 and diag:
t[idx,idx] = numpy.log(t[idx,idx])
# pair of two negative blocks
elif (
t[idx,idx] < 0
and diag
and numpy.isclose(t[idx,idx], t[idx + 1,idx + 1])
):
log_lambda = numpy.log(-t[idx,idx])
t[idx:idx+2,idx:idx+2] = numpy.array(
[[log_lambda, numpy.pi], [-numpy.pi, log_lambda]]
)
idx += 1
# antisymmetric 2x2 block
elif (
numpy.isclose(t[idx,idx], t[idx + 1,idx + 1])
and numpy.isclose(t[idx + 1,idx], -t[idx,idx + 1])
):
log_comp = numpy.log(complex(t[idx,idx], t[idx,idx + 1]))
t[idx:idx+2,idx:idx+2] = numpy.array(
[
[numpy.real(log_comp), numpy.imag(log_comp)],
[-numpy.imag(log_comp), numpy.real(log_comp)],
],
)
idx += 1
else:
real_output = False
break
idx += 1

if not real_output:
return scipy.linalg.logm(A, disp=disp)

F = q @ t @ q.T

# NOTE copied from scipy
errtol = 1000 * numpy.finfo("d").eps
# TODO use a better error approximation
errest = scipy.linalg.norm(scipy.linalg.expm(F) - A, 1) / scipy.linalg.norm(A, 1)
if disp:
if not numpy.isfinite(errest) or errest >= errtol:
print("logm result may be inaccurate, approximate err =", errest)
return F
else:
return F, errest

else:
return scipy.linalg.logm(A, disp=disp)

9 changes: 4 additions & 5 deletions pyscfad/lo/boys.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy
import scipy
import jax
from pyscf.lib import logger
from pyscf.lo import boys as pyscf_boys
Expand All @@ -11,6 +10,7 @@
pack_uniq_var,
)
from pyscfad.tools.linear_solver import gen_gmres
from pyscfad.scipy.linalg import logm

# modified from pyscf v2.6
def kernel(localizer, mo_coeff=None, callback=None, verbose=None,
Expand Down Expand Up @@ -133,12 +133,11 @@ def _extract_x0(loc, u):
'Saddle point reached in orbital localization.'
f'\n{h_diag}')

#TODO ensure real solution of logm
mat = scipy.linalg.logm(u)
if not numpy.allclose(mat.imag, 0, atol=1e-6):
mat = logm(u, real=True)
if not numpy.isreal(mat).all():
raise RuntimeError('Complex solutions are not supported for '
'differentiating the Boys localiztion.')
x = pack_uniq_var(mat.real)
x = pack_uniq_var(numpy.real(mat))
return x

def _boys(x, mol, mo_coeff, *,
Expand Down
4 changes: 4 additions & 0 deletions pyscfad/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
eigh as eigh,
svd as svd,
)

from pyscfad._src.scipy.linalg import (
logm as logm,
)
16 changes: 15 additions & 1 deletion pyscfad/scipy/test/test_linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from functools import partial
import numpy as np
import scipy
import jax
from jax import numpy as jnp
from jax import scipy as jsp
from pyscfad.scipy.linalg import eigh, svd
from pyscfad.scipy.linalg import eigh, svd, logm

def test_eigh():
a = jnp.ones((2,2))
Expand Down Expand Up @@ -41,3 +43,15 @@ def test_svd():
assert abs(jac0[1] - jac1[1]).max() < 1e-7
assert abs(jac0[2] - jac1[2]).max() < 1e-7

def test_logm():
theta = 2 * np.pi / 3
a = np.array([[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]])
A = scipy.linalg.block_diag(np.diag(np.array([-1, -1, 1])), a)

X = logm(A, real=True)
assert np.isreal(X).all()
assert scipy.linalg.norm(scipy.linalg.expm(X) - A, 1) < 1e-12

assert np.allclose(logm(A), scipy.linalg.logm(A))

0 comments on commit 3864c7f

Please sign in to comment.