diff --git a/MANIFEST.in b/MANIFEST.in index 86e322e3..a8342829 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,9 +1,11 @@ include MANIFEST.in include README.md setup.py LICENSE -prune pyscfad/lib/build +recursive-include pyscfadlib/thirdparty *.so *.dylib +include pyscfadlib/*.so pyscfadlib/*.dylib pyscfadlib/config.h.in -include pyscfad/lib/*.so -include pyscfad/lib/*.dylib +# source code +recursive-include pyscfadlib *.c *.h CMakeLists.txt -recursive-include pyscfad/lib *.c *.h +global-exclude *.py[cod] +prune pyscfadlib/build diff --git a/pyscfad/backend/_jax/core.py b/pyscfad/backend/_jax/core.py index 61a9cab7..2bdd309f 100644 --- a/pyscfad/backend/_jax/core.py +++ b/pyscfad/backend/_jax/core.py @@ -1,6 +1,9 @@ import numpy as np import jax +def is_tensor(x): + return isinstance(x, jax.Array) + def stop_gradient(x): return jax.lax.stop_gradient(x) diff --git a/pyscfad/backend/_numpy/core.py b/pyscfad/backend/_numpy/core.py index b116a00f..635f9944 100644 --- a/pyscfad/backend/_numpy/core.py +++ b/pyscfad/backend/_numpy/core.py @@ -1,5 +1,8 @@ import numpy as np +def is_tensor(x): + return isinstance(x, (np.ndarray, np.generic)) + def stop_gradient(x): return x diff --git a/pyscfad/backend/_torch/core.py b/pyscfad/backend/_torch/core.py index a3fff0fc..9f9c24dc 100644 --- a/pyscfad/backend/_torch/core.py +++ b/pyscfad/backend/_torch/core.py @@ -1,9 +1,9 @@ from keras_core import ops +is_tensor = ops.is_tensor stop_gradient = ops.stop_gradient +convert_to_tensor = ops.convert_to_tensor def convert_to_numpy(x): x = stop_gradient(x) return ops.convert_to_numpy(x) - -convert_to_tensor = ops.convert_to_tensor diff --git a/pyscfad/lib/logger.py b/pyscfad/lib/logger.py new file mode 100644 index 00000000..966f2c98 --- /dev/null +++ b/pyscfad/lib/logger.py @@ -0,0 +1,92 @@ +import sys +from functools import wraps +from pyscf.lib import logger +from pyscf.lib.logger import ( + DEBUG1 as DEBUG1, + DEBUG as DEBUG, + INFO as INFO, + NOTE as NOTE, + NOTICE as NOTICE, + WARN as WARN, + WARNING as WARNING, + ERR as ERR, + ERROR as ERROR, + QUIET as QUIET, + TIMER_LEVEL as TIMER_LEVEL, + process_clock as process_clock, + perf_counter as perf_counter, +) +from pyscfad import ops + +def flush(rec, msg, *args): + arg_list = [] + for arg in args: + if ops.is_tensor(arg): + arg = ops.convert_to_numpy(arg) + arg_list.append(arg) + logger.flush(rec, msg, *arg_list) + +def log(rec, msg, *args): + if rec.verbose > QUIET: + flush(rec, msg, *args) + +def error(rec, msg, *args): + if rec.verbose >= ERROR: + flush(rec, '\nERROR: '+msg+'\n', *args) + #sys.stderr.write('ERROR: ' + (msg%args) + '\n') + +def warn(rec, msg, *args): + if rec.verbose >= WARN: + flush(rec, '\nWARN: '+msg+'\n', *args) + #if rec.stdout is not sys.stdout: + # sys.stderr.write('WARN: ' + (msg%args) + '\n') + +def info(rec, msg, *args): + if rec.verbose >= INFO: + flush(rec, msg, *args) + +def note(rec, msg, *args): + if rec.verbose >= NOTICE: + flush(rec, msg, *args) + +def debug(rec, msg, *args): + if rec.verbose >= DEBUG: + flush(rec, msg, *args) + +def debug1(rec, msg, *args): + if rec.verbose >= DEBUG1: + flush(rec, msg, *args) + +def timer(rec, msg, cpu0=None, wall0=None): + if cpu0 is None: + cpu0 = rec._t0 + if wall0 is None: + wall0 = rec._w0 + rec._t0, rec._w0 = process_clock(), perf_counter() + if rec.verbose >= TIMER_LEVEL: + flush(rec, ' CPU time for %s %9.2f sec, wall time %9.2f sec' + % (msg, rec._t0-cpu0, rec._w0-wall0)) + return rec._t0, rec._w0 + +class Logger(logger.Logger): + log = log + error = error + warn = warn + note = note + info = info + debug = debug + debug1 = debug1 + timer = timer + +@wraps(logger.new_logger) +def new_logger(rec=None, verbose=None): + if isinstance(verbose, Logger): + log = verbose + elif isinstance(verbose, int): + if getattr(rec, 'stdout', None): + log = Logger(rec.stdout, verbose) + else: + log = Logger(sys.stdout, verbose) + else: + log = Logger(rec.stdout, rec.verbose) + return log diff --git a/pyscfad/ops/core.py b/pyscfad/ops/core.py index 77a4f7d2..8cc695fc 100644 --- a/pyscfad/ops/core.py +++ b/pyscfad/ops/core.py @@ -1,12 +1,16 @@ from pyscfad import backend __all__ = [ + 'is_tensor', 'stop_gradient', 'stop_grad', 'convert_to_numpy', 'convert_to_tensor', ] +def is_tensor(x): + return backend.core.is_tensor(x) + def stop_gradient(x): return backend.core.stop_gradient(x) diff --git a/pyscfad/scf/hf.py b/pyscfad/scf/hf.py index fc29591c..332fc161 100644 --- a/pyscfad/scf/hf.py +++ b/pyscfad/scf/hf.py @@ -2,7 +2,7 @@ import numpy import jax from pyscf.data import nist -from pyscf.lib import logger, module_method +from pyscf.lib import module_method from pyscf.scf import hf as pyscf_hf from pyscf.scf import chkfile from pyscf.scf.hf import TIGHT_GRAD_CONV_TOL @@ -11,7 +11,7 @@ from pyscfad import ops from pyscfad import numpy as np from pyscfad import lib -from pyscfad.lib import jit, stop_trace +from pyscfad.lib import logger, jit, stop_trace from pyscfad.ops import stop_grad from pyscfad import util from pyscfad.implicit_diff import make_implicit_diff @@ -47,7 +47,7 @@ def _scf(dm, mf, s1e, h1e, *, vhf = mf.get_veff(mol, dm) e_tot = mf.energy_tot(dm, h1e, vhf) log.info('init E= %.15g', e_tot) - cput1 = log.timer('initialize scf', log._t0, log._w0) + cput1 = log.timer('initialize scf') for cycle in range(mf.max_cycle): dm_last = dm diff --git a/pyscfadlib/CMakeLists.txt b/pyscfadlib/CMakeLists.txt index 7c1ceb38..61f5e380 100644 --- a/pyscfadlib/CMakeLists.txt +++ b/pyscfadlib/CMakeLists.txt @@ -79,13 +79,13 @@ set(C_LINK_TEMPLATE " =1.17', 'scipy',