Skip to content

Commit

Permalink
test lo symmetry
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Oct 10, 2023
1 parent d250371 commit b7cf4b3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 34 deletions.
68 changes: 64 additions & 4 deletions pyscfad/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,85 @@
from scipy.sparse.linalg import LinearOperator
from functools import reduce
import numpy
from scipy.sparse.linalg import LinearOperator, eigsh
from scipy.sparse.linalg import gmres as scipy_gmres

def gmres(A_or_matvec, b, x0=None, *,
tol=1e-05, atol=None, restart=None, maxiter=None, M=None,
callback=None, callback_type=None):
b_shape = b.shape
if x0 is not None:
x0 = x0.flatten()
x0 = x0.ravel()

if callable(A_or_matvec):
def _matvec(u):
return A_or_matvec(u.reshape(b_shape)).flatten()
return A_or_matvec(u.reshape(b_shape)).ravel()
A = LinearOperator((b.size, b.size), matvec=_matvec)
else:
A = A_or_matvec

u, info = scipy_gmres(A, b.flatten(), x0=x0, tol=tol,
u, info = scipy_gmres(A, b.ravel(), x0=x0, tol=tol,
restart=restart, maxiter=maxiter, M=M, atol=atol,
callback=callback, callback_type=callback_type)
if info > 0:
raise RuntimeError(f'scipy gmres failed to converge in {info} iterations.')
elif info < 0:
raise RuntimeError('scipy gmres failed.')
return u.reshape(b_shape), info


def gmres_safe(A_or_matvec, b, x0=None, *,
tol=1e-05, atol=None, restart=None, maxiter=None, M=None,
callback=None, callback_type=None, cond=1e-6):
b_shape = b.shape

if callable(A_or_matvec):
def _matvec(u):
return A_or_matvec(u.reshape(b_shape)).ravel()
A = LinearOperator((b.size, b.size), matvec=_matvec)
else:
A = A_or_matvec

k = 3
v_null = None
while True:
w, v = eigsh(A, k=k, which='SM')
if numpy.all(abs(w) >= cond) and w[-1] > 0:
break
elif numpy.any(abs(w) < cond) and w[-1] >= cond:
v_null = v[:,abs(w)<cond]
break
else:
k *= 2
continue

if v_null is None:
return gmres(A_or_matvec, b, x0=x0,
tol=tol, atol=atol, restart=restart, maxiter=maxiter, M=M,
callback=callback, callback_type=callback_type)

_proj_to_null = lambda u: numpy.dot(v_null, numpy.dot(v_null.T, u))
_proj_to_range = lambda u: u - _proj_to_null(u)
if callable(A_or_matvec):
def _matvec(u):
u_range = _proj_to_range(u)
Au = A_or_matvec(u_range.reshape(b_shape)).ravel()
return _proj_to_range(Au) + _proj_to_null(u)
A = LinearOperator((b.size, b.size), matvec=_matvec)
else:
P_null = numpy.dot(v_null, v_null.T)
P_range = numpy.eye(A_or_matvec.shape[0]) - P_null
A = reduce(numpy.dot, (P_range, A_or_matvec, P_range)) + P_null

if x0 is not None:
x0 = x0.ravel()
u, info = scipy_gmres(A, b.ravel(), x0=x0, tol=tol,
restart=restart, maxiter=maxiter, M=M, atol=atol,
callback=callback, callback_type=callback_type)

if info > 0:
raise RuntimeError(f'scipy gmres failed to converge in {info} iterations.')
elif info < 0:
raise RuntimeError('scipy gmres failed.')

u = _proj_to_range(u)
return u.reshape(b_shape), info
51 changes: 26 additions & 25 deletions pyscfad/lo/pipek.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from functools import reduce, partial
import numpy
import scipy
import jax
from pyscf import numpy as np
from pyscf.lib import logger
from pyscf.lo.pipek import PM
from pyscfad import config
from pyscfad.lib import vmap
from pyscfad.implicit_diff import make_implicit_diff
from pyscfad.soscf.ciah import extract_rotation, pack_uniq_var
from pyscfad.tools.linear_solver import precond_by_hdiag, gen_gmres, GMRESDisp
from pyscfad.tools.linear_solver import gen_gmres, GMRESDisp
from pyscfad.lo import orth

def atomic_pops(mol, mo_coeff, method='mulliken'):
Expand Down Expand Up @@ -60,26 +62,19 @@ def fn(aoslice, csc, idx):

return proj


def opt_cond(x, mol, mo_coeff, pop_method='mulliken', exponent=2):
def cost_function(x, mol, mo_coeff, pop_method='mulliken', exponent=2):
u = extract_rotation(x)
mo_coeff = np.dot(mo_coeff, u)
pop = atomic_pops(mol, mo_coeff, pop_method)
if exponent == 2:
g0 = np.einsum('xii,xip->pi', pop, pop)
g = -pack_uniq_var(g0-g0.conj().T) * 2
return -np.einsum('xii,xii->', pop, pop)
else:
pop3 = np.einsum('xii->xi', pop)**3
g0 = np.einsum('xi,xip->pi', pop3, pop)
g = -pack_uniq_var(g0-g0.conj().T) * 4

h_diag = np.einsum('xii,xpp->pi', pop, pop) * 2
g_diag = g0.diagonal()
h_diag-= g_diag + g_diag.reshape(-1,1)
h_diag+= np.einsum('xip,xip->pi', pop, pop) * 2
h_diag+= np.einsum('xip,xpi->pi', pop, pop) * 2
h_diag = -pack_uniq_var(h_diag) * 2
return g, h_diag
pop2 = np.einsum('xii->xi', pop)**2
return -np.einsum('xi,xi', pop2, pop2)

def opt_cond(x, mol, mo_coeff, pop_method='mulliken', exponent=2):
g = jax.grad(cost_function, 0)(x, mol, mo_coeff, pop_method, exponent)
return g

def _pm(x, mol, mo_coeff, *,
pop_method=None, exponent=None, init_guess=None,
Expand All @@ -102,34 +97,40 @@ def _pm(x, mol, mo_coeff, *,
u, sorted_idx = loc.kernel(mo_coeff=mo_coeff, return_u=True)
else:
u, sorted_idx = loc.kernel(mo_coeff=None, return_u=True)
h_diag = loc.gen_g_hop(u)[2]
if numpy.any(h_diag < 0):
logger.warn(loc, 'Saddle point reached in orbital localization.')
if numpy.linalg.det(u) < 0:
u[:,0] *= -1
mat = scipy.linalg.logm(u)
x = pack_uniq_var(mat)
if numpy.any(abs(x.imag) > 1e-9):
if numpy.any(abs(x.imag) > 1e-6):
raise RuntimeError('Complex solutions are not supported for '
'differentiating the Boys localiztion.')
else:
x = x.real
logger.debug(loc, 'PM localization |g| = %.6g',
numpy.linalg.norm(opt_cond(x, mol, mo_coeff,
pop_method=pop_method,
exponent=exponent)))
return x, sorted_idx

def pm(mol, mo_coeff, *,
pop_method='mulliken', exponent=2, init_guess=None,
conv_tol=None, conv_tol_grad=None, max_cycle=None):
conv_tol=None, conv_tol_grad=None, max_cycle=None, symmetry=False):
if mo_coeff.shape[-1] == 1:
return mo_coeff

gen_precond = None
if config.moleintor_opt:
gen_precond = precond_by_hdiag
solver = gen_gmres(restart=40,
callback=GMRESDisp(mol.verbose), callback_type='pr_norm')
callback=GMRESDisp(mol.verbose),
callback_type='pr_norm',
safe=symmetry)
_pm_iter = make_implicit_diff(_pm, implicit_diff=True,
fixed_point=False,
optimality_cond=partial(opt_cond,
pop_method=pop_method,
exponent=exponent),
solver=solver, has_aux=True,
optimality_fun_has_aux=True,
gen_precond=gen_precond)
solver=solver, has_aux=True)

