Skip to content

Commit

Permalink
update lo, add geomopt
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Nov 19, 2023
1 parent b7cf4b3 commit 72eaa2b
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 55 deletions.
32 changes: 22 additions & 10 deletions pyscfad/_src/implicit_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def _map_back(diff_items, items, keys):
def root_vjp(optimality_fun, sol, args, cotangent,
solve=gmres, nondiff_argnums=(),
optfn_has_aux=False, solver_kwargs=None,
gen_precond=None):
gen_precond=None,
custom_vjp_from_optcond=False):
if solver_kwargs is None:
solver_kwargs = {}

Expand All @@ -29,7 +30,10 @@ def fun_sol(sol):
# FIXME M may not work for solvers other than scipy
M = None
if optfn_has_aux:
_, vjp_fun_sol, optfn_aux = jax.vjp(fun_sol, sol, has_aux=True)
if custom_vjp_from_optcond:
_, (vjp_fun_sol, optfn_aux) = fun_sol(sol)
else:
_, vjp_fun_sol, optfn_aux = jax.vjp(fun_sol, sol, has_aux=True)
if gen_precond is not None:
M = gen_precond(optfn_aux)
else:
Expand Down Expand Up @@ -58,7 +62,8 @@ def fun_args(*diff_args):
def _custom_root(solver_fun, optimality_fun, solve,
has_aux=False, nondiff_argnums=(), use_converged_args=None,
optfn_has_aux=False, solver_kwargs=None,
gen_precond=None):
gen_precond=None,
custom_vjp_from_optcond=False):
solver_fun_sig = inspect.signature(solver_fun)
optimality_fun_sig = inspect.signature(optimality_fun)

Expand Down Expand Up @@ -96,7 +101,8 @@ def solver_fun_rev(tup, cotangent):
nondiff_argnums=nondiff_argnums,
optfn_has_aux=optfn_has_aux,
solver_kwargs=solver_kwargs,
gen_precond=gen_precond)
gen_precond=gen_precond,
custom_vjp_from_optcond=custom_vjp_from_optcond)
vjps = (None,) + vjps
return vjps

Expand All @@ -113,7 +119,8 @@ def wrapped_solver_fun(*args, **kwargs):
def custom_root(optimality_fun, solve=None, has_aux=False,
nondiff_argnums=(), use_converged_args=None,
optfn_has_aux=False, solver_kwargs=None,
gen_precond=None):
gen_precond=None,
custom_vjp_from_optcond=False):
if solve is None:
solve = gmres

Expand All @@ -123,14 +130,16 @@ def wrapper(solver_fun):
use_converged_args=use_converged_args,
optfn_has_aux=optfn_has_aux,
solver_kwargs=solver_kwargs,
gen_precond=gen_precond)
gen_precond=gen_precond,
custom_vjp_from_optcond=custom_vjp_from_optcond)

return wrapper

def custom_fixed_point(fixed_point_fun, solve=None, has_aux=False,
nondiff_argnums=(), use_converged_args=None,
optfn_has_aux=False, solver_kwargs=None,
gen_precond=None):
gen_precond=None,
custom_vjp_from_optcond=False):

def optimality_fun(x0, *args):
return _Sub(fixed_point_fun(x0, *args), x0)
Expand All @@ -142,13 +151,15 @@ def optimality_fun(x0, *args):
use_converged_args=use_converged_args,
optfn_has_aux=optfn_has_aux,
solver_kwargs=solver_kwargs,
gen_precond=gen_precond)
gen_precond=gen_precond,
custom_vjp_from_optcond=custom_vjp_from_optcond)

