From b7cf4b3aa2d21785e89c80b5aef66d9bdfa2e8a2 Mon Sep 17 00:00:00 2001 From: fishjojo Date: Tue, 10 Oct 2023 15:18:13 -0700 Subject: [PATCH] test lo symmetry --- pyscfad/_src/scipy/sparse/linalg.py | 68 +++++++++++++++++++++++++++-- pyscfad/lo/pipek.py | 51 +++++++++++----------- pyscfad/scipy/sparse/linalg.py | 1 + pyscfad/tools/linear_solver.py | 19 +++++--- 4 files changed, 105 insertions(+), 34 deletions(-) diff --git a/pyscfad/_src/scipy/sparse/linalg.py b/pyscfad/_src/scipy/sparse/linalg.py index ff863eb8..c54d7636 100644 --- a/pyscfad/_src/scipy/sparse/linalg.py +++ b/pyscfad/_src/scipy/sparse/linalg.py @@ -1,4 +1,6 @@ -from scipy.sparse.linalg import LinearOperator +from functools import reduce +import numpy +from scipy.sparse.linalg import LinearOperator, eigsh from scipy.sparse.linalg import gmres as scipy_gmres def gmres(A_or_matvec, b, x0=None, *, @@ -6,16 +8,16 @@ def gmres(A_or_matvec, b, x0=None, *, callback=None, callback_type=None): b_shape = b.shape if x0 is not None: - x0 = x0.flatten() + x0 = x0.ravel() if callable(A_or_matvec): def _matvec(u): - return A_or_matvec(u.reshape(b_shape)).flatten() + return A_or_matvec(u.reshape(b_shape)).ravel() A = LinearOperator((b.size, b.size), matvec=_matvec) else: A = A_or_matvec - u, info = scipy_gmres(A, b.flatten(), x0=x0, tol=tol, + u, info = scipy_gmres(A, b.ravel(), x0=x0, tol=tol, restart=restart, maxiter=maxiter, M=M, atol=atol, callback=callback, callback_type=callback_type) if info > 0: @@ -23,3 +25,61 @@ def _matvec(u): elif info < 0: raise RuntimeError('scipy gmres failed.') return u.reshape(b_shape), info + + +def gmres_safe(A_or_matvec, b, x0=None, *, + tol=1e-05, atol=None, restart=None, maxiter=None, M=None, + callback=None, callback_type=None, cond=1e-6): + b_shape = b.shape + + if callable(A_or_matvec): + def _matvec(u): + return A_or_matvec(u.reshape(b_shape)).ravel() + A = LinearOperator((b.size, b.size), matvec=_matvec) + else: + A = A_or_matvec + + k = 3 + v_null = None + while True: + w, v = eigsh(A, k=k, which='SM') + if numpy.all(abs(w) >= cond) and w[-1] > 0: + break + elif numpy.any(abs(w) < cond) and w[-1] >= cond: + v_null = v[:,abs(w) 0: + raise RuntimeError(f'scipy gmres failed to converge in {info} iterations.') + elif info < 0: + raise RuntimeError('scipy gmres failed.') + + u = _proj_to_range(u) + return u.reshape(b_shape), info diff --git a/pyscfad/lo/pipek.py b/pyscfad/lo/pipek.py index a7b3704f..566c41c3 100644 --- a/pyscfad/lo/pipek.py +++ b/pyscfad/lo/pipek.py @@ -1,13 +1,15 @@ from functools import reduce, partial import numpy import scipy +import jax from pyscf import numpy as np +from pyscf.lib import logger from pyscf.lo.pipek import PM from pyscfad import config from pyscfad.lib import vmap from pyscfad.implicit_diff import make_implicit_diff from pyscfad.soscf.ciah import extract_rotation, pack_uniq_var -from pyscfad.tools.linear_solver import precond_by_hdiag, gen_gmres, GMRESDisp +from pyscfad.tools.linear_solver import gen_gmres, GMRESDisp from pyscfad.lo import orth def atomic_pops(mol, mo_coeff, method='mulliken'): @@ -60,26 +62,19 @@ def fn(aoslice, csc, idx): return proj - -def opt_cond(x, mol, mo_coeff, pop_method='mulliken', exponent=2): +def cost_function(x, mol, mo_coeff, pop_method='mulliken', exponent=2): u = extract_rotation(x) mo_coeff = np.dot(mo_coeff, u) pop = atomic_pops(mol, mo_coeff, pop_method) if exponent == 2: - g0 = np.einsum('xii,xip->pi', pop, pop) - g = -pack_uniq_var(g0-g0.conj().T) * 2 + return -np.einsum('xii,xii->', pop, pop) else: - pop3 = np.einsum('xii->xi', pop)**3 - g0 = np.einsum('xi,xip->pi', pop3, pop) - g = -pack_uniq_var(g0-g0.conj().T) * 4 - - h_diag = np.einsum('xii,xpp->pi', pop, pop) * 2 - g_diag = g0.diagonal() - h_diag-= g_diag + g_diag.reshape(-1,1) - h_diag+= np.einsum('xip,xip->pi', pop, pop) * 2 - h_diag+= np.einsum('xip,xpi->pi', pop, pop) * 2 - h_diag = -pack_uniq_var(h_diag) * 2 - return g, h_diag + pop2 = np.einsum('xii->xi', pop)**2 + return -np.einsum('xi,xi', pop2, pop2) + +def opt_cond(x, mol, mo_coeff, pop_method='mulliken', exponent=2): + g = jax.grad(cost_function, 0)(x, mol, mo_coeff, pop_method, exponent) + return g def _pm(x, mol, mo_coeff, *, pop_method=None, exponent=None, init_guess=None, @@ -102,34 +97,40 @@ def _pm(x, mol, mo_coeff, *, u, sorted_idx = loc.kernel(mo_coeff=mo_coeff, return_u=True) else: u, sorted_idx = loc.kernel(mo_coeff=None, return_u=True) + h_diag = loc.gen_g_hop(u)[2] + if numpy.any(h_diag < 0): + logger.warn(loc, 'Saddle point reached in orbital localization.') + if numpy.linalg.det(u) < 0: + u[:,0] *= -1 mat = scipy.linalg.logm(u) x = pack_uniq_var(mat) - if numpy.any(abs(x.imag) > 1e-9): + if numpy.any(abs(x.imag) > 1e-6): raise RuntimeError('Complex solutions are not supported for ' 'differentiating the Boys localiztion.') else: x = x.real + logger.debug(loc, 'PM localization |g| = %.6g', + numpy.linalg.norm(opt_cond(x, mol, mo_coeff, + pop_method=pop_method, + exponent=exponent))) return x, sorted_idx def pm(mol, mo_coeff, *, pop_method='mulliken', exponent=2, init_guess=None, - conv_tol=None, conv_tol_grad=None, max_cycle=None): + conv_tol=None, conv_tol_grad=None, max_cycle=None, symmetry=False): if mo_coeff.shape[-1] == 1: return mo_coeff - gen_precond = None - if config.moleintor_opt: - gen_precond = precond_by_hdiag solver = gen_gmres(restart=40, - callback=GMRESDisp(mol.verbose), callback_type='pr_norm') + callback=GMRESDisp(mol.verbose), + callback_type='pr_norm', + safe=symmetry) _pm_iter = make_implicit_diff(_pm, implicit_diff=True, fixed_point=False, optimality_cond=partial(opt_cond, pop_method=pop_method, exponent=exponent), - solver=solver, has_aux=True, - optimality_fun_has_aux=True, - gen_precond=gen_precond) + solver=solver, has_aux=True) x, sorted_idx = _pm_iter(None, mol, mo_coeff, pop_method=pop_method, diff --git a/pyscfad/scipy/sparse/linalg.py b/pyscfad/scipy/sparse/linalg.py index 732b5133..7c51089f 100644 --- a/pyscfad/scipy/sparse/linalg.py +++ b/pyscfad/scipy/sparse/linalg.py @@ -1,4 +1,5 @@ # pylint: disable = unused-import from pyscfad._src.scipy.sparse.linalg import ( gmres, + gmres_safe, ) diff --git a/pyscfad/tools/linear_solver.py b/pyscfad/tools/linear_solver.py index a66431ba..f2fdfd33 100644 --- a/pyscfad/tools/linear_solver.py +++ b/pyscfad/tools/linear_solver.py @@ -3,6 +3,7 @@ from scipy.sparse.linalg import LinearOperator from jax.scipy.sparse.linalg import gmres as jax_gmres from pyscfad.scipy.sparse.linalg import gmres as pyscfad_gmres +from pyscfad.scipy.sparse.linalg import gmres_safe class GMRESDisp: def __init__(self, disp=False): @@ -27,13 +28,21 @@ def precond_by_hdiag(h_diag, thresh=1e-12): def gen_gmres_with_default_kwargs(tol=1e-6, atol=1e-12, maxiter=30, x0=None, M=None, restart=20, - callback=None, callback_type=None): + callback=None, callback_type=None, + safe=False, **kwargs): from pyscfad import config if config.moleintor_opt: - gmres = partial(pyscfad_gmres, - tol=tol, atol=atol, maxiter=maxiter, - x0=x0, M=M, restart=restart, - callback=callback, callback_type=callback_type) + if safe: + gmres = partial(gmres_safe, + tol=tol, atol=atol, maxiter=maxiter, + x0=x0, M=M, restart=restart, + callback=callback, callback_type=callback_type, + **kwargs) + else: + gmres = partial(pyscfad_gmres, + tol=tol, atol=atol, maxiter=maxiter, + x0=x0, M=M, restart=restart, + callback=callback, callback_type=callback_type) else: gmres = partial(jax_gmres, tol=tol, atol=atol, maxiter=maxiter,