x, sorted_idx = _pm_iter(None, mol, mo_coeff,
pop_method=pop_method,
Expand Down
1 change: 1 addition & 0 deletions pyscfad/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable = unused-import
from pyscfad._src.scipy.sparse.linalg import (
gmres,
gmres_safe,
)
19 changes: 14 additions & 5 deletions pyscfad/tools/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from scipy.sparse.linalg import LinearOperator
from jax.scipy.sparse.linalg import gmres as jax_gmres
from pyscfad.scipy.sparse.linalg import gmres as pyscfad_gmres
from pyscfad.scipy.sparse.linalg import gmres_safe

class GMRESDisp:
def __init__(self, disp=False):
Expand All @@ -27,13 +28,21 @@ def precond_by_hdiag(h_diag, thresh=1e-12):

def gen_gmres_with_default_kwargs(tol=1e-6, atol=1e-12, maxiter=30,
x0=None, M=None, restart=20,
callback=None, callback_type=None):
callback=None, callback_type=None,
safe=False, **kwargs):
from pyscfad import config
if config.moleintor_opt:
gmres = partial(pyscfad_gmres,
tol=tol, atol=atol, maxiter=maxiter,
x0=x0, M=M, restart=restart,
callback=callback, callback_type=callback_type)
if safe:
gmres = partial(gmres_safe,
tol=tol, atol=atol, maxiter=maxiter,
x0=x0, M=M, restart=restart,
callback=callback, callback_type=callback_type,
**kwargs)
else:
gmres = partial(pyscfad_gmres,
tol=tol, atol=atol, maxiter=maxiter,
x0=x0, M=M, restart=restart,
callback=callback, callback_type=callback_type)
else:
gmres = partial(jax_gmres,
tol=tol, atol=atol, maxiter=maxiter,
Expand Down

0 comments on commit b7cf4b3

Please sign in to comment.