def make_implicit_diff(fn, implicit_diff=False, fixed_point=True,
optimality_cond=None, solver=None, has_aux=False,
nondiff_argnums=(), use_converged_args=None,
optimality_fun_has_aux=False,
solver_kwargs=None, gen_precond=None):
solver_kwargs=None, gen_precond=None,
custom_vjp_from_optcond=False):
'''Wrap a function for implicit differentiation.
Args:
Expand Down Expand Up @@ -211,6 +222,7 @@ def make_implicit_diff(fn, implicit_diff=False, fixed_point=True,
use_converged_args=use_converged_args,
optfn_has_aux=optimality_fun_has_aux,
solver_kwargs=solver_kwargs,
gen_precond=gen_precond)(fn)
gen_precond=gen_precond,
custom_vjp_from_optcond=custom_vjp_from_optcond)(fn)
else:
return fn
4 changes: 3 additions & 1 deletion pyscfad/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ def _matvec(u):
k = 3
v_null = None
while True:
if k > b.size:
raise RuntimeError
w, v = eigsh(A, k=k, which='SM')
if numpy.all(abs(w) >= cond) and w[-1] > 0:
break
elif numpy.any(abs(w) < cond) and w[-1] >= cond:
v_null = v[:,abs(w)<cond]
break
else:
k *= 2
k += 3
continue

if v_null is None:
Expand Down
1 change: 0 additions & 1 deletion pyscfad/fci/fci_slow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pyscf import numpy as np
from pyscf.fci import cistring
from pyscf.fci import fci_slow as pyscf_fci_slow
#from pyscfad.lib import numpy as np
from pyscfad.lib import vmap
from pyscfad.gto import mole

Expand Down
7 changes: 7 additions & 0 deletions pyscfad/geomopt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

def optimize(*args, **kwargs):
try:
from . import geometric_solver as geom
except ImportError as err:
raise ImportError('Unable to import geometric.') from err
return geom.kernel(*args, **kwargs)
87 changes: 87 additions & 0 deletions pyscfad/geomopt/geometric_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
import uuid
import tempfile
import geometric

import numpy
from pyscf import lib
from pyscf.lib import logger
from pyscf.geomopt.addons import dump_mol_geometry

class PySCFADEngine(geometric.engine.Engine):
def __init__(self, mol, value_and_grad,
maxsteps=100, callback=None):
molecule = geometric.molecule.Molecule()
molecule.elem = [mol.atom_symbol(i) for i in range(mol.natm)]
molecule.xyzs = [mol.atom_coords()*lib.param.BOHR] # In Angstrom
super().__init__(molecule)

self.mol = mol
self.value_and_grad = value_and_grad

self.cycle = 0
self.maxsteps = maxsteps
self.callback = callback
self.e_last = 0
#self.assert_convergence = assert_convergence

def calc_new(self, coords, dirname):
if self.cycle >= self.maxsteps:
raise NotConvergedError('Geometry optimization is not converged in '
'%d iterations' % self.maxsteps)

mol = self.mol
value_and_grad = self.value_and_grad
self.cycle += 1
logger.note(mol, '\nGeometry optimization cycle %d', self.cycle)

# geomeTRIC requires coords and gradients in atomic unit
coords = numpy.asarray(coords.reshape(-1,3))
if mol.verbose >= logger.NOTE:
dump_mol_geometry(mol, coords*lib.param.BOHR)

if mol.symmetry:
pass

mol.coords = None
mol = mol.set_geom_(coords, unit='Bohr', inplace=False)
mol.build(trace_coords=True, trace_exp=False, trace_ctr_coeff=False)
energy, gradients = value_and_grad(mol)
energy = numpy.asarray(energy)
gradients = numpy.asarray(gradients)
logger.note(mol, 'cycle %d: E = %.12g dE = %g norm(grad) = %g',
self.cycle, energy, energy - self.e_last, numpy.linalg.norm(gradients))
self.e_last = energy
self.mol = mol

if callable(self.callback):
self.callback(locals())
return {"energy": energy, "gradient": gradients.ravel()}

def kernel(mol, value_and_grad,
constraints=None, callback=None,
maxsteps=100, **kwargs):
engine = PySCFADEngine(mol, value_and_grad)
engine.callback = callback
engine.maxsteps = maxsteps

if engine.mol.symmetry:
pass

if not os.path.exists(os.path.abspath(
os.path.join(geometric.optimize.__file__, '..', 'log.ini'))) and kwargs.get('logIni') is None:
kwargs['logIni'] = os.path.abspath(os.path.join(__file__, '..', 'log.ini'))

with tempfile.TemporaryDirectory(dir=lib.param.TMPDIR) as tmpdir:
tmpf = os.path.join(tmpdir, str(uuid.uuid4()))
try:
geometric.optimize.run_optimizer(customengine=engine, input=tmpf,
constraints=constraints, **kwargs)
conv = True
except NotConvergedError as e:
logger.note(mol, str(e))
conv = False
return conv, engine.mol

class NotConvergedError(RuntimeError):
pass
21 changes: 21 additions & 0 deletions pyscfad/geomopt/log.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[loggers]
keys=root

[handlers]
keys=stream_handler

[formatters]
keys=formatter

[logger_root]
level=INFO
handlers=stream_handler

[handler_stream_handler]
class=geometric.nifty.RawStreamHandler
level=INFO
formatter=formatter
args=(sys.stderr,)

[formatter_formatter]
format=%(message)s
17 changes: 16 additions & 1 deletion pyscfad/gto/_pyscf_moleintor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
'int1e_rinv_dr01' : (3, 3),
'int2c2e_dr10' : (3, 3),
'int2c2e_dr01' : (3, 3),
'int1e_r2_dr10' : (3, 3),
'int1e_r2_dr01' : (3, 3),
'int2e_dr1000' : (3, 3),
'int2e_dr0010' : (3, 3),
'int1e_ovlp_dr20' : (9, 9),
Expand All @@ -54,6 +56,9 @@
'int2c2e_dr20' : (9, 9),
'int2c2e_dr11' : (9, 9),
'int2c2e_dr02' : (9, 9),
'int1e_r2_dr20' : (9, 9),
'int1e_r2_dr11' : (9, 9),
'int1e_r2_dr02' : (9, 9),
'int2e_dr2000' : (9, 9),
'int2e_dr1100' : (9, 9),
'int2e_dr1010' : (9, 9),
Expand All @@ -79,6 +84,10 @@
'int2c2e_dr21' : (27, 27),
'int2c2e_dr12' : (27, 27),
'int2c2e_dr03' : (27, 27),
'int1e_r2_dr30' : (27, 27),
'int1e_r2_dr21' : (27, 27),
'int1e_r2_dr12' : (27, 27),
'int1e_r2_dr03' : (27, 27),
'int2e_dr3000' : (27, 27),
'int2e_dr2100' : (27, 27),
'int2e_dr2010' : (27, 27),
Expand Down Expand Up @@ -114,6 +123,11 @@
'int2c2e_dr22' : (81, 81),
'int2c2e_dr13' : (81, 81),
'int2c2e_dr04' : (81, 81),
'int1e_r2_dr40' : (81, 81),
'int1e_r2_dr31' : (81, 81),
'int1e_r2_dr22' : (81, 81),
'int1e_r2_dr13' : (81, 81),
'int1e_r2_dr04' : (81, 81),
'int2e_dr4000' : (81, 81),
'int2e_dr3100' : (81, 81),
'int2e_dr3010' : (81, 81),
Expand All @@ -131,7 +145,8 @@
'int2e_dr0040' : (81, 81),
'int2e_dr0031' : (81, 81),
'int2e_dr0022' : (81, 81),
'int2e_dr0013' : (81, 81),})
'int2e_dr0013' : (81, 81),
})

def _get_intor_and_comp(intor_name, comp=None):
intor_name = ascint3(intor_name)
Expand Down
Loading

0 comments on commit 72eaa2b

Please sign in to comment.