From 6ad7181efd527b123c698fd0de87fd77135f91e8 Mon Sep 17 00:00:00 2001 From: Xing Zhang Date: Sat, 22 Jun 2024 01:41:06 -0700 Subject: [PATCH] Release 0.1.5 (#29) --- .github/workflows/build_pyscf_linux.sh | 22 - .github/workflows/build_pyscf_macos.sh | 7 - .github/workflows/ci.yml | 4 +- .github/workflows/install_pyscf.sh | 9 +- .github/workflows/publish_pyscfad.yml | 36 ++ ...uild_wheels.yml => publish_pyscfadlib.yml} | 35 +- .github/workflows/release_github.yml | 24 + .github/workflows/release_info.sh | 28 ++ .github/workflows/run_test.sh | 6 +- .pylintrc | 4 +- CHANGELOG.md | 61 +++ examples/scf/00-simple.py | 2 - examples/scf/10-polarizability.py | 6 +- pyscfad/__init__.py | 14 +- pyscfad/_src/_config.py | 1 - pyscfad/_src/scipy/linalg.py | 2 +- pyscfad/_src/util.py | 131 ------ pyscfad/ao2mo/__init__.py | 3 +- pyscfad/ao2mo/_ao2mo.py | 2 +- pyscfad/ao2mo/addons.py | 2 +- pyscfad/ao2mo/incore.py | 4 +- pyscfad/backend/__init__.py | 22 + pyscfad/backend/_common/__init__.py | 1 + pyscfad/backend/_common/core.py | 48 ++ pyscfad/backend/_cupy/__init__.py | 56 +++ pyscfad/backend/_cupy/core.py | 7 + pyscfad/backend/_jax/__init__.py | 65 +++ pyscfad/backend/_jax/core.py | 129 +++++ pyscfad/backend/_numpy/__init__.py | 58 +++ pyscfad/backend/_numpy/core.py | 50 ++ pyscfad/backend/_torch/__init__.py | 60 +++ pyscfad/backend/_torch/core.py | 15 + pyscfad/backend/_torch/linalg.py | 18 + pyscfad/backend/_torch/numpy.py | 4 + pyscfad/backend/config.py | 111 +++++ pyscfad/backend/numpy.py | 4 + pyscfad/backend/ops.py | 65 +++ pyscfad/cc/ccsd.py | 7 +- pyscfad/cc/ccsd_rdm.py | 2 +- pyscfad/cc/ccsd_t_slow.py | 5 +- pyscfad/cc/dfccsd.py | 2 +- pyscfad/cc/rccsd.py | 3 +- pyscfad/cc/rintermediates.py | 2 +- pyscfad/cc/test/test_ccsd_t.py | 4 + pyscfad/cc/test/test_rccsd_hess.py | 4 + pyscfad/df/addons.py | 2 +- pyscfad/df/df.py | 2 +- pyscfad/df/df_jk.py | 2 +- pyscfad/df/incore.py | 3 +- pyscfad/dft/libxc.py | 23 +- pyscfad/dft/numint.py | 13 +- pyscfad/dft/rks.py | 57 +-- pyscfad/fci/__init__.py | 2 +- pyscfad/fci/fci_slow.py | 5 +- pyscfad/gto/_mole_helper.py | 23 +- pyscfad/gto/_moleintor_jvp.py | 11 +- pyscfad/gto/eval_gto.py | 10 +- pyscfad/gto/mole.py | 75 ++- pyscfad/gto/moleintor_opt.py | 440 ++++++++++++++++++ pyscfad/gw/rpa.py | 4 +- pyscfad/gw/test/test_rpa.py | 2 +- pyscfad/lib/__init__.py | 5 +- pyscfad/lib/config.h.in | 6 - pyscfad/lib/diis.py | 5 +- pyscfad/lib/jax_helper.py | 133 ------ pyscfad/lib/linalg_helper.py | 5 +- pyscfad/lib/logger.py | 3 + pyscfad/lib/misc.py | 16 - pyscfad/lib/numpy_helper.py | 54 ++- pyscfad/lib/ops.py | 62 --- pyscfad/lo/boys.py | 4 +- pyscfad/lo/orth.py | 4 +- pyscfad/lo/pipek.py | 2 +- pyscfad/ml/__init__.py | 3 + pyscfad/ml/scf/__init__.py | 3 + pyscfad/ml/scf/hf.py | 53 +++ pyscfad/mp/dfmp2.py | 5 +- pyscfad/mp/mp2.py | 5 +- pyscfad/mp/test/test_oomp2.py | 6 +- pyscfad/pbc/df/df_jk.py | 4 +- pyscfad/pbc/df/fft_jk.py | 6 +- pyscfad/pbc/dft/krks.py | 33 +- pyscfad/pbc/dft/numint.py | 5 +- pyscfad/pbc/dft/rks.py | 45 +- pyscfad/pbc/gto/cell.py | 59 +-- pyscfad/pbc/gto/eval_gto.py | 4 +- pyscfad/pbc/lib/kpts_helper.py | 2 +- pyscfad/pbc/scf/hf.py | 86 ++-- pyscfad/pbc/scf/khf.py | 75 ++- pyscfad/pbc/scf/test/test_hf.py | 2 +- pyscfad/pbc/tools/pbc.py | 6 +- pyscfad/scf/chkfile.py | 9 +- pyscfad/scf/diis.py | 2 +- pyscfad/scf/hf.py | 152 +++--- pyscfad/scf/rohf.py | 7 +- pyscfad/scf/uhf.py | 16 +- pyscfad/soscf/ciah.py | 5 +- pyscfad/tdscf/rhf.py | 15 +- pyscfad/tools/util.py | 2 +- pyscfad/util.py | 31 +- pyscfad/version.py | 2 +- pyscfadlib/pyproject.toml | 2 +- pyscfadlib/pyscfadlib/libcint.patch | 2 +- pyscfadlib/pyscfadlib/libcint.patch.6.1 | 30 ++ pyscfadlib/pyscfadlib/version.py | 2 +- requirements.txt | 4 +- setup.py | 4 +- 107 files changed, 2002 insertions(+), 803 deletions(-) delete mode 100755 .github/workflows/build_pyscf_linux.sh delete mode 100755 .github/workflows/build_pyscf_macos.sh create mode 100644 .github/workflows/publish_pyscfad.yml rename .github/workflows/{build_wheels.yml => publish_pyscfadlib.yml} (64%) create mode 100644 .github/workflows/release_github.yml create mode 100755 .github/workflows/release_info.sh create mode 100644 CHANGELOG.md delete mode 100644 pyscfad/_src/util.py create mode 100644 pyscfad/backend/__init__.py create mode 100644 pyscfad/backend/_common/__init__.py create mode 100644 pyscfad/backend/_common/core.py create mode 100644 pyscfad/backend/_cupy/__init__.py create mode 100644 pyscfad/backend/_cupy/core.py create mode 100644 pyscfad/backend/_jax/__init__.py create mode 100644 pyscfad/backend/_jax/core.py create mode 100644 pyscfad/backend/_numpy/__init__.py create mode 100644 pyscfad/backend/_numpy/core.py create mode 100644 pyscfad/backend/_torch/__init__.py create mode 100644 pyscfad/backend/_torch/core.py create mode 100644 pyscfad/backend/_torch/linalg.py create mode 100644 pyscfad/backend/_torch/numpy.py create mode 100644 pyscfad/backend/config.py create mode 100644 pyscfad/backend/numpy.py create mode 100644 pyscfad/backend/ops.py create mode 100644 pyscfad/gto/moleintor_opt.py delete mode 100644 pyscfad/lib/config.h.in delete mode 100644 pyscfad/lib/jax_helper.py delete mode 100644 pyscfad/lib/misc.py delete mode 100644 pyscfad/lib/ops.py create mode 100644 pyscfad/ml/__init__.py create mode 100644 pyscfad/ml/scf/__init__.py create mode 100644 pyscfad/ml/scf/hf.py create mode 100644 pyscfadlib/pyscfadlib/libcint.patch.6.1 diff --git a/.github/workflows/build_pyscf_linux.sh b/.github/workflows/build_pyscf_linux.sh deleted file mode 100755 index f467ad09..00000000 --- a/.github/workflows/build_pyscf_linux.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash -# MKL -wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB -sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB -rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB -sudo echo "deb https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list -sudo apt-get update -sudo apt-get install -y intel-oneapi-mkl -source /opt/intel/oneapi/setvars.sh -echo $MKLROOT -printenv >> $GITHUB_ENV - -sudo apt-get -qq install \ - gcc \ - libblas-dev \ - cmake - -cd pyscf/pyscf/lib -curl -L "https://github.com/fishjojo/pyscf-deps/blob/master/pyscf-2.1.1-ad-deps.tar.gz?raw=true" | tar xzf - -mkdir build; cd build -cmake -DBUILD_LIBXC=OFF -DBUILD_XCFUN=OFF -DBUILD_LIBCINT=OFF .. -make -j4 diff --git a/.github/workflows/build_pyscf_macos.sh b/.github/workflows/build_pyscf_macos.sh deleted file mode 100755 index ecec7d15..00000000 --- a/.github/workflows/build_pyscf_macos.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash -cd pyscf/pyscf/lib -#curl -L https://github.com/fishjojo/pyscf-deps/raw/master/pyscf-1.7.5-deps-macos-10.14.tar.gz | tar xzf - -mkdir build; cd build -#cmake -DBUILD_LIBXC=OFF -DBUILD_XCFUN=OFF .. -cmake .. -make -j4 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9b66d399..058a4425 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,12 +14,12 @@ jobs: strategy: matrix: os: [ubuntu-20.04] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] environment: ci steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: install pyscf diff --git a/.github/workflows/install_pyscf.sh b/.github/workflows/install_pyscf.sh index 69eaa825..feeaf120 100755 --- a/.github/workflows/install_pyscf.sh +++ b/.github/workflows/install_pyscf.sh @@ -2,13 +2,12 @@ python -m pip install --upgrade pip python -m pip cache purge pip install wheel +pip install pytest +pip install pytest-cov pip install numpy pip install 'scipy<1.12' pip install h5py pip install jaxlib pip install jax -pip install pytest -pip install pytest-cov - -pip install 'pyscf==2.3' -pip install 'pyscfadlib==0.1.4' +pip install 'pyscf>=2.3,<2.7' +pip install 'pyscfadlib>=0.1.4' diff --git a/.github/workflows/publish_pyscfad.yml b/.github/workflows/publish_pyscfad.yml new file mode 100644 index 00000000..daefabd1 --- /dev/null +++ b/.github/workflows/publish_pyscfad.yml @@ -0,0 +1,36 @@ +name: Publish pyscfad + +on: + release: + types: + - released + workflow_dispatch: + +jobs: + publish_pypi_any: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-20.04] + python-version: ["3.11"] + + environment: release + permissions: + id-token: write + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels + run: | + python3 -m pip install --upgrade build + python3 -m build + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/publish_pyscfadlib.yml similarity index 64% rename from .github/workflows/build_wheels.yml rename to .github/workflows/publish_pyscfadlib.yml index 6d1dc3ff..dc05fe2a 100644 --- a/.github/workflows/build_wheels.yml +++ b/.github/workflows/publish_pyscfadlib.yml @@ -1,17 +1,23 @@ -name: Build +name: Publish pyscfadlib on: release: - types: [published] + types: + - released + workflow_dispatch: jobs: - build_wheels_x86: + publish_pypi_linux_macos_x86: name: Build wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-20.04, macos-12] + environment: release + permissions: + id-token: write + steps: - uses: actions/checkout@v4 @@ -22,21 +28,23 @@ jobs: CMAKE_CONFIGURE_ARGS: "-DWITH_F12=OFF" with: package-dir: pyscfadlib - output-dir: pyscfadlib/wheelhouse + output-dir: dist config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v4 - with: - name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} - path: pyscfadlib/wheelhouse/*.whl + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 - build_wheels_arm64: + publish_pypi_macos_arm64: name: Build wheels on ${{ matrix.os }} runs-on: ${{ matrix.os }} strategy: matrix: os: [macos-14] + environment: release + permissions: + id-token: write + steps: - uses: actions/checkout@v4 @@ -47,11 +55,8 @@ jobs: CMAKE_OSX_ARCHITECTURES: arm64 with: package-dir: pyscfadlib - output-dir: pyscfadlib/wheelhouse + output-dir: dist config-file: "{package}/pyproject.toml" - - uses: actions/upload-artifact@v4 - with: - name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} - path: pyscfadlib/wheelhouse/*.whl - + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/release_github.yml b/.github/workflows/release_github.yml new file mode 100644 index 00000000..bf028458 --- /dev/null +++ b/.github/workflows/release_github.yml @@ -0,0 +1,24 @@ +name: github release + +on: [workflow_dispatch] + +jobs: + create-release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Release Info + run: ./.github/workflows/release_info.sh + id: release_info + - name: Create Release + if: ${{ steps.release_info.outputs.version_tag }} + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.release_info.outputs.version_tag }} + release_name: pyscfad release ${{ steps.release_info.outputs.version_tag }} + body_path: RELEASE.md + prerelease: true + # draft: true diff --git a/.github/workflows/release_info.sh b/.github/workflows/release_info.sh new file mode 100755 index 00000000..70558d43 --- /dev/null +++ b/.github/workflows/release_info.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +set -e + +# Get latest version tag +get_last_tag() { + curl --silent "https://api.github.com/repos/fishjojo/pyscfad/releases/latest" | sed -n 's/.*"tag_name": "v\(.*\)",.*/\1/p' +} +last_version=$(get_last_tag) +echo Last version: $last_version + +# Get current version tag +cur_version=$(sed -n "/^__version__ =/s/.*\"\(.*\)\"/\1/p" pyscfad/version.py) +if [ -z "$cur_version" ]; then + cur_version=$(sed -n "/^__version__ =/s/.*'\(.*\)'/\1/p" pyscfad/version.py) +fi +echo Current version: $cur_version + +# Create version tag +if [ -n "$last_version" ] && [ -n "$cur_version" ] && [ "$cur_version" != "$last_version" ]; then + git config user.name "Github Actions" + git config user.email "github-actions@users.noreply.github.com" + version_tag=v"$cur_version" + + # Extract release info from CHANGELOG + sed -n "/^## pyscfad $cur_version/,/^## pyscfad $last_version/p" CHANGELOG.md | tail -n +2 | sed -e '/^## pyscfad /,$d' | head -n -1 > RELEASE.md + echo "::set-output name=version_tag::$version_tag" +fi diff --git a/.github/workflows/run_test.sh b/.github/workflows/run_test.sh index 2ced9abf..5bfc6c5a 100755 --- a/.github/workflows/run_test.sh +++ b/.github/workflows/run_test.sh @@ -1,9 +1,5 @@ #!/usr/bin/env bash export OMP_NUM_THREADS=1 -export PYTHONPATH=$(pwd):$(pwd)/pyscf:$PYTHONPATH -echo "pyscfad = True" >> $HOME/.pyscf_conf.py -#echo "pyscf_numpy_backend = 'jax'" >> $HOME/.pyscf_conf.py -#echo "pyscf_scipy_linalg_backend = 'pyscfad'" >> $HOME/.pyscf_conf.py -#echo "pyscf_scipy_backend = 'jax'" >> $HOME/.pyscf_conf.py +export PYTHONPATH=$(pwd):$PYTHONPATH pytest ./pyscfad --cov-report xml --cov=. --verbosity=1 --durations=10 diff --git a/.pylintrc b/.pylintrc index ee34511d..2de27ccb 100644 --- a/.pylintrc +++ b/.pylintrc @@ -8,7 +8,9 @@ [MASTER] # Files or directories to be skipped. They should be base names, not paths. -ignore=test +ignore=test, + backend, + ml, # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..5020d0a8 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,61 @@ +# Change log + +## pyscfad 0.1.5 +* Changes + * pyscfad is now compatable with pyscf 2.6. + * Add `backend` module (experimental). + * Add GCCSD(T). + * Add interface to Jabobi sweep for Pipek-Mezey localization. + +* Bug fixes + * Fix LRC hybrid density functionals. + +## pyscfad 0.1.4 (Mar 5, 2024) +* Changes + * pyscfad is now compatable with pyscf 2.3. + * Drop support for python 3.7. + * Drop dependence on jaxopt. + * Update JAX custom pytree node auxiliary data for safe comparison. + * Add `pyscfadlib`, fast C codes for custom VJPs. + * Add dynamic configuration. + * Allow `implicit_diff` to use preconditioners. + * Improve `scipy.linalg.eigh`. + * Add `scipy.linalg.svd`. + * Improve `lib.unpack_tril`, `lib.pack_tril`. + * Refactor `gto.moleintor`. + * Add fast VJP for molecular integrals. + * Improve `gto.eval_gto` performance. + * Add `gto.eval_gto` gradient w.r.t. `Mole.ctr_coeff` and `Mole.exp`. + * Avoid `df.DF` to create temperary files. + * Optimize `df.df_jk`. + * Add `scf.cphf`. + * Add `ao2mo._ao2mo`. + * Add `lo.iao`, `lo.boys`, `lo.pipek`, `lo.orth`. + * Add `geomopt`. + * Add `mp.dfmp2`, MP2 one-RDM. + * Consider permutation symmetry for CCSD. + * Disable `jit` for CCSD which causes memory leak. + * Simplify implementation of `cc.ccsd_t_slow`. + * Add optimized `cc.ccsd_t`. + * Add iterative CCSD(T) solver. + * Add `cc.dfccsd`, `cc.dcsd`. + * Add `tools.timer`. + +* Bug fixes + * Fix integral derivatives w.r.t. `Mole.ctr_coeff` and `Mole.exp`. + +## pyscfad 0.1.3 (Sep 13, 2023) +* Bug fixes + * Fix installation issues. + +## pyscfad 0.1.2 (Mar 13, 2023) +* Changes + * Add AD support for ROHF. + +## pyscfad 0.1.1 (Jan 25, 2023) +* Changes + * pyscfad is now compatable with pyscf 2.1. + +## pyscfad 0.1.0 (Aug 3, 2022) +* Changes + * First release. diff --git a/examples/scf/00-simple.py b/examples/scf/00-simple.py index 2823faf0..126efa1a 100644 --- a/examples/scf/00-simple.py +++ b/examples/scf/00-simple.py @@ -1,4 +1,3 @@ -import pyscf from pyscfad import gto, scf """ @@ -12,7 +11,6 @@ mol.build() mf = scf.RHF(mol) -mf.kernel() jac = mf.energy_grad() print(f'Nuclaer gradient:\n{jac.coords}') print(f'Gradient wrt basis exponents:\n{jac.exp}') diff --git a/examples/scf/10-polarizability.py b/examples/scf/10-polarizability.py index 4426b0fb..718e2c75 100644 --- a/examples/scf/10-polarizability.py +++ b/examples/scf/10-polarizability.py @@ -1,7 +1,7 @@ import numpy import jax +from pyscfad import numpy as np from pyscfad import gto, scf -from pyscfad.lib import numpy as jnp mol = gto.Mole() mol.atom = '''H , 0. 0. 0. @@ -15,7 +15,7 @@ h1 = mf.get_hcore() def apply_E(E): - mf.get_hcore = lambda *args, **kwargs: h1 + jnp.einsum('x,xij->ij', E, ao_dip) + mf.get_hcore = lambda *args, **kwargs: h1 + np.einsum('x,xij->ij', E, ao_dip) mf.kernel() return mf.dip_moment(mol, mf.make_rdm1(), unit='AU', verbose=0) @@ -24,7 +24,7 @@ def apply_E(E): print(polar) def apply_E1(E): - mf.get_hcore = lambda *args, **kwargs: h1 + jnp.einsum('x,xij->ij', E, ao_dip) + mf.get_hcore = lambda *args, **kwargs: h1 + np.einsum('x,xij->ij', E, ao_dip) return mf.kernel() polar = -jax.hessian(apply_E1)(E0) diff --git a/pyscfad/__init__.py b/pyscfad/__init__.py index a7dab95e..6011d04e 100644 --- a/pyscfad/__init__.py +++ b/pyscfad/__init__.py @@ -1,15 +1,19 @@ """ PySCF with auto-differentiation """ +import sys from pyscfad.version import __version__ -#from pyscfad import util -#from pyscfad import implicit_diff - from pyscfad._src._config import ( config, config_update ) -import jax -jax.config.update("jax_enable_x64", True) +# export backend.numpy to pyscfad namespace +# pylint: disable=useless-import-alias +from pyscfad.backend import numpy as numpy +from pyscfad.backend import ops as ops +sys.modules['pyscfad.numpy'] = numpy +sys.modules['pyscfad.ops'] = ops + +del sys diff --git a/pyscfad/_src/_config.py b/pyscfad/_src/_config.py index 32b82105..a545bb56 100644 --- a/pyscfad/_src/_config.py +++ b/pyscfad/_src/_config.py @@ -25,7 +25,6 @@ def reset(self): config = _Config() -config.set_default('pyscfad_numpy_backend', 'jax') config.set_default('pyscfad_scf_implicit_diff', False) config.set_default('pyscfad_ccsd_implicit_diff', False) config.set_default('pyscfad_ccsd_checkpoint', False) diff --git a/pyscfad/_src/scipy/linalg.py b/pyscfad/_src/scipy/linalg.py index 4897191a..b2e7ab20 100644 --- a/pyscfad/_src/scipy/linalg.py +++ b/pyscfad/_src/scipy/linalg.py @@ -4,7 +4,7 @@ import scipy.linalg from jax import numpy as np from jax import scipy as jax_scipy -from pyscfad.lib import custom_jvp, jit +from pyscfad.ops import custom_jvp, jit # default threshold for degenerate eigenvalues DEG_THRESH = 1e-9 diff --git a/pyscfad/_src/util.py b/pyscfad/_src/util.py deleted file mode 100644 index b0453e00..00000000 --- a/pyscfad/_src/util.py +++ /dev/null @@ -1,131 +0,0 @@ -import warnings -from jax import tree_util -from pyscf import __config__ - -PYSCFAD = getattr(__config__, 'pyscfad', False) - - -def _dict_hash(this): - from pyscf.lib.misc import finger - fg = [] - leaves, tree = tree_util.tree_flatten(this) - fg.append(hash(tree)) - for v in leaves: - if hasattr(v, 'size'): # arrays - fg.append(finger(v)) - elif isinstance(v, set): - fg.append(_dict_hash(tuple(v))) - else: - try: - fg.append(hash(v)) - except TypeError as e: - raise e - return hash(tuple(fg)) - -def _dict_equality(d1, d2): - leaves1, tree1 = tree_util.tree_flatten(d1) - leaves2, tree2 = tree_util.tree_flatten(d2) - if tree1 != tree2: - return False - - for v1, v2 in zip(leaves1, leaves2): - if v1 is v2: - neq = False - else: - if hasattr(v1, 'size') and hasattr(v2, 'size'): # arrays - if v1.size != v2.size: - neq = True - elif v1.size == 0 and v2.size == 0: - neq = False - else: - try: - neq = not v1 == v2 - except ValueError as e: - try: - neq = not (v1 == v2).all() - except Exception: - raise e from None - else: - try: - neq = not v1 == v2 - except ValueError as e: - raise e - if neq: - return False - return True - - -class _AuxData: - def __init__(self, **kwargs): - self.data = {**kwargs} - - def __eq__(self, other): - if self is other: - return True - if not isinstance(other, _AuxData): - return False - return _dict_equality(self.data, other.data) - - def __hash__(self): - return _dict_hash(self.data) - - -def pytree_node(leaf_names, num_args=0): - """Class decorator that registers the underlying class as a pytree node. - - See `jax document `_ - for the definition of pytrees. - - Parameters - ---------- - leaf_names : list or tuple - Attributes of the class that are traced as pytree leaves. - num_args : int, optional - Number of positional arguments in ``leaf_names``. - This is useful when the ``__init__`` method of the class - has positional arguments that are named differently than - the actual attribute names. Default value is 0. - - Notes - ----- - The ``__init__`` method of the class can't have positional arguments - that are not included in ``leaf_names``. If ``num_args`` is greater - than 0, the sequence of positional arguments in ``leaf_names`` must - follow that in the ``__init__`` method. - """ - def class_orig(cls): - return cls - - def class_as_pytree_node(cls): - def tree_flatten(obj): - keys = obj.__dict__.keys() - for leaf_name in leaf_names: - if leaf_name not in keys: - raise KeyError(f'Pytree leaf {leaf_name} is not defined in class {cls}.') - children = tuple(getattr(obj, leaf_name, None) for leaf_name in leaf_names) - if len(children) <= 0: - #raise KeyError("Empty pytree node is not supported.") - warnings.warn(f'Not taking derivatives wrt the leaves in ' - f'the node {obj.__class__} as none of those was specified.') - - aux_keys = list(set(keys) - set(leaf_names)) - aux_data = list(getattr(obj, key, None) for key in aux_keys) - metadata = (num_args,) + (_AuxData(**dict(zip(aux_keys, aux_data))),) - return children, metadata - - def tree_unflatten(metadata, children): - num_args = metadata[0] - auxdata = metadata[1] - leaves_args = children[:num_args] - leaves_kwargs = dict(zip(leaf_names[num_args:], children[num_args:])) - kwargs = {**leaves_kwargs, **(auxdata.data)} - obj = cls(*leaves_args, **kwargs) - return obj - - tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten) - return cls - - if PYSCFAD: - return class_as_pytree_node - else: - return class_orig diff --git a/pyscfad/ao2mo/__init__.py b/pyscfad/ao2mo/__init__.py index fc7ad4a7..ec73f962 100644 --- a/pyscfad/ao2mo/__init__.py +++ b/pyscfad/ao2mo/__init__.py @@ -4,8 +4,7 @@ def general(eri_or_mol, mo_coeffs, *args, erifile=None, dataname='eri_mo', intor='int2e', **kwargs): - from pyscfad import lib - if lib.isarray(eri_or_mol): + if hasattr(eri_or_mol, 'shape'): return incore.general(eri_or_mol, mo_coeffs, *args, **kwargs) else: raise NotImplementedError diff --git a/pyscfad/ao2mo/_ao2mo.py b/pyscfad/ao2mo/_ao2mo.py index 2bea7529..162f3420 100644 --- a/pyscfad/ao2mo/_ao2mo.py +++ b/pyscfad/ao2mo/_ao2mo.py @@ -1,4 +1,4 @@ -from jax import numpy as np +from pyscfad import numpy as np from pyscfad import config from ._ao2mo_opt import nr_e2 as nr_e2_opt diff --git a/pyscfad/ao2mo/addons.py b/pyscfad/ao2mo/addons.py index fad74f44..6864f243 100644 --- a/pyscfad/ao2mo/addons.py +++ b/pyscfad/ao2mo/addons.py @@ -1,6 +1,6 @@ from pyscf.ao2mo import addons as pyscf_addons from pyscfad import lib -from pyscfad.lib import vmap +from pyscfad.ops import vmap class load(pyscf_addons.load): def __enter__(self): diff --git a/pyscfad/ao2mo/incore.py b/pyscfad/ao2mo/incore.py index 3a5a1449..e6036f21 100644 --- a/pyscfad/ao2mo/incore.py +++ b/pyscfad/ao2mo/incore.py @@ -1,8 +1,8 @@ -from jax import numpy as np from pyscf.ao2mo import incore from pyscf.ao2mo.incore import iden_coeffs +from pyscfad import numpy as np from pyscfad import lib -from pyscfad.lib import jit, vmap +from pyscfad.ops import jit, vmap def full(eri_ao, mo_coeff, verbose=0, compact=True, **kwargs): nao = mo_coeff.shape[0] diff --git a/pyscfad/backend/__init__.py b/pyscfad/backend/__init__.py new file mode 100644 index 00000000..a1d473d1 --- /dev/null +++ b/pyscfad/backend/__init__.py @@ -0,0 +1,22 @@ +""" +The module oversees the backend. +""" +from pyscfad.backend.config import ( + default_backend, + set_backend, + get_backend, + with_backend, +) +from pyscfad.backend import numpy as numpy +from pyscfad.backend import ops as ops + +set_backend(default_backend()) + +__all__ = [ + 'set_backend', + 'get_backend', + 'with_backend', + 'numpy', + 'ops', +] + diff --git a/pyscfad/backend/_common/__init__.py b/pyscfad/backend/_common/__init__.py new file mode 100644 index 00000000..bb67a43f --- /dev/null +++ b/pyscfad/backend/_common/__init__.py @@ -0,0 +1 @@ +from .core import * diff --git a/pyscfad/backend/_common/core.py b/pyscfad/backend/_common/core.py new file mode 100644 index 00000000..8f479fdd --- /dev/null +++ b/pyscfad/backend/_common/core.py @@ -0,0 +1,48 @@ +def stop_gradient(x): + return x + +def class_as_pytree_node(cls, leaf_names, num_args=0): + return cls + +class custom_jvp: + """Fake ``custom_jvp`` that does nothing. + """ + def __init__(self, fun, *args, **kwargs): + self.fun = fun + self.jvp = None + + def defjvp(self, jvp): + self.jvp = jvp + return jvp + + def __call__(self, *args, **kwargs): + return self.fun(*args, **kwargs) + +def jit(fun, **kwargs): + return fun + +# TODO deprecate these +class _Indexable(object): + # pylint: disable=line-too-long + """ + see https://github.com/google/jax/blob/97d00584f8b87dfe5c95e67892b54db993f34486/jax/_src/ops/scatter.py#L87 + """ + __slots__ = () + + def __getitem__(self, idx): + return idx + +index = _Indexable() + +def index_update(x, idx, y): + x[idx] = y + return x + +def index_add(x, idx, y): + x[idx] += y + return x + +def index_mul(x, idx, y): + x[idx] *= y + return x + diff --git a/pyscfad/backend/_cupy/__init__.py b/pyscfad/backend/_cupy/__init__.py new file mode 100644 index 00000000..12b17189 --- /dev/null +++ b/pyscfad/backend/_cupy/__init__.py @@ -0,0 +1,56 @@ +from types import ModuleType +try: + import cupy as cp +except ImportError as err: + raise ImportError("Unable to import cupy.") from err + +from .._common import ( + stop_gradient, + class_as_pytree_node, + custom_jvp, + jit, + index, + index_update, + index_add, + index_mul, +) +from .core import ( + is_array, + to_numpy, +) + +class CupyBackend: + def __init__(self, package): + self._pkg = package + self._cache = {} + + def __getattr__(self, name): + if name in self._cache: + return self._cache[name] + + try: + attr = getattr(self._pkg, name) + if isinstance(attr, ModuleType): + submodule = self.__class__(attr) + self._cache[name] = submodule + return submodule + else: + self._cache[name] = attr + return attr + except AttributeError as err: + raise AttributeError(f"{self._pkg.__name__} has no attribute {name}") from err + +backend = CupyBackend(cp) + +backend._cache['is_array'] = is_array +backend._cache['to_numpy'] = to_numpy +backend._cache['stop_gradient'] = stop_gradient +backend._cache['class_as_pytree_node'] = class_as_pytree_node +backend._cache['custom_jvp'] = custom_jvp +backend._cache['jit'] = jit +backend._cache['vmap'] = NotImplemented +backend._cache['index'] = index +backend._cache['index_update'] = index_update +backend._cache['index_add'] = index_add +backend._cache['index_mul'] = index_mul + diff --git a/pyscfad/backend/_cupy/core.py b/pyscfad/backend/_cupy/core.py new file mode 100644 index 00000000..466f4a0c --- /dev/null +++ b/pyscfad/backend/_cupy/core.py @@ -0,0 +1,7 @@ +import cupy + +def is_array(x): + return isinstance(x, cupy.ndarray) + +def to_numpy(x): + return cupy.asnumpy(x) diff --git a/pyscfad/backend/_jax/__init__.py b/pyscfad/backend/_jax/__init__.py new file mode 100644 index 00000000..897bd3be --- /dev/null +++ b/pyscfad/backend/_jax/__init__.py @@ -0,0 +1,65 @@ +from types import ModuleType +try: + import jax +except ImportError as err: + raise ImportError("Unable to import jax.") from err + +from ..config import default_floatx +if default_floatx() == 'float64': + jax.config.update("jax_enable_x64", True) + +from .._common import ( + index, +) + +from jax import ( + custom_jvp, + jit, +) +from jax.lax import stop_gradient +from .core import ( + is_array, + to_numpy, + vmap, + class_as_pytree_node, + index_update, + index_add, + index_mul, +) + +class JaxBackend: + def __init__(self, package): + self._pkg = package + self._cache = {} + + def __getattr__(self, name): + if name in self._cache: + return self._cache[name] + + try: + attr = getattr(self._pkg, name) + if isinstance(attr, ModuleType): + submodule = self.__class__(attr) + self._cache[name] = submodule + return submodule + else: + self._cache[name] = attr + return attr + except AttributeError as err: + raise AttributeError(f"{self._pkg.__name__} has no attribute {name}") from err + +backend = JaxBackend(jax.numpy) + +# FIXME maybe separate ops from numpy +backend._cache['is_array'] = is_array +backend._cache['to_numpy'] = to_numpy +backend._cache['stop_gradient'] = stop_gradient +backend._cache['class_as_pytree_node'] = class_as_pytree_node +backend._cache['custom_jvp'] = custom_jvp +backend._cache['jit'] = jit +backend._cache['vmap'] = vmap +backend._cache['index'] = index +backend._cache['index_update'] = index_update +backend._cache['index_add'] = index_add +backend._cache['index_mul'] = index_mul + diff --git a/pyscfad/backend/_jax/core.py b/pyscfad/backend/_jax/core.py new file mode 100644 index 00000000..595a7b8f --- /dev/null +++ b/pyscfad/backend/_jax/core.py @@ -0,0 +1,129 @@ +import warnings +import jax +from jax import numpy as jnp +from jax import tree_util + +def is_array(x): + return isinstance(x, jax.Array) + +def to_numpy(x): + x = jax.lax.stop_gradient(x) + return x.__array__() + +def convert_to_tensor(x, dtype=None, **kwargs): + return jnp.asarray(x, dtype=dtype, **kwargs) + +def vmap(fun, in_axes=0, out_axes=0, chunk_size=None, signature=None): + return jax.vmap(fun, in_axes=in_axes, out_axes=out_axes) + +# TODO deprecate these +def index_update(x, idx, y): + x = jnp.asarray(x) + y = jnp.asarray(y) + return x.at[idx].set(y) + +def index_add(x, idx, y): + x = jnp.asarray(x) + y = jnp.asarray(y) + return x.at[idx].add(y) + +def index_mul(x, idx, y): + x = jnp.asarray(x) + y = jnp.asarray(y) + return x.at[idx].multiply(y) + + +def _dict_hash(this): + from pyscf.lib.misc import finger + fg = [] + leaves, tree = tree_util.tree_flatten(this) + fg.append(hash(tree)) + for v in leaves: + if hasattr(v, 'size'): # arrays + fg.append(finger(v)) + elif isinstance(v, set): + fg.append(_dict_hash(tuple(v))) + else: + try: + fg.append(hash(v)) + except TypeError as e: + raise e + return hash(tuple(fg)) + + +def _dict_equality(d1, d2): + leaves1, tree1 = tree_util.tree_flatten(d1) + leaves2, tree2 = tree_util.tree_flatten(d2) + if tree1 != tree2: + return False + + for v1, v2 in zip(leaves1, leaves2): + if v1 is v2: + neq = False + else: + if hasattr(v1, 'size') and hasattr(v2, 'size'): # arrays + if v1.size != v2.size: + neq = True + elif v1.size == 0 and v2.size == 0: + neq = False + else: + try: + neq = not v1 == v2 + except ValueError as e: + try: + neq = not (v1 == v2).all() + except Exception: + raise e from None + else: + try: + neq = not v1 == v2 + except ValueError as e: + raise e + if neq: + return False + return True + + +class _AuxData: + def __init__(self, **kwargs): + self.data = {**kwargs} + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, _AuxData): + return False + return _dict_equality(self.data, other.data) + + def __hash__(self): + return _dict_hash(self.data) + + +def class_as_pytree_node(cls, leaf_names, num_args=0): + def tree_flatten(obj): + keys = obj.__dict__.keys() + for leaf_name in leaf_names: + if leaf_name not in keys: + raise KeyError(f'Pytree leaf {leaf_name} is not defined in class {cls}.') + children = tuple(getattr(obj, leaf_name, None) for leaf_name in leaf_names) + if len(children) <= 0: + #raise KeyError("Empty pytree node is not supported.") + warnings.warn(f'Not taking derivatives w.r.t. the leaves in ' + f'the node {obj.__class__} as none of those was specified.') + + aux_keys = list(set(keys) - set(leaf_names)) + aux_data = list(getattr(obj, key, None) for key in aux_keys) + metadata = (num_args,) + (_AuxData(**dict(zip(aux_keys, aux_data))),) + return children, metadata + + def tree_unflatten(metadata, children): + num_args = metadata[0] + auxdata = metadata[1] + leaves_args = children[:num_args] + leaves_kwargs = dict(zip(leaf_names[num_args:], children[num_args:])) + kwargs = {**leaves_kwargs, **(auxdata.data)} + obj = cls(*leaves_args, **kwargs) + return obj + + tree_util.register_pytree_node(cls, tree_flatten, tree_unflatten) + return cls diff --git a/pyscfad/backend/_numpy/__init__.py b/pyscfad/backend/_numpy/__init__.py new file mode 100644 index 00000000..5a02b7fb --- /dev/null +++ b/pyscfad/backend/_numpy/__init__.py @@ -0,0 +1,58 @@ +from types import ModuleType +try: + import numpy as np +except ImportError as err: + raise ImportError("Unable to import numpy.") from err + +from .._common import ( + stop_gradient, + class_as_pytree_node, + custom_jvp, + jit, # TODO use numba + index, + index_update, + index_add, + index_mul, +) +from .core import ( + is_array, + to_numpy, + vmap, +) + +class NumpyBackend: + def __init__(self, package): + self._pkg = package + self._cache = {} + + def __getattr__(self, name): + if name in self._cache: + return self._cache[name] + + try: + attr = getattr(self._pkg, name) + if isinstance(attr, ModuleType): + submodule = self.__class__(attr) + self._cache[name] = submodule + return submodule + else: + self._cache[name] = attr + return attr + except AttributeError as err: + raise AttributeError(f"{self._pkg.__name__} has no attribute {name}") from err + +backend = NumpyBackend(np) + +# FIXME maybe separate ops from numpy +backend._cache['is_array'] = is_array +backend._cache['to_numpy'] = to_numpy +backend._cache['stop_gradient'] = stop_gradient +backend._cache['class_as_pytree_node'] = class_as_pytree_node +backend._cache['custom_jvp'] = custom_jvp +backend._cache['jit'] = jit +backend._cache['vmap'] = vmap +backend._cache['index'] = index +backend._cache['index_update'] = index_update +backend._cache['index_add'] = index_add +backend._cache['index_mul'] = index_mul + diff --git a/pyscfad/backend/_numpy/core.py b/pyscfad/backend/_numpy/core.py new file mode 100644 index 00000000..7ad3ee96 --- /dev/null +++ b/pyscfad/backend/_numpy/core.py @@ -0,0 +1,50 @@ +import numpy as np + +def is_array(x): + # FIXME should np.generic be included? + return isinstance(x, (np.ndarray, np.generic)) + +def to_numpy(x): + return np.asarray(x) + +def convert_to_tensor(x, dtype=None, **kwargs): + return np.asarray(x, dtype=dtype, **kwargs) + +def vmap(fun, in_axes=0, out_axes=0, chunk_size=None, signature=None): + if not isinstance(out_axes, int): + raise NotImplementedError + + def vmap_f(*args): + if isinstance(in_axes, int): + in_axes_loc = (in_axes,) * len(args) + else: + in_axes_loc = in_axes + + if isinstance(in_axes_loc, (list, tuple)): + excluded = [] + vmap_args = [] + assert len(in_axes_loc) == len(args) + for i, axis in enumerate(in_axes_loc): + if axis is None: + excluded.append(i) + vmap_args.append(args[i]) + elif isinstance(axis, int): + vmap_args.append(np.moveaxis(args[i], axis, 0)) + else: + raise KeyError + if len(excluded) > 0: + excluded = set(excluded) + else: + excluded = None + + vfun = np.vectorize(fun, excluded=excluded, signature=signature) + out = vfun(*vmap_args) + else: + raise KeyError + + if out_axes != 0: + out = np.moveaxis(out, 0, out_axes) + return out + + return vmap_f + diff --git a/pyscfad/backend/_torch/__init__.py b/pyscfad/backend/_torch/__init__.py new file mode 100644 index 00000000..2a52365b --- /dev/null +++ b/pyscfad/backend/_torch/__init__.py @@ -0,0 +1,60 @@ +from types import ModuleType +try: + import torch as torch +except ImportError as err: + raise ImportError("Unable to import torch.") from err + +from ..config import default_floatx +if default_floatx() == 'float64': + torch.set_default_dtype(torch.float64) + +from torch import ( + is_tensor as is_array, +) + +from .._common import ( + class_as_pytree_node, + custom_jvp, +) +from .numpy import ( + iscomplexobj, +) +from .core import ( + to_numpy, + stop_gradient, + vmap, + jit, +) + +class TorchBackend: + def __init__(self, package): + self._pkg = package + self._cache = {} + + def __getattr__(self, name): + if name in self._cache: + return self._cache[name] + + try: + attr = getattr(self._pkg, name) + if isinstance(attr, ModuleType): + submodule = self.__class__(attr) + self._cache[name] = submodule + return submodule + else: + self._cache[name] = attr + return attr + except AttributeError as err: + raise AttributeError(f"{self._pkg.__name__} has no attribute {name}") from err + +backend = TorchBackend(torch) +backend._cache['iscomplexobj'] = iscomplexobj + +backend._cache['is_array'] = is_array +backend._cache['to_numpy'] = to_numpy +backend._cache['stop_gradient'] = stop_gradient +backend._cache['class_as_pytree_node'] = class_as_pytree_node +backend._cache['custom_jvp'] = custom_jvp +backend._cache['jit'] = jit +backend._cache['vmap'] = vmap + diff --git a/pyscfad/backend/_torch/core.py b/pyscfad/backend/_torch/core.py new file mode 100644 index 00000000..121f2e95 --- /dev/null +++ b/pyscfad/backend/_torch/core.py @@ -0,0 +1,15 @@ +import torch + +def to_numpy(x): + return x.numpy(force=True) + +def stop_gradient(x): + return x.detach() + +def vmap(fun, in_axes=0, out_axes=0, chunk_size=None, signature=None): + return torch.vmap(fun, in_dims=in_axes, out_dims=out_axes, chunk_size=chunk_size) + +def jit(obj, **kwargs): + # TODO make jit work + #return torch.jit.script(obj, **kwargs) + return obj diff --git a/pyscfad/backend/_torch/linalg.py b/pyscfad/backend/_torch/linalg.py new file mode 100644 index 00000000..b7ebcd81 --- /dev/null +++ b/pyscfad/backend/_torch/linalg.py @@ -0,0 +1,18 @@ +import torch +from .core import convert_to_tensor + +def cholesky(a, **kwargs): + a = convert_to_tensor(a) + return torch.linalg.cholesky(a, **kwargs) + +def eigh(a, UPLO='L', **kwargs): + a = convert_to_tensor(a) + return torch.linalg.eigh(a, UPLO, **kwargs) + +def inv(a, **kwargs): + a = convert_to_tensor(a) + return torch.linalg.inv(a, **kwargs) + +def norm(x, ord=None, axis=None, keepdims=False, **kwargs): + x = convert_to_tensor(x) + return torch.linalg.norm(x, ord, axis, keepdims, **kwargs) diff --git a/pyscfad/backend/_torch/numpy.py b/pyscfad/backend/_torch/numpy.py new file mode 100644 index 00000000..45289803 --- /dev/null +++ b/pyscfad/backend/_torch/numpy.py @@ -0,0 +1,4 @@ +import torch + +def iscomplexobj(x): + return torch.is_complex(x) diff --git a/pyscfad/backend/config.py b/pyscfad/backend/config.py new file mode 100644 index 00000000..b30bcdd4 --- /dev/null +++ b/pyscfad/backend/config.py @@ -0,0 +1,111 @@ +""" +Default configurations related to the backend. +""" +import os +import sys +import json +import importlib +import contextlib +import threading + +# default +_FLOATX = 'float64' +_BACKEND = 'jax' + +_allowed_floatx = ('float32', 'float64') +_allowed_backend = ('numpy', 'cupy', 'jax', 'torch') +_floatx = _backend = None + +if 'PYSCFAD_HOME' in os.environ: + _base_dir = os.environ['PYSCFAD_HOME'] + _PYSCFAD_DIR = os.path.expanduser(_base_dir) +else: + _base_dir = os.path.expanduser('~') + if not os.access(_base_dir, os.W_OK): + _base_dir = '/tmp' + _PYSCFAD_DIR = os.path.join(_base_dir, '.pyscfad') + +_config_path = os.path.expanduser(os.path.join(_PYSCFAD_DIR, "pyscfad.json")) +if os.path.exists(_config_path): + try: + with open(_config_path) as f: + _config = json.load(f) + except ValueError: + _config = {} + + _floatx = _config.get('floatx', None) + _backend = _config.get('backend', None) + +# NOTE environment variables overwrite configure file +if 'PYSCFAD_FLOATX' in os.environ: + _floatx = os.environ['PYSCFAD_FLOATX'] +if 'PYSCFAD_BACKEND' in os.environ: + _backend = os.environ['PYSCFAD_BACKEND'] + +if _floatx in _allowed_floatx: + _FLOATX = _floatx +if _backend in _allowed_backend: + _BACKEND = _backend + +if not os.path.exists(_PYSCFAD_DIR): + try: + os.makedirs(_PYSCFAD_DIR) + except OSError: + pass + +if not os.path.exists(_config_path): + _config = { + "floatx": _FLOATX, + "backend": _BACKEND, + } + try: + with open(_config_path, "w") as f: + f.write(json.dumps(_config, indent=4)) + except IOError: + pass + +del (_floatx, _backend) +del (os, sys, json) + +def default_backend(): + return _BACKEND + +def default_floatx(): + return _FLOATX + + +#---------------- dynamic backend update ----------------# + +_current_backend = None +_backend_cache = {} + +def set_backend(backend_name): + if not backend_name in _allowed_backend: + raise KeyError(f"Required backend {backend_name} is not supported.") + + with threading.RLock(): + global _current_backend + if backend_name in _backend_cache: + _current_backend = _backend_cache[backend_name] + else: + try: + module = importlib.import_module(f"pyscfad.backend._{backend_name}").backend + except Exception: + raise RuntimeError("Failed setting backend {backend_name}.") + _backend_cache[backend_name] = module + _current_backend = module + +def get_backend(): + return _current_backend + +@contextlib.contextmanager +def with_backend(backend_name): + with threading.RLock(): + global _current_backend + previous_backend = _current_backend + set_backend(backend_name) + try: + yield + finally: + _current_backend = previous_backend + diff --git a/pyscfad/backend/numpy.py b/pyscfad/backend/numpy.py new file mode 100644 index 00000000..60faee4c --- /dev/null +++ b/pyscfad/backend/numpy.py @@ -0,0 +1,4 @@ +from .config import get_backend + +def __getattr__(name): + return getattr(get_backend(), name) diff --git a/pyscfad/backend/ops.py b/pyscfad/backend/ops.py new file mode 100644 index 00000000..5c8c6f33 --- /dev/null +++ b/pyscfad/backend/ops.py @@ -0,0 +1,65 @@ +from .config import get_backend + +__all__ = [ + 'is_array', + 'isarray', + 'is_tensor', + 'to_numpy', + 'stop_gradient', + 'stop_grad', + 'stop_trace', + 'class_as_pytree_node', + 'custom_jvp', + 'jit', + 'vmap', + 'index', + 'index_update', + 'index_add', + 'index_mul', +] + +def __getattr__(name): + return getattr(get_backend(), name) + +def is_array(x): + return get_backend().is_array(x) + +is_tensor = isarray = is_array + +def to_numpy(x): + return get_backend().to_numpy(x) + +def stop_gradient(x): + return get_backend().stop_gradient(x) + +stop_grad = stop_gradient + +def stop_trace(fn): + """Convenient wrapper to call functions with arguments + detached from the graph. + """ + def wrapped_fn(*args, **kwargs): + args_no_grad = [stop_grad(arg) for arg in args] + kwargs_no_grad = {k : stop_grad(v) for k, v in kwargs.items()} + return fn(*args_no_grad, **kwargs_no_grad) + return wrapped_fn + +def class_as_pytree_node(cls, leaf_names, num_args=0): + return get_backend().class_as_pytree_node(cls, leaf_names, num_args=num_args) + +def jit(obj, **kwargs): + return get_backend().jit(obj, **kwargs) + +def vmap(fun, in_axes=0, out_axes=0, chunk_size=None, signature=None): + return get_backend().vmap(fun, in_axes=in_axes, out_axes=out_axes, + chunk_size=chunk_size, signature=signature) + +def index_update(x, idx, y): + return get_backend().index_update(x, idx, y) + +def index_add(x, idx, y): + return get_backend().index_add(x, idx, y) + +def index_mul(x, idx, y): + return get_backend().index_mul(x, idx, y) + diff --git a/pyscfad/cc/ccsd.py b/pyscfad/cc/ccsd.py index ed79d68c..5bbb5d9d 100644 --- a/pyscfad/cc/ccsd.py +++ b/pyscfad/cc/ccsd.py @@ -1,11 +1,12 @@ from functools import reduce import numpy -from jax import numpy as np from pyscf.cc import ccsd as pyscf_ccsd from pyscf.mp.mp2 import _mo_without_core +from pyscfad import numpy as np +from pyscfad import ops from pyscfad import lib -from pyscfad.lib import ops, logger -#from pyscfad.lib import jit +from pyscfad.lib import logger +#from pyscfad.ops import jit #from pyscfad import util from pyscfad import config from pyscfad.implicit_diff import make_implicit_diff diff --git a/pyscfad/cc/ccsd_rdm.py b/pyscfad/cc/ccsd_rdm.py index adddf0b1..51400067 100644 --- a/pyscfad/cc/ccsd_rdm.py +++ b/pyscfad/cc/ccsd_rdm.py @@ -1,4 +1,4 @@ -from jax import numpy as np +from pyscfad import numpy as np def _make_rdm1(mycc, d1, with_frozen=True, ao_repr=False, with_mf=True): doo, dov, dvo, dvv = d1 diff --git a/pyscfad/cc/ccsd_t_slow.py b/pyscfad/cc/ccsd_t_slow.py index 41ba1b38..04b716cb 100644 --- a/pyscfad/cc/ccsd_t_slow.py +++ b/pyscfad/cc/ccsd_t_slow.py @@ -1,10 +1,11 @@ ''' CCSD(T) ''' -from jax import numpy as np +from pyscfad import numpy as np from pyscfad import config, config_update from pyscfad import lib -from pyscfad.lib import logger, jit, vmap +from pyscfad.lib import logger +from pyscfad.ops import jit, vmap from pyscfad.implicit_diff import make_implicit_diff from pyscfad.tools.linear_solver import gen_gmres diff --git a/pyscfad/cc/dfccsd.py b/pyscfad/cc/dfccsd.py index 1b63929e..6d70e891 100644 --- a/pyscfad/cc/dfccsd.py +++ b/pyscfad/cc/dfccsd.py @@ -1,5 +1,5 @@ -from jax import numpy as np from pyscf.lib import square_mat_in_trilu_indices +from pyscfad import numpy as np from pyscfad import util from pyscfad import lib from pyscfad.ao2mo import _ao2mo diff --git a/pyscfad/cc/rccsd.py b/pyscfad/cc/rccsd.py index 5435d07f..203f7011 100644 --- a/pyscfad/cc/rccsd.py +++ b/pyscfad/cc/rccsd.py @@ -1,7 +1,8 @@ from jax import numpy as np from pyscf.lib import current_memory from pyscfad import util -from pyscfad.lib import logger, jit +from pyscfad.lib import logger +from pyscfad.ops import jit from pyscfad import ao2mo from pyscfad.cc import ccsd from pyscfad.cc import rintermediates as imd diff --git a/pyscfad/cc/rintermediates.py b/pyscfad/cc/rintermediates.py index dc6ccbf1..1e7b4edd 100644 --- a/pyscfad/cc/rintermediates.py +++ b/pyscfad/cc/rintermediates.py @@ -1,7 +1,7 @@ ''' Intermediates for restricted CCSD. Complex integrals are supported. ''' -from jax import numpy as np +from pyscfad import numpy as np #from pyscfad.lib import jit # This is restricted (R)CCSD diff --git a/pyscfad/cc/test/test_ccsd_t.py b/pyscfad/cc/test/test_ccsd_t.py index 08b1a463..f933317b 100644 --- a/pyscfad/cc/test/test_ccsd_t.py +++ b/pyscfad/cc/test/test_ccsd_t.py @@ -2,6 +2,10 @@ import numpy import jax from pyscfad import gto, scf, cc +from pyscfad import config + +config.update('pyscfad_scf_implicit_diff', True) +config.update('pyscfad_ccsd_implicit_diff', True) def test_nuc_grad(get_mol): mol = get_mol diff --git a/pyscfad/cc/test/test_rccsd_hess.py b/pyscfad/cc/test/test_rccsd_hess.py index 66111ce1..e76c2201 100644 --- a/pyscfad/cc/test/test_rccsd_hess.py +++ b/pyscfad/cc/test/test_rccsd_hess.py @@ -3,6 +3,10 @@ import jax from pyscf import lib from pyscfad import gto, scf, cc +from pyscfad import config + +config.update('pyscfad_scf_implicit_diff', True) +config.update('pyscfad_ccsd_implicit_diff', True) def test_nuc_hessian(get_mol): mol = get_mol diff --git a/pyscfad/df/addons.py b/pyscfad/df/addons.py index 773cae7e..495de752 100644 --- a/pyscfad/df/addons.py +++ b/pyscfad/df/addons.py @@ -1,5 +1,5 @@ -from jax import numpy as np from pyscf.df import addons as pyscf_addons +from pyscfad import numpy as np from pyscfad import lib from pyscfad import ao2mo from pyscfad.gto._mole_helper import setup_exp, setup_ctr_coeff diff --git a/pyscfad/df/df.py b/pyscfad/df/df.py index 8d47e918..54021bce 100644 --- a/pyscfad/df/df.py +++ b/pyscfad/df/df.py @@ -5,7 +5,7 @@ from pyscf.lib import logger from pyscf.df import df as pyscf_df from pyscfad import util -from pyscfad.lib import isarray +from pyscfad.ops import isarray from pyscfad.df import addons, incore, df_jk @util.pytree_node(['mol', 'auxmol', '_cderi'], num_args=1) diff --git a/pyscfad/df/df_jk.py b/pyscfad/df/df_jk.py index 510d8a14..162c7300 100644 --- a/pyscfad/df/df_jk.py +++ b/pyscfad/df/df_jk.py @@ -1,4 +1,4 @@ -from jax import numpy as np +from pyscfad import numpy as np from pyscfad import config from .addons import restore from ._df_jk_opt import get_jk as get_jk_opt diff --git a/pyscfad/df/incore.py b/pyscfad/df/incore.py index 83dc6f55..684460d3 100644 --- a/pyscfad/df/incore.py +++ b/pyscfad/df/incore.py @@ -6,7 +6,8 @@ from pyscf import gto from pyscf.df.outcore import _guess_shell_ranges from pyscf import __config__ -from pyscfad.lib import ops, vmap, jit +from pyscfad import ops +from pyscfad.ops import vmap, jit from pyscfad import config from . import addons, _int3c_cross_opt diff --git a/pyscfad/dft/libxc.py b/pyscfad/dft/libxc.py index c0043d53..be86abf0 100644 --- a/pyscfad/dft/libxc.py +++ b/pyscfad/dft/libxc.py @@ -1,28 +1,24 @@ from functools import partial -from jax import numpy as np -from jax import jit, custom_jvp from pyscf.dft import libxc from pyscf.dft.libxc import parse_xc, is_lda, is_meta_gga +from pyscfad import numpy as np +from pyscfad.ops import jit, custom_jvp def eval_xc(xc_code, rho, spin=0, relativity=0, deriv=1, omega=None, verbose=None): # NOTE only consider exc and vxc if deriv > 1: raise NotImplementedError - hyb, fn_facs = parse_xc(xc_code) - if omega is not None: - hyb = hyb[:2] + (float(omega),) - - exc = _eval_xc_comp(rho, hyb, fn_facs, spin, relativity, deriv=0, verbose=verbose) + exc = _eval_xc_comp(rho, xc_code, spin, relativity, deriv=0, omega=omega, verbose=verbose) if deriv == 0: vxc = (None,) * 4 elif deriv == 1: - vxc = _eval_xc_comp(rho, hyb, fn_facs, spin, relativity, deriv=1, verbose=verbose) + vxc = _eval_xc_comp(rho, xc_code, spin, relativity, deriv=1, omega=omega, verbose=verbose) return exc, vxc, None, None @partial(custom_jvp, nondiff_argnums=tuple(range(1,7))) -def _eval_xc_comp(rho, hyb, fn_facs, spin=0, relativity=0, deriv=1, verbose=None): - out = libxc._eval_xc(hyb, fn_facs, rho, spin, relativity, deriv, verbose)[deriv] +def _eval_xc_comp(rho, xc_code, spin=0, relativity=0, deriv=1, omega=None, verbose=None): + out = libxc.eval_xc(xc_code, rho, spin, relativity, deriv, omega, verbose)[deriv] if deriv == 1: out = tuple(out) if len(out) < 4: @@ -34,16 +30,17 @@ def _eval_xc_comp(rho, hyb, fn_facs, spin=0, relativity=0, deriv=1, verbose=None return out @_eval_xc_comp.defjvp -def _eval_xc_comp_jvp(hyb, fn_facs, spin, relativity, deriv, verbose, +def _eval_xc_comp_jvp(xc_code, spin, relativity, deriv, omega, verbose, primals, tangents): rho, = primals rho_t, = tangents if deriv > 2: raise NotImplementedError - val = _eval_xc_comp(rho, hyb, fn_facs, spin, relativity, deriv, verbose) - val1 = _eval_xc_comp(rho, hyb, fn_facs, spin, relativity, deriv+1, verbose) + val = _eval_xc_comp(rho, xc_code, spin, relativity, deriv, omega, verbose) + val1 = _eval_xc_comp(rho, xc_code, spin, relativity, deriv+1, omega, verbose) + hyb, fn_facs = parse_xc(xc_code) fn_ids = [x[0] for x in fn_facs] n = len(fn_ids) if (n == 0 or diff --git a/pyscfad/dft/numint.py b/pyscfad/dft/numint.py index 717236e9..6dee3249 100644 --- a/pyscfad/dft/numint.py +++ b/pyscfad/dft/numint.py @@ -1,16 +1,17 @@ import warnings from functools import partial import numpy -from jax import numpy as np from pyscf.lib import load_library from pyscf.dft import numint from pyscf.dft.numint import SWITCH_SIZE from pyscf.dft.gen_grid import BLKSIZE -from pyscfad.lib import ops -from pyscfad.lib import stop_grad -from pyscfad.lib import jit -from pyscfad.lib import custom_jvp -#from pyscfad.lib import vmap +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import ( + stop_grad, + jit, + custom_jvp, +) from pyscfad.dft import libxc libdft = load_library('libdft') diff --git a/pyscfad/dft/rks.py b/pyscfad/dft/rks.py index baa1665a..b5724cd1 100644 --- a/pyscfad/dft/rks.py +++ b/pyscfad/dft/rks.py @@ -1,12 +1,11 @@ -import numpy -from jax import numpy as np from pyscf import __config__ from pyscf.lib import current_memory from pyscf.lib import logger from pyscf.dft import rks as pyscf_rks from pyscf.dft import gen_grid +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import stop_grad +from pyscfad.ops import stop_grad from pyscfad.scf import hf from pyscfad.dft import numint @@ -114,51 +113,38 @@ def energy_elec(ks, dm=None, h1e=None, vhf=None): vhf = ks.get_veff(ks.mol, dm) e1 = np.einsum('ij,ji->', h1e, dm) e2 = vhf.ecoul + vhf.exc - ks.scf_summary['e1'] = stop_grad(e1).real - ks.scf_summary['coul'] = stop_grad(vhf.ecoul).real - ks.scf_summary['exc'] = stop_grad(vhf.exc).real + ks.scf_summary['e1'] = e1.real + ks.scf_summary['coul'] = vhf.ecoul.real + ks.scf_summary['exc'] = vhf.exc.real logger.debug(ks, 'E1 = %s Ecoul = %s Exc = %s', e1, vhf.ecoul, vhf.exc) return (e1+e2).real, e2 -NELEC_ERROR_TOL = getattr(__config__, 'dft_rks_prune_error_tol', 0.02) def prune_small_rho_grids_(ks, mol, dm, grids): mol = stop_grad(mol) dm = stop_grad(dm) rho = ks._numint.get_rho(mol, dm, grids, ks.max_memory) - n = numpy.dot(rho, grids.weights) - if abs(n-mol.nelectron) < NELEC_ERROR_TOL*n: - rho *= grids.weights - idx = abs(rho) > ks.small_rho_cutoff / grids.weights.size - logger.debug(ks, 'Drop grids %d', - grids.weights.size - numpy.count_nonzero(idx)) - grids.coords = numpy.asarray(grids.coords [idx], order='C') - grids.weights = numpy.asarray(grids.weights[idx], order='C') - grids.non0tab = grids.make_mask(mol, grids.coords) - return grids + return grids.prune_by_density_(rho, ks.small_rho_cutoff) def _dft_common_init_(mf, xc='LDA,VWN'): mf.xc = xc mf.nlc = '' + mf.disp = None mf.grids = None mf.nlcgrids = None # Use rho to filter grids - mf.small_rho_cutoff = getattr(__config__, 'dft_rks_RKS_small_rho_cutoff', 1e-7) -################################################## -# don't modify the following attributes, they are not input options + mf.small_rho_cutoff = getattr( + __config__, 'dft_rks_RKS_small_rho_cutoff', 1e-7) mf._numint = numint.NumInt() - mf._keys = mf._keys.union(['xc', 'nlc', 'omega', 'grids', 'nlcgrids', - 'small_rho_cutoff']) def _dft_common_post_init_(mf): if mf.grids is None: mf.grids = gen_grid.Grids(stop_grad(mf.mol)) - mf.grids.level = getattr(__config__, 'dft_rks_RKS_grids_level', - mf.grids.level) + mf.grids.level = getattr( + __config__, 'dft_rks_RKS_grids_level', mf.grids.level) if mf.nlcgrids is None: mf.nlcgrids = gen_grid.Grids(stop_grad(mf.mol)) - mf.nlcgrids.level = getattr(__config__, 'dft_rks_RKS_nlcgrids_level', - mf.nlcgrids.level) - + mf.nlcgrids.level = getattr( + __config__, 'dft_rks_RKS_nlcgrids_level', mf.nlcgrids.level) class KohnShamDFT(pyscf_rks.KohnShamDFT): __init__ = _dft_common_init_ @@ -174,6 +160,23 @@ def reset(self, mol=None): @util.pytree_node(hf.Traced_Attributes, num_args=1) class RKS(KohnShamDFT, hf.RHF): + """Subclass of :class:`pyscf.dft.rks.RKS` with traceable attributes. + + Attributes + ---------- + mol : :class:`pyscfad.gto.Mole` + :class:`pyscfad.gto.Mole` instance. + mo_coeff : array + MO coefficients. + mo_energy : array + MO energies. + _eri : array + Two-electron repulsion integrals. + + Notes + ----- + Grid response is not considered with AD. + """ def __init__(self, mol, xc='LDA,VWN', **kwargs): hf.RHF.__init__(self, mol) KohnShamDFT.__init__(self, xc) diff --git a/pyscfad/fci/__init__.py b/pyscfad/fci/__init__.py index 1692092b..fdcfedc4 100644 --- a/pyscfad/fci/__init__.py +++ b/pyscfad/fci/__init__.py @@ -1,5 +1,5 @@ from functools import reduce -from jax import numpy as np +from pyscfad import numpy as np from pyscfad import ao2mo from pyscfad.fci import fci_slow from pyscfad.fci.fci_slow import fci_ovlp diff --git a/pyscfad/fci/fci_slow.py b/pyscfad/fci/fci_slow.py index bf325081..feb9d8f9 100644 --- a/pyscfad/fci/fci_slow.py +++ b/pyscfad/fci/fci_slow.py @@ -1,7 +1,8 @@ import numpy -from jax import numpy as np from pyscf.fci import cistring -from pyscfad.lib import vmap, ops, stop_grad +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import vmap, stop_grad from pyscfad.lib.linalg_helper import davidson from pyscfad.gto import mole from pyscfad import ao2mo diff --git a/pyscfad/gto/_mole_helper.py b/pyscfad/gto/_mole_helper.py index ddffeaf3..e594b14a 100644 --- a/pyscfad/gto/_mole_helper.py +++ b/pyscfad/gto/_mole_helper.py @@ -3,13 +3,12 @@ KAPPA_OF, PTR_EXP, PTR_COEFF, PTR_ENV_START) def uncontract(mol, shls_slice=None): - """ - Return the uncontracted basis functions. + """Uncontract basis functions. Parameters ---------- - mol : Mole instance - Mole instance with the contracted basis functions. + mol : :class:`Mole` instance + :class:`Mole` instance with the contracted basis functions. shls_slice : tuple Starting and ending indices of the shells being @@ -18,8 +17,8 @@ def uncontract(mol, shls_slice=None): Returns ------- - mol1 : Mole instance - Mole instance with the uncontracted basis functions. + mol1 : :class:`Mole` instance + :class:`Mole` instance with the uncontracted basis functions. Notes ----- @@ -73,7 +72,7 @@ def shlmap_ctr2unctr(mol): Parameters ---------- - mol : Mole instance + mol : :class:`Mole` instance Returns ------- @@ -91,12 +90,11 @@ def shlmap_ctr2unctr(mol): return map_c2u def setup_exp(mol): - """ - Find unique exponents of the basis functions. + """Find unique exponents of the basis functions. Parameters ---------- - mol : Mole instance + mol : :class:`Mole` instance Returns ------- @@ -139,12 +137,11 @@ def setup_exp(mol): return es, es_of, env_of def setup_ctr_coeff(mol): - """ - Find unique contraction coefficients of the basis functions. + """Find unique contraction coefficients of the basis functions. Parameters ---------- - mol : Mole instance + mol : :class:`Mole` instance Returns ------- diff --git a/pyscfad/gto/_moleintor_jvp.py b/pyscfad/gto/_moleintor_jvp.py index 0fbf9d4d..a015a3cf 100644 --- a/pyscfad/gto/_moleintor_jvp.py +++ b/pyscfad/gto/_moleintor_jvp.py @@ -1,11 +1,18 @@ from functools import partial import numpy -from jax import numpy as np + from pyscf import ao2mo from pyscf.gto import mole as pyscf_mole from pyscf.gto import ATOM_OF from pyscf.gto.moleintor import _get_intor_and_comp -from pyscfad.lib import ops, custom_jvp, jit, vmap + +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import ( + custom_jvp, + jit, + vmap, +) from ._mole_helper import ( setup_exp, setup_ctr_coeff, diff --git a/pyscfad/gto/eval_gto.py b/pyscfad/gto/eval_gto.py index 43795c6a..e90db7fa 100644 --- a/pyscfad/gto/eval_gto.py +++ b/pyscfad/gto/eval_gto.py @@ -1,12 +1,18 @@ from functools import partial import numpy #import jax -from jax import numpy as np + from pyscf.gto import mole as pyscf_mole from pyscf.gto.moleintor import make_loc from pyscf.gto.eval_gto import _get_intor_and_comp from pyscf.gto.eval_gto import eval_gto as pyscf_eval_gto -from pyscfad.lib import jit, custom_jvp, vmap + +from pyscfad import numpy as np +from pyscfad.ops import ( + custom_jvp, + jit, + vmap, +) from pyscfad.gto._moleintor_helper import get_bas_label, promote_xyz from pyscfad.gto._mole_helper import ( setup_exp, diff --git a/pyscfad/gto/mole.py b/pyscfad/gto/mole.py index 2dd05541..a5a5da2b 100644 --- a/pyscfad/gto/mole.py +++ b/pyscfad/gto/mole.py @@ -1,42 +1,35 @@ from functools import wraps -from jax import numpy as np from pyscf.gto import mole as pyscf_mole from pyscf.lib import logger, param +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import custom_jvp, vmap from pyscfad.gto import moleintor from pyscfad.gto.eval_gto import eval_gto from ._mole_helper import setup_exp, setup_ctr_coeff Traced_Attributes = ['coords', 'exp', 'ctr_coeff', 'r0'] -def energy_nuc(mol, charges=None, **kwargs): +@wraps(pyscf_mole.inter_distance) +def inter_distance(mol, coords=None): + if coords is None: + coords = mol.atom_coords() + r = coords[:,None,:] - coords[None,:,:] + rr = np.sum(r*r, axis=2) + rr = np.sqrt(np.where(rr>1e-10, rr, 0)) + return rr + +@wraps(pyscf_mole.classical_coulomb_energy) +def classical_coulomb_energy(mol, charges=None, coords=None): if charges is None: - charges = mol.atom_charges() + charges = np.asarray(mol.atom_charges(), dtype=float) if len(charges) <= 1: return 0.0 - r = distance_matrix(mol.atom_coords()) - enuc = np.einsum('i,ij,j->', charges, 1./r, charges) * .5 + rr = inter_distance(mol, coords) + rr = np.where(rr>1e-5, rr, 1e200) + enuc = np.einsum('i,ij,j->', charges, 1./rr, charges) * .5 return enuc -@custom_jvp -def distance_matrix(coords): - r = np.linalg.norm(coords[:,None,:] - coords[None,:,:], axis=2) - r += np.eye(r.shape[-1]) * 1e200 - return r - -@distance_matrix.defjvp -def distance_matrix_jvp(primals, tangents): - coords, = primals - coords_t, = tangents - rnorm = primal_out = distance_matrix(coords) - def body(r1, r2, rnorm, coords_t): - r = r1 - r2 - jvp = np.dot(r / rnorm[:,None], coords_t) - return jvp - tangent_out = vmap(body, (0,None,0,0))(coords, coords, rnorm, coords_t) - tangent_out += tangent_out.T - return primal_out, tangent_out +energy_nuc = classical_coulomb_energy @wraps(pyscf_mole.intor_cross) def intor_cross(intor, mol1, mol2, comp=None, grids=None): @@ -55,22 +48,24 @@ def nao_nr_range(mol, bas_id0, bas_id1): @util.pytree_node(Traced_Attributes) class Mole(pyscf_mole.Mole): - ''' - A subclass of :class:`pyscf.gto.Mole`, where the following - attributes can be traced. - - Attributes: - coords : array - Atomic coordinates. - exp : array - Exponents of Gaussian basis functions. - ctr_coeff : array - Contraction coefficients of Gaussian basis functions. - r0 : array - Centers of Gaussian basis functions. Currently this is - not used as the basis functions are atom centered. This - is a placeholder for floating Gaussian basis sets. - ''' + """Subclass of :class:`pyscf.gto.Mole` with traceable attributes. + + Attributes + ---------- + coords : array + Atomic coordinates. + exp : array + Exponents of Gaussian basis functions. + ctr_coeff : array + Contraction coefficients of Gaussian basis functions. + r0 : array + Centers of Gaussian basis functions. Currently this is + not used as the basis functions are atom centered. This + is a placeholder for floating Gaussian basis sets. + """ + + _keys = {'coords', 'exp', 'ctr_coeff', 'r0'} + def __init__(self, **kwargs): self.coords = None self.exp = None diff --git a/pyscfad/gto/moleintor_opt.py b/pyscfad/gto/moleintor_opt.py new file mode 100644 index 00000000..4754bb53 --- /dev/null +++ b/pyscfad/gto/moleintor_opt.py @@ -0,0 +1,440 @@ +from functools import partial +import ctypes +import numpy +from jax import custom_vjp +from jax.tree_util import tree_flatten, tree_unflatten + +from pyscf import lib +from pyscf.lib import logger +from pyscf.gto import mole as pyscf_mole +from pyscfad.gto._pyscf_moleintor import ( + make_loc, + make_cintopt, + _stand_sym_code, + _get_intor_and_comp, +) + +from pyscfad.gto._mole_helper import ( + get_fakemol_exp, + get_fakemol_cs, + setup_exp, + setup_ctr_coeff, + shlmap_ctr2unctr, +) +from pyscfad.gto._moleintor_helper import ( + int1e_dr1_name, + int2e_dr1_name, +) +from pyscfad.gto.moleintor import _intor +from pyscfadlib import libcgto_vjp as libcgto + +def getints(mol, intor, shls_slice=None, + comp=None, hermi=0, aosym='s1', + out=None, grids=None): + if intor.endswith('_spinor'): + raise NotImplementedError('Integrals for spinors are not supported.') + if grids is not None: + raise NotImplementedError('Integrals on grids are not supported.') + if out is not None: + logger.warn(mol, f'Argument out = {out} will be ignored.') + if hermi == 2: + hermi = 0 + msg = f'Anti-hermitian symmetry is not supported. Setting hermi = {hermi}.' + logger.warn(mol, msg) + + if (intor.startswith('int1e') or + intor.startswith('int2c2e') or + intor.startswith('ECP')): + return getints2c(mol, intor, shls_slice, comp, hermi, out=None) + elif intor.startswith('int2e'): + return getints4c(mol, intor, shls_slice, comp, aosym, out=None) + else: + raise NotImplementedError(f'Integral {intor} is not supported.') + +@partial(custom_vjp, nondiff_argnums=(1,2,3,4,5)) +def getints2c(mol, intor, shls_slice=None, comp=None, hermi=0, out=None): + return _intor(mol, intor, comp=comp, hermi=hermi, + shls_slice=shls_slice, out=out) + +def getints2c_fwd(mol, intor, shls_slice, comp, hermi, out): + y = getints2c(mol, intor, shls_slice, comp, hermi, out) + return y, (mol,) + +def getints2c_bwd(intor, shls_slice, comp, hermi, out, + res, ybar): + mol = res[0] + leaves = [] + + if mol.coords is not None: + vjp_coords = getints2c_coords_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar) + leaves.append(vjp_coords) + + if mol.exp is not None: + vjp_exp = getints2c_exp_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar) + leaves.append(vjp_exp) + + if mol.ctr_coeff is not None: + vjp_coeff = getints2c_coeff_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar) + leaves.append(vjp_coeff) + + if mol.r0 is not None: + raise NotImplementedError + + _, tree = tree_flatten(mol) + molbar = tree_unflatten(tree, leaves) + return (molbar,) + +getints2c.defvjp(getints2c_fwd, getints2c_bwd) + +@partial(custom_vjp, nondiff_argnums=(1,2,3,4,5)) +def getints4c(mol, intor, shls_slice=None, comp=None, aosym='s1', out=None): + return _intor(mol, intor, comp=comp, aosym=aosym, + shls_slice=shls_slice, out=out) + +def getints4c_fwd(mol, intor, shls_slice, comp, aosym, out): + y = getints4c(mol, intor, shls_slice, comp, aosym, out) + return y, (mol,) + +def getints4c_bwd(intor, shls_slice, comp, aosym, out, + res, ybar): + mol = res[0] + leaves = [] + + if mol.coords is not None: + vjp_coords = getints4c_coords_bwd(intor, shls_slice, comp, aosym, out, + mol, ybar) + leaves.append(vjp_coords) + + if mol.exp is not None: + raise NotImplementedError + #vjp_exp = getints4c_exp_bwd(intor, shls_slice, comp, aosym, out, + # mol, ybar) + + if mol.ctr_coeff is not None: + raise NotImplementedError + #vjp_coeff = getints4c_coeff_bwd(intor, shls_slice, comp, aosym, out, + # mol, ybar) + + _, tree = tree_flatten(mol) + molbar = tree_unflatten(tree, leaves) + return (molbar,) + +getints4c.defvjp(getints4c_fwd, getints4c_bwd) + + +def _int1e_r0_deriv(intor_ip_bra, shls_slice, comp, hermi, + mol, ybar, rc_deriv=None, switch_ij=False): + nbas = mol.nbas + if shls_slice is None: + shls_slice = (0, nbas, 0, nbas) + i0, i1, j0, j1 = shls_slice[:4] + assert i0 >= 0 and i1 <= nbas + assert j0 >= 0 and j1 <= nbas + assert i0 < i1 and j0 < j1 + + ao_loc = make_loc(mol._bas, intor_ip_bra) + ao_loc = numpy.asarray(ao_loc, order='C', dtype=numpy.int32) + naoi = ao_loc[i1] - ao_loc[i0] + naoj = ao_loc[j1] - ao_loc[j0] + + atm = numpy.asarray(mol._atm, order='C', dtype=numpy.int32) + bas = numpy.asarray(mol._bas, order='C', dtype=numpy.int32) + env = numpy.asarray(mol._env, order='C', dtype=numpy.double) + + cintopt = make_cintopt(atm, bas, env, intor_ip_bra) + + ybar = numpy.asarray(ybar).reshape(comp, naoi, naoj) + if hermi == 1: + assert i0 == j0 and i1 == j1 + ybar = ybar + ybar.transpose(0,2,1) + elif switch_ij: + ybar = ybar.transpose(0,2,1) + shls_slice = (j0, j1, i0, i1) + ybar = numpy.asarray(ybar, order='C', dtype=numpy.double) + + ndim = 3 + natm = len(atm) + if rc_deriv is not None: + vjp = numpy.zeros((ndim,), order='C', dtype=numpy.double) + fn = getattr(libcgto, 'GTOint2c_rc_vjp') + else: + vjp = numpy.zeros((natm,ndim), order='C', dtype=numpy.double) + fn = getattr(libcgto, 'GTOint2c_r0_vjp') + + fn(getattr(libcgto, intor_ip_bra), + vjp.ctypes.data_as(ctypes.c_void_p), + ybar.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(comp), ctypes.c_int(ndim), ctypes.c_int(hermi), + (ctypes.c_int*4)(*(shls_slice[:4])), + ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, + atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(natm), + bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(nbas), + env.ctypes.data_as(ctypes.c_void_p)) + return vjp + +def _int1e_rc_deriv(intor_ip_bra, shls_slice, comp, hermi, + mol, ybar): + intor_ip_bra = intor_ip_bra.replace('nuc', 'rinv') + + natm = mol.natm + vjp = numpy.zeros((natm,3), dtype=numpy.double) + for iatm in range(natm): + with mol.with_rinv_at_nucleus(iatm): + charge = -mol.atom_charge(iatm) + vjp[iatm] += _int1e_r0_deriv(intor_ip_bra, shls_slice, comp, hermi, + mol, ybar, rc_deriv=iatm) * charge + return vjp + +def getints2c_coords_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar): + log = logger.new_logger(mol) + _, comp = _get_intor_and_comp(intor, comp) + + switch_ij = False + intor_ip_bra, intor_ip_ket = int1e_dr1_name(intor) + if not intor_ip_bra: + #switch i and j shells due to derivative over ket + switch_ij = True + intor_ip_bra = intor_ip_ket + + vjp = _int1e_r0_deriv(intor_ip_bra, shls_slice, comp, hermi, + mol, ybar, switch_ij=switch_ij) + if 'nuc' in intor_ip_bra: + vjp += _int1e_rc_deriv(intor_ip_bra, shls_slice, comp, hermi, + mol, ybar) + log.timer(f'getints2c_coords_bwd {intor}') + del log + return vjp + +def getints2c_exp_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar): + log = logger.new_logger(mol) + nbas = mol.nbas + if shls_slice is None: + shls_slice = (0, nbas, 0, nbas) + i0, i1, j0, j1 = shls_slice[:4] + assert i0 >= 0 and i1 <= nbas + assert j0 >= 0 and j1 <= nbas + assert i0 < i1 and j0 < j1 + + if comp is None: + comp = 1 + elif comp != 1: + raise NotImplementedError + + order = 2 # first order derivative of Gaussians + + shlmap_c2u = shlmap_ctr2unctr(mol) + shlmap_c2u = numpy.asarray(shlmap_c2u, order='C', dtype=numpy.int32) + mol1 = get_fakemol_exp(mol, order) + mol1._atm[:,pyscf_mole.CHARGE_OF] = 0 # set nuclear charge to zero + + ao_loc = make_loc(mol._bas, intor) + ao_loc = numpy.asarray(ao_loc, order='C', dtype=numpy.int32) + + if intor.endswith('_sph'): + cart = False + intor = intor.replace('_sph', '_cart') + ao_loc_cart = make_loc(mol._bas, intor) + ao_loc_cart = numpy.asarray(ao_loc_cart, order='C', dtype=numpy.int32) + elif intor.endswith('_cart'): + cart = True + ao_loc_cart = ao_loc + else: + raise NotImplementedError + + nbas1 = len(mol1._bas) + shls_slice = shls_slice + (nbas, nbas+nbas1) + if 'ECP' in intor: + assert mol._ecp is not None + bas1 = numpy.vstack((mol1._bas, mol._ecpbas)) + else: + bas1 = mol1._bas + atmc, basc, envc = pyscf_mole.conc_env(mol._atm, mol._bas, mol._env, + mol1._atm, bas1, mol1._env) + if 'ECP' in intor: + envc[pyscf_mole.AS_ECPBAS_OFFSET] = nbas + nbas1 + envc[pyscf_mole.AS_NECPBAS] = len(mol._ecpbas) + + atmc = numpy.asarray(atmc, order='C', dtype=numpy.int32) + basc = numpy.asarray(basc, order='C', dtype=numpy.int32) + envc = numpy.asarray(envc, order='C', dtype=numpy.double) + + _, es_of, _ = setup_exp(mol) + es_of = numpy.asarray(es_of, order='C', dtype=numpy.int32) + + nes = len(mol.exp) + vjp = numpy.zeros((nes,), order='C', dtype=numpy.double) + + cintopt = make_cintopt(atmc, basc, envc, intor) + + if hermi == 1: + ybar = ybar + ybar.T + ybar = numpy.asarray(ybar, order='C', dtype=numpy.double) + + fn = getattr(libcgto, 'GTOint2c_exp_vjp') + fn(getattr(libcgto, intor), + vjp.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(nes), + ybar.ctypes.data_as(ctypes.c_void_p), + shlmap_c2u.ctypes.data_as(ctypes.c_void_p), + es_of.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(comp), ctypes.c_int(hermi), + (ctypes.c_int*6)(*(shls_slice[:6])), + ao_loc.ctypes.data_as(ctypes.c_void_p), + ao_loc_cart.ctypes.data_as(ctypes.c_void_p), cintopt, + atmc.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(atmc)), + basc.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(basc)), + envc.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(cart), ctypes.c_int(order)) + log.timer('getints2c_exp_bwd') + del log + return vjp + +def getints2c_coeff_bwd(intor, shls_slice, comp, hermi, out, + mol, ybar): + log = logger.new_logger(mol) + nbas = mol.nbas + if shls_slice is None: + shls_slice = (0, nbas, 0, nbas) + i0, i1, j0, j1 = shls_slice[:4] + assert i0 >= 0 and i1 <= nbas + assert j0 >= 0 and j1 <= nbas + assert i0 < i1 and j0 < j1 + + if comp is None: + comp = 1 + elif comp != 1: + raise NotImplementedError + + shlmap_c2u = shlmap_ctr2unctr(mol) + shlmap_c2u = numpy.asarray(shlmap_c2u, order='C', dtype=numpy.int32) + mol1 = get_fakemol_cs(mol) + mol1._atm[:,pyscf_mole.CHARGE_OF] = 0 # set nuclear charge to zero + + ao_loc = make_loc(mol._bas, intor) + ao_loc = numpy.asarray(ao_loc, order='C', dtype=numpy.int32) + + if intor.endswith('_sph'): + cart = False + elif intor.endswith('_cart'): + cart = True + else: + raise NotImplementedError + + nbas1 = len(mol1._bas) + shls_slice = shls_slice + (nbas, nbas+nbas1) + if 'ECP' in intor: + assert mol._ecp is not None + bas1 = numpy.vstack((mol1._bas, mol._ecpbas)) + else: + bas1 = mol1._bas + atmc, basc, envc = pyscf_mole.conc_env(mol._atm, mol._bas, mol._env, + mol1._atm, bas1, mol1._env) + if 'ECP' in intor: + envc[pyscf_mole.AS_ECPBAS_OFFSET] = nbas + nbas1 + envc[pyscf_mole.AS_NECPBAS] = len(mol._ecpbas) + + atmc = numpy.asarray(atmc, order='C', dtype=numpy.int32) + basc = numpy.asarray(basc, order='C', dtype=numpy.int32) + envc = numpy.asarray(envc, order='C', dtype=numpy.double) + + _, cs_of, _ = setup_ctr_coeff(mol) + cs_of = numpy.asarray(cs_of, order='C', dtype=numpy.int32) + + ncs = len(mol.ctr_coeff) + vjp = numpy.zeros((ncs,), order='C', dtype=numpy.double) + + cintopt = make_cintopt(atmc, basc, envc, intor) + + if hermi == 1: + ybar = ybar + ybar.T + ybar = numpy.asarray(ybar, order='C', dtype=numpy.double) + + fn = getattr(libcgto, 'GTOint2c_coeff_vjp') + fn(getattr(libcgto, intor), + vjp.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(ncs), + ybar.ctypes.data_as(ctypes.c_void_p), + shlmap_c2u.ctypes.data_as(ctypes.c_void_p), + cs_of.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(comp), ctypes.c_int(hermi), + (ctypes.c_int*6)(*(shls_slice[:6])), + ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, + atmc.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(atmc)), + basc.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(len(basc)), + envc.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(cart)) + log.timer('getints2c_coeff_bwd') + del log + return vjp + + +def getints4c_coords_bwd(intor, shls_slice, comp, aosym, out, + mol, ybar): + log = logger.new_logger(mol) + aosym = _stand_sym_code(aosym) + if aosym != 's4': + raise NotImplementedError + + atm = numpy.asarray(mol._atm, dtype=numpy.int32, order='C') + bas = numpy.asarray(mol._bas, dtype=numpy.int32, order='C') + env = numpy.asarray(mol._env, dtype=numpy.double, order='C') + c_atm = atm.ctypes.data_as(ctypes.c_void_p) + c_bas = bas.ctypes.data_as(ctypes.c_void_p) + c_env = env.ctypes.data_as(ctypes.c_void_p) + natm = atm.shape[0] + nbas = bas.shape[0] + + if shls_slice is None: + shls_slice = (0, nbas, 0, nbas, 0, nbas, 0, nbas) + elif len(shls_slice) == 4: + shls_slice = shls_slice + (0, nbas, 0, nbas) + + i0, i1, j0, j1, k0, k1, l0, l1 = shls_slice + assert i0 >= 0 and i1 <= nbas + assert j0 >= 0 and j1 <= nbas + assert k0 >= 0 and k1 <= nbas + assert l0 >= 0 and l1 <= nbas + assert i0 < i1 and j0 < j1 and k0 < k1 and l0 < l1 + + if comp is None: + comp = 1 + elif comp != 1: + raise NotImplementedError + comp = 3 # first order + + intor1 = int2e_dr1_name(intor)[0] + ao_loc = make_loc(bas, intor1) + ao_loc = numpy.asarray(ao_loc, order='C', dtype=numpy.int32) + + naoi = ao_loc[i1] - ao_loc[i0] + naoj = ao_loc[j1] - ao_loc[j0] + naok = ao_loc[k1] - ao_loc[k0] + naol = ao_loc[l1] - ao_loc[l0] + + if aosym in ('s4', 's2ij'): + assert numpy.all(ao_loc[i0:i1]-ao_loc[i0] == ao_loc[j0:j1]-ao_loc[j0]) + if aosym in ('s4', 's2kl'): + assert numpy.all(ao_loc[k0:k1]-ao_loc[k0] == ao_loc[l0:l1]-ao_loc[l0]) + + drv = libcgto.GTOnr2e_fill_r0_vjp + fill = getattr(libcgto, 'GTOnr2e_fill_r0_vjp_'+aosym) + vjp = numpy.zeros((natm, comp), order='C', dtype=numpy.double) + if aosym == 's4': + ybar += ybar.T + ybar = numpy.asarray(ybar, order='C', dtype=numpy.double) + + cintopt = make_cintopt(atm, bas, env, intor1) + prescreen = lib.c_null_ptr() + drv(getattr(libcgto, intor1), fill, prescreen, + vjp.ctypes.data_as(ctypes.c_void_p), + ybar.ctypes.data_as(ctypes.c_void_p), + ctypes.c_int(comp), (ctypes.c_int*8)(*shls_slice), + ao_loc.ctypes.data_as(ctypes.c_void_p), cintopt, + c_atm, ctypes.c_int(natm), c_bas, ctypes.c_int(nbas), c_env) + + log.timer('getints4c_coords_bwd') + del log + return vjp diff --git a/pyscfad/gw/rpa.py b/pyscfad/gw/rpa.py index 618a9648..a25fddae 100644 --- a/pyscfad/gw/rpa.py +++ b/pyscfad/gw/rpa.py @@ -1,10 +1,10 @@ import numpy -from jax import numpy as np from pyscf.lib import logger, current_memory from pyscf import df as pyscf_df from pyscf.gw import rpa as pyscf_rpa +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import vmap, jit +from pyscfad.ops import vmap, jit from pyscfad import scf, dft, df from pyscfad.df.addons import restore diff --git a/pyscfad/gw/test/test_rpa.py b/pyscfad/gw/test/test_rpa.py index 764f601c..c3cf56bc 100644 --- a/pyscfad/gw/test/test_rpa.py +++ b/pyscfad/gw/test/test_rpa.py @@ -24,7 +24,7 @@ def get_h2o(): config.reset() -def test_nuc_grad(get_h2o): +def test_nuc_grad_skip(get_h2o): mol = get_h2o auxbasis = pyscf_df.addons.make_auxbasis(mol, mp2fit=True) auxmol = df.addons.make_auxmol(mol, auxbasis) diff --git a/pyscfad/lib/__init__.py b/pyscfad/lib/__init__.py index a4cb0f87..418d7938 100644 --- a/pyscfad/lib/__init__.py +++ b/pyscfad/lib/__init__.py @@ -1,3 +1,4 @@ -from pyscfad.lib.misc import * -from pyscfad.lib.jax_helper import * +""" +Wrappers for functions in pyscf.lib +""" from pyscfad.lib.numpy_helper import * diff --git a/pyscfad/lib/config.h.in b/pyscfad/lib/config.h.in deleted file mode 100644 index bb1f255c..00000000 --- a/pyscfad/lib/config.h.in +++ /dev/null @@ -1,6 +0,0 @@ -#if defined _OPENMP -#include -#else -#define omp_get_thread_num() 0 -#define omp_get_num_threads() 1 -#endif diff --git a/pyscfad/lib/diis.py b/pyscfad/lib/diis.py index ede45e43..9b76890c 100644 --- a/pyscfad/lib/diis.py +++ b/pyscfad/lib/diis.py @@ -1,10 +1,11 @@ import numpy import scipy.linalg -from jax import numpy as np from pyscf.lib import prange from pyscf.lib import diis as pyscf_diis from pyscf.lib.diis import INCORE_SIZE, BLOCK_SIZE -from pyscfad.lib import logger, stop_grad +from pyscfad import numpy as np +from pyscfad.ops import stop_grad +from pyscfad.lib import logger # pylint: disable=consider-using-f-string class DIIS(pyscf_diis.DIIS): diff --git a/pyscfad/lib/jax_helper.py b/pyscfad/lib/jax_helper.py deleted file mode 100644 index 8b61c5b9..00000000 --- a/pyscfad/lib/jax_helper.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Helper functions for jax -""" - -#import dataclasses -import numpy -import jax -#from jax import tree_util -from pyscf import __config__ - -PYSCFAD = getattr(__config__, 'pyscfad', False) - -def stop_grad(x): - if PYSCFAD: - return jax.lax.stop_gradient(x) - else: - return x - -def stop_trace(fn): - if PYSCFAD: - def wrapped_fn(*args, **kwargs): - args_no_grad = [stop_grad(arg) for arg in args] - kwargs_no_grad = {k : stop_grad(v) for k, v in kwargs.items()} - return fn(*args_no_grad, **kwargs_no_grad) - return wrapped_fn - else: - return fn - -def jit(fun, **kwargs): - if PYSCFAD: - return jax.jit(fun, **kwargs) - else: - return fun - -def vmap_numpy(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, signature=None): - if axis_name is not None: - raise NotImplementedError - if axis_size is not None: - raise NotImplementedError - if not isinstance(out_axes, int): - raise NotImplementedError - - def vmap_f(*args): - if isinstance(in_axes, int): - in_axes_loc = (in_axes,) * len(args) - else: - in_axes_loc = in_axes - - if isinstance(in_axes_loc, (list, tuple)): - excluded = [] - vmap_args = [] - assert len(in_axes_loc) == len(args) - for i, axis in enumerate(in_axes_loc): - if axis is None: - excluded.append(i) - vmap_args.append(args[i]) - elif isinstance(axis, int): - vmap_args.append(numpy.moveaxis(args[i], axis, 0)) - else: - raise KeyError - if len(excluded) > 0: - excluded = set(excluded) - else: - excluded = None - - vfun = numpy.vectorize(fun, excluded=excluded, signature=signature) - out = vfun(*vmap_args) - else: - raise KeyError - - if out_axes != 0: - out = numpy.moveaxis(out, 0, out_axes) - return out - - return vmap_f - -if PYSCFAD: - def vmap(fun, in_axes=0, out_axes=0, axis_name=None, axis_size=None, signature=None): - f_vmap = jax.vmap(fun, in_axes=in_axes, out_axes=out_axes, - axis_name=axis_name, axis_size=axis_size) - return f_vmap -else: - vmap = vmap_numpy - -if PYSCFAD: - custom_jvp = jax.custom_jvp -else: - class custom_jvp(): - ''' - A fake custom_jvp that does nothing - ''' - def __init__(self, fun, *args, **kwargs): - self.fun = fun - self.jvp = None - - def defjvp(self, jvp): - self.jvp = jvp - return jvp - - def __call__(self, *args, **kwargs): - return self.fun(*args, **kwargs) - - -#def dataclass(cls): -# data_cls = dataclasses.dataclass()(cls) -# data_fields = [] -# meta_fields = [] -# for field_name, field_info in data_cls.__dataclass_fields__.items(): -# is_pytree_node = field_info.metadata.get('pytree_node', False) -# if is_pytree_node: -# data_fields.append(field_name) -# else: -# meta_fields.append(field_name) -# -# def tree_flatten(obj): -# children = tuple(getattr(obj, key, None) for key in data_fields) -# metadata = tuple(getattr(obj, key, None) for key in meta_fields) -# return children, metadata -# -# def tree_unflatten(metadata, children): -# data_args = tuple(zip(data_fields, children)) -# meta_args = tuple(zip(meta_fields, metadata)) -# kwargs = dict(data_args + meta_args) -# obj = data_cls(**kwargs) -# return obj -# -# tree_util.register_pytree_node(data_cls, -# tree_flatten, -# tree_unflatten) -# return data_cls -# -#def field(pytree_node=False, **kwargs): -# return dataclasses.field(metadata={'pytree_node': pytree_node}, **kwargs) diff --git a/pyscfad/lib/linalg_helper.py b/pyscfad/lib/linalg_helper.py index e732108c..36505eec 100644 --- a/pyscfad/lib/linalg_helper.py +++ b/pyscfad/lib/linalg_helper.py @@ -1,7 +1,6 @@ import sys import warnings import numpy -from jax import numpy as np from jax import scipy from pyscf import __config__ from pyscf.lib.linalg_helper import ( @@ -9,7 +8,9 @@ _sort_by_similarity, _sort_elast, ) -from pyscfad.lib import logger, stop_grad, jit +from pyscfad import numpy as np +from pyscfad.lib import logger +from pyscfad.ops import stop_grad, jit DAVIDSON_LINDEP = getattr(__config__, 'lib_linalg_helper_davidson_lindep', 1e-14) MAX_MEMORY = getattr(__config__, 'lib_linalg_helper_davidson_max_memory', 2000) diff --git a/pyscfad/lib/logger.py b/pyscfad/lib/logger.py index a517ba90..1d7067f0 100644 --- a/pyscfad/lib/logger.py +++ b/pyscfad/lib/logger.py @@ -1,10 +1,13 @@ # pylint: skip-file from pyscf.lib import logger from pyscf.lib.logger import * +from pyscfad import ops def flush(rec, msg, *args): args_list = [] for arg in args: + if ops.is_array(arg): + arg = ops.to_numpy(arg) args_list.append(getattr(arg, 'val', arg)) rec.stdout.write(msg % tuple(args_list)) rec.stdout.write('\n') diff --git a/pyscfad/lib/misc.py b/pyscfad/lib/misc.py deleted file mode 100644 index c933b098..00000000 --- a/pyscfad/lib/misc.py +++ /dev/null @@ -1,16 +0,0 @@ -import numpy -import jax -from jax import numpy as np -from .jax_helper import jit - -def isarray(a): - return isinstance(a, (numpy.ndarray, jax.Array)) - -@jit -def square_mat_in_trilu_indices(n): - tril2sq = np.zeros((n,n), dtype=int) - idx = np.tril_indices(n) - idx_flat = np.arange(n*(n+1)//2) - tril2sq = tril2sq.at[idx[1],idx[0]].set(idx_flat) - tril2sq = tril2sq.at[idx[0],idx[1]].set(idx_flat) - return tril2sq diff --git a/pyscfad/lib/numpy_helper.py b/pyscfad/lib/numpy_helper.py index 25d25b85..6b66f97b 100644 --- a/pyscfad/lib/numpy_helper.py +++ b/pyscfad/lib/numpy_helper.py @@ -1,21 +1,25 @@ from functools import partial import math -from jax import numpy -from .jax_helper import jit, vmap +from pyscf.lib.numpy_helper import ( + PLAIN, + HERMITIAN, + ANTIHERMI, + SYMMETRIC, +) +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import jit, vmap from pyscfad import config -from pyscfad.lib import ops -einsum = numpy.einsum -dot = numpy.dot - -__all__ = ['numpy', 'einsum', 'dot', - 'PLAIN', 'HERMITIAN', 'ANTIHERMI', 'SYMMETRIC', - 'unpack_triu', 'unpack_tril', 'pack_tril'] - -PLAIN = 0 -HERMITIAN = 1 -ANTIHERMI = 2 -SYMMETRIC = 3 +__all__ = [ + 'PLAIN', + 'HERMITIAN', + 'ANTIHERMI', + 'SYMMETRIC', + 'unpack_triu', + 'unpack_tril', + 'pack_tril' +] @partial(jit, static_argnums=1) def _unpack_triu(triu, filltril=HERMITIAN): @@ -24,19 +28,19 @@ def _unpack_triu(triu, filltril=HERMITIAN): ''' assert triu.ndim == 1 nd = int(math.sqrt(2*triu.size)) - out = numpy.zeros((nd,nd), dtype=triu.dtype) - idx = numpy.triu_indices(nd) + out = np.zeros((nd,nd), dtype=triu.dtype) + idx = np.triu_indices(nd) out = ops.index_update(out, idx, triu) if filltril == PLAIN: return out elif filltril == HERMITIAN: - out += numpy.tril(out.T.conj(), -1) + out += np.tril(out.T.conj(), -1) return out elif filltril == ANTIHERMI: out -= out.conj().T return out elif filltril == SYMMETRIC: - out += numpy.tril(out.T, -1) + out += np.tril(out.T, -1) return out else: raise KeyError @@ -46,9 +50,9 @@ def unpack_triu(triu, filltril=HERMITIAN, axis=-1, out=None): out = _unpack_triu(triu, filltril) elif triu.ndim == 2: if axis == -1 or axis == 1: - out = vmap(_unpack_triu, (0,None))(triu, filltril) + out = vmap(_unpack_triu, (0,None), signature='(n)->(m,m)')(triu, filltril) elif axis == 0 or axis == -2: - out = vmap(_unpack_triu, (1,None))(triu, filltril) + out = vmap(_unpack_triu, (1,None), signature='(n)->(m,m)')(triu, filltril) else: raise NotImplementedError return out @@ -60,19 +64,19 @@ def _unpack_tril(tril, filltriu=HERMITIAN): ''' assert tril.ndim == 1 nd = int(math.sqrt(2*tril.size)) - out = numpy.zeros((nd,nd), dtype=tril.dtype) - idx = numpy.tril_indices(nd) + out = np.zeros((nd,nd), dtype=tril.dtype) + idx = np.tril_indices(nd) out = ops.index_update(out, idx, tril) if filltriu == PLAIN: return out elif filltriu == HERMITIAN: - out += numpy.triu(out.T.conj(), 1) + out += np.triu(out.T.conj(), 1) return out elif filltriu == ANTIHERMI: out -= out.T.conj() return out elif filltriu == SYMMETRIC: - out += numpy.triu(out.T, 1) + out += np.triu(out.T, 1) return out else: raise KeyError @@ -100,7 +104,7 @@ def pack_tril(a, axis=-1, out=None): from pyscfad.lib import _numpy_helper_opt return _numpy_helper_opt._pack_tril(a, axis, out) def fn(mat): - idx = numpy.tril_indices(mat.shape[0]) + idx = np.tril_indices(mat.shape[0]) return mat[idx].ravel() if a.ndim == 3: diff --git a/pyscfad/lib/ops.py b/pyscfad/lib/ops.py deleted file mode 100644 index 08d2c9d6..00000000 --- a/pyscfad/lib/ops.py +++ /dev/null @@ -1,62 +0,0 @@ -from packaging.version import Version -import jax -import jax.ops -from jax import numpy as jnp -from pyscfad import config - -_JaxArray = config.numpy_backend == 'jax' - -# pylint: disable=no-member - -if Version(jax.__version__) < Version('0.2.22'): - _index_update = jax.ops.index_update - _index_add = jax.ops.index_add - _index_mul = jax.ops.index_mul -else: - def _index_update(x, idx, y, indices_are_sorted=False, unique_indices=False): - x = jnp.asarray(x) - y = jnp.asarray(y) - return x.at[idx].set(y) - - def _index_add(x, idx, y, indices_are_sorted=False, unique_indices=False): - x = jnp.asarray(x) - y = jnp.asarray(y) - return x.at[idx].add(y) - - def _index_mul(x, idx, y, indices_are_sorted=False, unique_indices=False): - x = jnp.asarray(x) - y = jnp.asarray(y) - return x.at[idx].multiply(y) - -class _Indexable(object): - # pylint: disable=line-too-long - """ - see https://github.com/google/jax/blob/97d00584f8b87dfe5c95e67892b54db993f34486/jax/_src/ops/scatter.py#L87 - """ - __slots__ = () - - def __getitem__(self, idx): - return idx - -index = _Indexable() - -def index_update(a, idx, value): - if _JaxArray: - a = _index_update(a, idx, value) - else: - a[idx] = value - return a - -def index_add(a, idx, value): - if _JaxArray: - a = _index_add(a, idx, value) - else: - a[idx] += value - return a - -def index_mul(a, idx, value): - if _JaxArray: - a = _index_mul(a, idx, value) - else: - a[idx] *= value - return a diff --git a/pyscfad/lo/boys.py b/pyscfad/lo/boys.py index 545ad760..0978b28c 100644 --- a/pyscfad/lo/boys.py +++ b/pyscfad/lo/boys.py @@ -1,10 +1,10 @@ import numpy import scipy import jax -from jax import numpy as np from pyscf.lib import logger from pyscf.lo import boys as pyscf_boys -from pyscfad.lib import stop_grad +from pyscfad import numpy as np +from pyscfad.ops import stop_grad from pyscfad.implicit_diff import make_implicit_diff from pyscfad.soscf.ciah import ( extract_rotation, diff --git a/pyscfad/lo/orth.py b/pyscfad/lo/orth.py index e96b1b6a..56e71c0b 100644 --- a/pyscfad/lo/orth.py +++ b/pyscfad/lo/orth.py @@ -2,8 +2,8 @@ import numpy import scipy from jax import scipy as jax_scipy -from jax import numpy as np -from pyscfad.lib import custom_jvp +from pyscfad import numpy as np +from pyscfad.ops import custom_jvp @partial(custom_jvp, nondiff_argnums=(1,)) def lowdin(s, thresh=1e-15): diff --git a/pyscfad/lo/pipek.py b/pyscfad/lo/pipek.py index 02826ffc..71e14b5a 100644 --- a/pyscfad/lo/pipek.py +++ b/pyscfad/lo/pipek.py @@ -2,9 +2,9 @@ import numpy import scipy import jax -from jax import numpy as np from pyscf.lib import logger from pyscf.lo import pipek as pyscf_pipek +from pyscfad import numpy as np from pyscfad.lib import vmap from pyscfad.implicit_diff import make_implicit_diff from pyscfad.soscf.ciah import ( diff --git a/pyscfad/ml/__init__.py b/pyscfad/ml/__init__.py new file mode 100644 index 00000000..9aac0d91 --- /dev/null +++ b/pyscfad/ml/__init__.py @@ -0,0 +1,3 @@ +""" +Experimental module for machine learning tasks. +""" diff --git a/pyscfad/ml/scf/__init__.py b/pyscfad/ml/scf/__init__.py new file mode 100644 index 00000000..240bd528 --- /dev/null +++ b/pyscfad/ml/scf/__init__.py @@ -0,0 +1,3 @@ +""" +Mean-field theory. +""" diff --git a/pyscfad/ml/scf/hf.py b/pyscfad/ml/scf/hf.py new file mode 100644 index 00000000..119ff972 --- /dev/null +++ b/pyscfad/ml/scf/hf.py @@ -0,0 +1,53 @@ +""" +SCF with given Fock matrix. +""" +from pyscfad import ops +from pyscfad import numpy as np +from pyscfad.scf import hf + +def cholesky_orth(s): + L = np.linalg.cholesky(s) + x = np.linalg.inv(L).T.conj() + return x + +class SCF(hf.SCF): + def _eigh(self, h, s): + if s is None: + return np.linalg.eigh(h) + + # orthogonalize basis + s = np.asarray(s) + x = cholesky_orth(s) + h_orth = x.T.conj() @ h @ x + e, c = np.linalg.eigh(h_orth) + c = x @ c + return e, c + + def get_occ(self, mo_energy=None, mo_coeff=None): + if mo_energy is None: + mo_energy = self.mo_energy + mo_energy = ops.to_numpy(mo_energy) + return super().get_occ(mo_energy) + +if __name__ == '__main__': + import torch + from pyscf import gto + + mol = gto.Mole() + mol.atom = 'H 0 0 0; F 0 0 0.9' + mol.basis = 'sto3g' + mol.build() + + fock = torch.rand(mol.nao, mol.nao, dtype=float) + fock = .5 * (fock + fock.T.conj()) + fock = torch.autograd.Variable(fock, requires_grad=True) + + mf = SCF(mol) + s = mf.get_ovlp() + mo_energy, mo_coeff = mf.eig(fock, s) + mo_occ = np.asarray(mf.get_occ(mo_energy)) # get_occ returns a numpy array + dm1 = mf.make_rdm1(mo_coeff, mo_occ) + dip = mf.dip_moment(dm=dm1) + dip_norm = np.linalg.norm(dip) + dip_norm.backward() + print(fock.grad) diff --git a/pyscfad/mp/dfmp2.py b/pyscfad/mp/dfmp2.py index fb872088..315e5e38 100644 --- a/pyscfad/mp/dfmp2.py +++ b/pyscfad/mp/dfmp2.py @@ -2,13 +2,14 @@ import numpy import jax from jax import custom_vjp -from jax import numpy as np from pyscf import __config__ from pyscf.lib import direct_sum, current_memory #from pyscf.mp.mp2 import _ChemistsERIs from pyscfad import config +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import vmap, logger +from pyscfad.ops import vmap +from pyscfad.lib import logger from pyscfad.ao2mo import _ao2mo from pyscfad.mp import mp2 diff --git a/pyscfad/mp/mp2.py b/pyscfad/mp/mp2.py index 37cf0b1f..397b959a 100644 --- a/pyscfad/mp/mp2.py +++ b/pyscfad/mp/mp2.py @@ -1,12 +1,13 @@ from functools import wraps import jax -from jax import numpy as np from pyscf import __config__ as pyscf_config from pyscf.lib import split_reshape from pyscf.mp import mp2 as pyscf_mp2 +from pyscfad import numpy as np from pyscfad import util from pyscfad import lib -from pyscfad.lib import ops, logger +from pyscfad.lib import logger +from pyscfad import ops from pyscfad import ao2mo WITH_T2 = getattr(pyscf_config, 'mp_mp2_with_t2', True) diff --git a/pyscfad/mp/test/test_oomp2.py b/pyscfad/mp/test/test_oomp2.py index 93522b3e..f297fa1d 100644 --- a/pyscfad/mp/test/test_oomp2.py +++ b/pyscfad/mp/test/test_oomp2.py @@ -1,8 +1,8 @@ import pytest import numpy import jax -from jax import numpy as np from scipy.optimize import minimize +from pyscfad import numpy as np from pyscfad import util from pyscfad.tools import rotate_mo1 from pyscfad import gto, scf, mp @@ -48,7 +48,7 @@ def grad(x0, mf): return f, g f, g = grad(x0, mf) - return (numpy.array(f), numpy.array(g)) + return (numpy.asarray(f), numpy.asarray(g)) def test_oomp2_energy(get_mol): mol = get_mol @@ -56,7 +56,7 @@ def test_oomp2_energy(get_mol): mf.kernel() nao = mol.nao size = nao*(nao-1)//2 - x0 = numpy.zeros([size,]) + x0 = np.zeros([size,]) options = {"gtol":1e-5} res = minimize(func, x0, args=(mf,), jac=True, method="BFGS", options = options) e = func(res.x, mf)[0] diff --git a/pyscfad/pbc/df/df_jk.py b/pyscfad/pbc/df/df_jk.py index ce7eacd8..1a05fd23 100644 --- a/pyscfad/pbc/df/df_jk.py +++ b/pyscfad/pbc/df/df_jk.py @@ -1,8 +1,8 @@ import numpy -from jax import numpy as np from pyscf.pbc import tools as pyscf_tools from pyscf.pbc.lib.kpts_helper import is_zero, member -from pyscfad.lib import ops +from pyscfad import numpy as np +from pyscfad import ops def _ewald_exxdiv_for_G0(cell, kpts, dms, vk, kpts_band=None): s = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpts) diff --git a/pyscfad/pbc/df/fft_jk.py b/pyscfad/pbc/df/fft_jk.py index d5bc1f02..a8b49259 100644 --- a/pyscfad/pbc/df/fft_jk.py +++ b/pyscfad/pbc/df/fft_jk.py @@ -1,9 +1,9 @@ import numpy -from jax import numpy as np -from jax import vmap from pyscf import lib as pyscf_lib from pyscf.pbc.df.df_jk import _format_dms, _format_kpts_band, _format_jks -from pyscfad.lib import ops +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import vmap from pyscfad.pbc import tools from pyscfad.pbc.df.df_jk import _ewald_exxdiv_for_G0 from pyscfad.pbc.lib.kpts_helper import is_zero, gamma_point diff --git a/pyscfad/pbc/dft/krks.py b/pyscfad/pbc/dft/krks.py index e0db3122..12177a8f 100644 --- a/pyscfad/pbc/dft/krks.py +++ b/pyscfad/pbc/dft/krks.py @@ -1,11 +1,11 @@ import numpy -from jax import numpy as np from pyscf import __config__ from pyscf.lib import logger from pyscf.pbc.dft import krks as pyscf_krks from pyscf.pbc.dft import gen_grid, multigrid -#from pyscfad import util -from pyscfad.lib import stop_grad +from pyscfad import util +from pyscfad import numpy as np +from pyscfad.ops import stop_grad from pyscfad.dft.rks import VXC from pyscfad.pbc.scf import khf from pyscfad.pbc.dft import rks @@ -77,10 +77,26 @@ def get_veff(ks, cell=None, dm=None, dm_last=0, vhf_last=0, hermi=1, del log return vxc -#@util.pytree_node(khf.Traced_Attributes, num_args=1) +@util.pytree_node(khf.Traced_Attributes, num_args=1) class KRKS(rks.KohnShamDFT, khf.KRHF): + """Subclass of :class:`pyscf.pbc.dft.krks.KRKS` with traceable attributes. + + Attributes + ---------- + cell : :class:`pyscfad.pbc.gto.Cell` + :class:`pyscfad.pbc.gto.Cell` instance. + mo_coeff : array + MO coefficients. + mo_energy : array + MO energies. + + Notes + ----- + Grid response is not considered with AD. + """ def __init__(self, cell, kpts=numpy.zeros((1,3)), xc='LDA,VWN', - exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), **kwargs): + exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), + **kwargs): khf.KRHF.__init__(self, cell, kpts, exxdiv, **kwargs) rks.KohnShamDFT.__init__(self, xc) self.__dict__.update(kwargs) @@ -89,7 +105,6 @@ def __init__(self, cell, kpts=numpy.zeros((1,3)), xc='LDA,VWN', # Currently, no grid response is considered. rks.KohnShamDFT.__post_init__(self) - def dump_flags(self, verbose=None): khf.KRHF.dump_flags(self, verbose) rks.KohnShamDFT.dump_flags(self, verbose) @@ -106,9 +121,9 @@ def energy_elec(self, dm_kpts=None, h1e_kpts=None, vhf=None): weight = 1./len(h1e_kpts) e1 = weight * np.einsum('kij,kji', h1e_kpts, dm_kpts) tot_e = e1 + vhf.ecoul + vhf.exc - self.scf_summary['e1'] = stop_grad(e1.real) - self.scf_summary['coul'] = stop_grad(vhf.ecoul.real) - self.scf_summary['exc'] = stop_grad(vhf.exc.real) + self.scf_summary['e1'] = e1.real + self.scf_summary['coul'] = vhf.ecoul.real + self.scf_summary['exc'] = vhf.exc.real logger.debug(self, 'E1 = %s Ecoul = %s Exc = %s', e1, vhf.ecoul, vhf.exc) return tot_e.real, vhf.ecoul + vhf.exc diff --git a/pyscfad/pbc/dft/numint.py b/pyscfad/pbc/dft/numint.py index feb38f0f..33c7687e 100644 --- a/pyscfad/pbc/dft/numint.py +++ b/pyscfad/pbc/dft/numint.py @@ -1,9 +1,10 @@ import sys import numpy -from jax import numpy as np from pyscf.pbc.dft import numint as pyscf_numint from pyscf.pbc.dft.gen_grid import BLKSIZE -from pyscfad.lib import ops, stop_grad +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import stop_grad from pyscfad.dft import numint from pyscfad.dft.numint import eval_mat, _contract_rho, _dot_ao_dm diff --git a/pyscfad/pbc/dft/rks.py b/pyscfad/pbc/dft/rks.py index d50a6390..f43b906e 100644 --- a/pyscfad/pbc/dft/rks.py +++ b/pyscfad/pbc/dft/rks.py @@ -1,12 +1,14 @@ import numpy -from jax import numpy as np from pyscf import __config__ from pyscf.lib import logger from pyscf.pbc.dft import rks as pyscf_rks from pyscf.pbc.dft import gen_grid, multigrid from pyscf.pbc.dft.rks import prune_small_rho_grids_ -#from pyscfad import util -from pyscfad.lib import stop_grad +from pyscf.pbc.lib.kpts import KPoints + +from pyscfad import util +from pyscfad import numpy as np +from pyscfad.ops import stop_grad from pyscfad.dft import rks as mol_ks from pyscfad.dft.rks import VXC from pyscfad.pbc.scf import hf as pbchf @@ -81,26 +83,51 @@ def _dft_common_init_(mf, xc='LDA,VWN', **kwargs): from pyscfad.pbc.scf import khf mf.xc = xc mf.grids = None - mf.small_rho_cutoff = getattr(__config__, 'dft_rks_RKS_small_rho_cutoff', 1e-7) + mf.nlc = '' + mf.nlcgrids = None + mf.small_rho_cutoff = getattr( + __config__, 'dft_rks_RKS_small_rho_cutoff', 1e-7) if isinstance(mf, khf.KSCF): - mf._numint = numint.KNumInt(mf.kpts) + if isinstance(mf.kpts, KPoints): + mf._numint = numint.KNumInt(mf.kpts.kpts) + else: + mf._numint = numint.KNumInt(mf.kpts) else: mf._numint = numint.NumInt() - mf._keys = mf._keys.union(['xc', 'grids', 'small_rho_cutoff']) def _dft_common_post_init_(mf): from pyscf.pbc.gto import Cell if mf.grids is None: mf.grids = gen_grid.UniformGrids(mf.cell.view(Cell)) + if mf.nlcgrids is None: + mf.nlcgrids = gen_grid.UniformGrids(mf.cell.view(Cell)) -class KohnShamDFT(mol_ks.KohnShamDFT): +class KohnShamDFT(mol_ks.KohnShamDFT, pyscf_rks.KohnShamDFT): __init__ = _dft_common_init_ __post_init__ = _dft_common_post_init_ -#@util.pytree_node(pbchf.Traced_Attributes, num_args=1) + dump_flags = pyscf_rks.KohnShamDFT.dump_flags + +@util.pytree_node(pbchf.Traced_Attributes, num_args=1) class RKS(KohnShamDFT, pbchf.RHF): + """Subclass of :class:`pyscf.pbc.dft.rks.RKS` with traceable attributes. + + Attributes + ---------- + cell : :class:`pyscfad.pbc.gto.Cell` + :class:`pyscfad.pbc.gto.Cell` instance. + mo_coeff : array + MO coefficients. + mo_energy : array + MO energies. + + Notes + ----- + Grid response is not considered with AD. + """ def __init__(self, cell, xc='LDA,VWN', kpt=numpy.zeros(3), - exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), **kwargs): + exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), + **kwargs): pbchf.RHF.__init__(self, cell, kpt, exxdiv, **kwargs) KohnShamDFT.__init__(self, xc) self.__dict__.update(kwargs) diff --git a/pyscfad/pbc/gto/cell.py b/pyscfad/pbc/gto/cell.py index 57a889bc..3fe908c9 100644 --- a/pyscfad/pbc/gto/cell.py +++ b/pyscfad/pbc/gto/cell.py @@ -1,13 +1,16 @@ import warnings import numpy -from jax import numpy as np from jax.scipy.special import erf, erfc from pyscf import __config__ from pyscf.gto.mole import PTR_COORD from pyscf.gto.moleintor import _get_intor_and_comp from pyscf.pbc.gto import cell as pyscf_cell + +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import stop_grad from pyscfad import util -from pyscfad.lib import stop_grad, ops, logger +from pyscfad.lib import logger from pyscfad.gto import mole from pyscfad.gto._mole_helper import setup_exp, setup_ctr_coeff from pyscfad.pbc.gto import _pbcintor @@ -172,32 +175,32 @@ def gn0(eta,z): @util.pytree_node(mole.Traced_Attributes+['abc']) class Cell(mole.Mole, pyscf_cell.Cell): - def __init__(self, **kwargs): - mole.Mole.__init__(self, **kwargs) - self.a = None # lattice vectors, (a1,a2,a3) - self.abc = None # traced lattice vectors - # if set, defines a spherical cutoff - # of fourier components, with .5 * G**2 < ke_cutoff - self.ke_cutoff = None - - self.pseudo = None - self.dimension = 3 - # TODO: Simple hack for now; the implementation of ewald depends on the - # density-fitting class. This determines how the ewald produces - # its energy. - self.low_dim_ft_type = None - -################################################## -# These attributes are initialized by build function if not given - self.mesh = None - self.rcut = None - -################################################## -# don't modify the following variables, they are not input arguments - keys = ('precision', 'exp_to_discard') - self._keys = self._keys.union(self.__dict__).union(keys) - self.__dict__.update(kwargs) + """Subclass of :class:`pyscf.pbc.gto.Cell` with traceable attributes. + + Attributes + ---------- + coords : array + Atomic coordinates. + exp : array + Exponents of Gaussian basis functions. + ctr_coeff : array + Contraction coefficients of Gaussian basis functions. + r0 : array + Centers of Gaussian basis functions. Currently this is + not used as the basis functions are atom centered. This + is a placeholder for floating Gaussian basis sets. + abc : array + Lattice vectors. + """ + _keys = {'abc'} + def __init__(self, **kwargs): + self.coords = None + self.exp = None + self.ctr_coeff = None + self.r0 = None + self.abc = None + pyscf_cell.Cell.__init__(self, **kwargs) def build(self, *args, **kwargs): trace_coords = kwargs.pop('trace_coords', True) @@ -215,7 +218,7 @@ def build(self, *args, **kwargs): if trace_ctr_coeff: self.ctr_coeff, _, _ = setup_ctr_coeff(self) if trace_r0: - pass + raise NotImplementedError if trace_lattice_vectors: self.abc = np.asarray(self.lattice_vectors()) diff --git a/pyscfad/pbc/gto/eval_gto.py b/pyscfad/pbc/gto/eval_gto.py index e661db44..c609c8f4 100644 --- a/pyscfad/pbc/gto/eval_gto.py +++ b/pyscfad/pbc/gto/eval_gto.py @@ -1,10 +1,10 @@ from functools import partial import numpy -from jax import numpy as np from pyscf.pbc.gto import Cell from pyscf.pbc.gto.eval_gto import _get_intor_and_comp from pyscf.pbc.gto.eval_gto import eval_gto as pyscf_eval_gto -from pyscfad.lib import custom_jvp +from pyscfad import numpy as np +from pyscfad.ops import custom_jvp from pyscfad.gto.eval_gto import _eval_gto_dot_grad_tangent_r0 def eval_gto(cell, eval_name, coords, comp=None, kpts=None, kpt=None, diff --git a/pyscfad/pbc/lib/kpts_helper.py b/pyscfad/pbc/lib/kpts_helper.py index c87c856f..9b473452 100644 --- a/pyscfad/pbc/lib/kpts_helper.py +++ b/pyscfad/pbc/lib/kpts_helper.py @@ -1,6 +1,6 @@ import numpy from pyscf.pbc.lib.kpts_helper import KPT_DIFF_TOL -from pyscfad.lib import stop_grad +from pyscfad.ops import stop_grad def is_zero(kpt): return abs(numpy.asarray(stop_grad(kpt))).sum() < KPT_DIFF_TOL diff --git a/pyscfad/pbc/scf/hf.py b/pyscfad/pbc/scf/hf.py index 6d7b9d65..8afee05e 100644 --- a/pyscfad/pbc/scf/hf.py +++ b/pyscfad/pbc/scf/hf.py @@ -1,29 +1,43 @@ import sys +from functools import wraps import h5py import numpy -from jax import numpy as np from pyscf import __config__ from pyscf.pbc.scf import hf as pyscf_pbc_hf from pyscf.pbc.scf.hf import _format_jks -#from pyscfad import util -from pyscfad.lib import stop_grad, logger +from pyscfad import util +from pyscfad import numpy as np +from pyscfad.ops import stop_grad, stop_trace +from pyscfad.lib import logger from pyscfad.scf import hf as mol_hf from pyscfad.pbc import df -#Traced_Attributes = ['cell', 'mo_coeff', 'mo_energy', 'with_df'] +Traced_Attributes = ['cell', 'mo_coeff', 'mo_energy',]# 'with_df'] +@wraps(pyscf_pbc_hf.get_ovlp) def get_ovlp(cell, kpt=np.zeros(3)): - s = cell.pbc_intor('int1e_ovlp', hermi=1, kpts=kpt) + s = cell.pbc_intor('int1e_ovlp', hermi=0, kpts=kpt) return s -#@util.pytree_node(Traced_Attributes, num_args=1) +@util.pytree_node(Traced_Attributes, num_args=1) class SCF(mol_hf.SCF, pyscf_pbc_hf.SCF): + """Subclass of :class:`pyscf.pbc.scf.hf.SCF` with traceable attributes. + + Attributes + ---------- + cell : :class:`pyscfad.pbc.gto.Cell` + :class:`pyscfad.pbc.gto.Cell` instance. + mo_coeff : array + MO coefficients. + mo_energy : array + MO energies. + """ def __init__(self, cell, kpt=numpy.zeros(3), - exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald')):#, **kwargs): + exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), + **kwargs): if not cell._built: sys.stderr.write('Warning: cell.build() is not called in input\n') cell.build() - self.cell = cell mol_hf.SCF.__init__(self, cell) self.with_df = df.FFTDF(cell) @@ -33,13 +47,13 @@ def __init__(self, cell, kpt=numpy.zeros(3), self.kpt = kpt self.conv_tol = max(cell.precision * 10, 1e-8) - self._keys = self._keys.union(['cell', 'exxdiv', 'with_df', 'rsjk']) - #self.__dict__.update(kwargs) + self.__dict__.update(kwargs) - def get_init_guess(self, cell=None, key='minao'): - if cell is None: cell = self.cell + def get_init_guess(self, cell=None, key='minao', s1e=None): + if cell is None: + cell = self.cell dm = mol_hf.SCF.get_init_guess(self, cell, key) - dm = normalize_dm_(self, dm) + dm = normalize_dm_(self, dm, s1e) return dm def get_hcore(self, cell=None, kpt=None): @@ -56,19 +70,28 @@ def get_hcore(self, cell=None, kpt=None): h1 = cell.pbc_intor('int1e_kin', comp=1, hermi=1, kpts=kpt) return nuc + h1 + @wraps(pyscf_pbc_hf.SCF.get_jk) def get_jk(self, cell=None, dm=None, hermi=1, kpt=None, kpts_band=None, with_j=True, with_k=True, omega=None, **kwargs): - #if cell is None: - # cell = self.cell + if cell is None: + cell = self.cell if dm is None: dm = self.make_rdm1() if kpt is None: kpt = self.kpt - cpu0 = (logger.process_clock(), logger.perf_counter()) + log = logger.new_logger(self) + cpu0 = (log._t0, log._w0) + dm = np.asarray(dm) nao = dm.shape[-1] + if (not omega and kpts_band is None and + not self.rsjk and + (self.exxdiv == 'ewald' or not self.exxdiv) and + (self._eri is not None or cell.incore_anyway or + self._is_mem_enough())): + log.warn('pbc.SCF will not construct incore 4-center ERIs.') if self.rsjk: raise NotImplementedError else: @@ -79,7 +102,8 @@ def get_jk(self, cell=None, dm=None, hermi=1, kpt=None, kpts_band=None, vj = _format_jks(vj, dm, kpts_band) if with_k: vk = _format_jks(vk, dm, kpts_band) - logger.timer(self, 'vj and vk', *cpu0) + log.timer('vj and vk', *cpu0) + del log return vj, vk def get_ovlp(self, cell=None, kpt=None): @@ -100,21 +124,13 @@ def dump_chk(self, envs): energy_nuc = pyscf_pbc_hf.SCF.energy_nuc energy_grad = NotImplemented -RHF = SCF - -def normalize_dm_(mf, dm): - cell = mf.cell - s = stop_grad(mf.get_ovlp(cell)) - if getattr(dm, 'ndim', 0) == 2: - ne = numpy.einsum('ij,ji->', stop_grad(dm), s).real - else: - ne = numpy.einsum('xij,ji->', stop_grad(dm), s).real - if abs(ne - cell.nelectron).sum() > 1e-7: - logger.debug(mf, 'Big error detected in the electron number ' - 'of initial guess density matrix (Ne/cell = %g)!\n' - ' This can cause huge error in Fock matrix and ' - 'lead to instability in SCF for low-dimensional ' - 'systems.\n DM is normalized wrt the number ' - 'of electrons %s', ne, cell.nelectron) - dm *= cell.nelectron / ne - return dm + +@util.pytree_node(Traced_Attributes, num_args=1) +class RHF(SCF, pyscf_pbc_hf.RHF): + pass + + +@wraps(pyscf_pbc_hf.normalize_dm_) +def normalize_dm_(mf, dm, s1e=None): + return stop_trace(pyscf_pbc_hf.normalize_dm_)(mf, dm, s1e=s1e) + diff --git a/pyscfad/pbc/scf/khf.py b/pyscfad/pbc/scf/khf.py index 8fe4fdf3..5ca1b4ad 100644 --- a/pyscfad/pbc/scf/khf.py +++ b/pyscfad/pbc/scf/khf.py @@ -1,17 +1,18 @@ import sys import h5py import numpy -from jax import numpy as np from pyscf import __config__ from pyscf.pbc.scf import khf as pyscf_khf -#from pyscfad import util -from pyscfad.lib import stop_grad, logger +from pyscfad import util +from pyscfad import numpy as np +from pyscfad.ops import stop_grad +from pyscfad.lib import logger from pyscfad.scf import hf as mol_hf from pyscfad.pbc import df from pyscfad.pbc.scf import hf as pbchf # TODO add mo_coeff, which requires AD wrt complex numbers -#Traced_Attributes = ['cell', 'mo_energy', 'with_df'] +Traced_Attributes = ['cell', 'mo_energy',]# 'with_df'] def get_ovlp(mf, cell=None, kpts=None): if cell is None: @@ -45,8 +46,8 @@ def energy_elec(mf, dm_kpts=None, h1e_kpts=None, vhf_kpts=None): nkpts = len(dm_kpts) e1 = 1./nkpts * np.einsum('kij,kji', dm_kpts, h1e_kpts) e_coul = 1./nkpts * np.einsum('kij,kji', dm_kpts, vhf_kpts) * 0.5 - mf.scf_summary['e1'] = stop_grad(e1.real) - mf.scf_summary['e2'] = stop_grad(e_coul.real) + mf.scf_summary['e1'] = e1.real + mf.scf_summary['e2'] = e_coul.real logger.debug(mf, 'E1 = %s E_coul = %s', e1, e_coul) if abs(e_coul.imag > mf.cell.precision*10): logger.warn(mf, 'Coulomb energy has imaginary part %s. ' @@ -92,18 +93,26 @@ def make_rdm1(mo_coeff_kpts, mo_occ_kpts, **kwargs): dm = [mol_hf.make_rdm1(mo_coeff_kpts[k], mo_occ_kpts[k]) for k in range(nkpts)] return np.asarray(dm) -#@util.pytree_node(Traced_Attributes, num_args=1) +@util.pytree_node(Traced_Attributes, num_args=1) class KSCF(pbchf.SCF, pyscf_khf.KSCF): + """Subclass of :class:`pyscf.pbc.scf.khf.KSCF` with traceable attributes. + + Attributes + ---------- + cell : :class:`pyscfad.pbc.gto.Cell` + :class:`pyscfad.pbc.gto.Cell` instance. + mo_energy : array + MO energies. + """ def __init__(self, cell, kpts=numpy.zeros((1,3)), - exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald')):#, **kwargs): + exxdiv=getattr(__config__, 'pbc_scf_SCF_exxdiv', 'ewald'), + **kwargs): if not cell._built: sys.stderr.write('Warning: cell.build() is not called in input\n') cell.build() - self.cell = cell mol_hf.SCF.__init__(self, cell) self.with_df = df.FFTDF(cell, kpts=kpts) - # Range separation JK builder self.rsjk = None self.exxdiv = exxdiv @@ -111,33 +120,26 @@ def __init__(self, cell, kpts=numpy.zeros((1,3)), self.conv_tol = max(cell.precision * 10, 1e-8) self.exx_built = False - self._keys = self._keys.union(['cell', 'exx_built', 'exxdiv', 'with_df', 'rsjk']) - #self.__dict__.update(kwargs) + self.__dict__.update(kwargs) def get_jk(self, cell=None, dm_kpts=None, hermi=1, kpts=None, kpts_band=None, with_j=True, with_k=True, omega=None, **kwargs): - #if cell is None: - # cell = self.cell if kpts is None: kpts = self.kpts if dm_kpts is None: dm_kpts = self.make_rdm1() - cpu0 = (logger.process_clock(), logger.perf_counter()) + + log = logger.new_logger(self) + cpu0 = (log._t0, log._w0) if self.rsjk: raise NotImplementedError else: vj, vk = self.with_df.get_jk(dm_kpts, hermi, kpts, kpts_band, with_j, with_k, omega, self.exxdiv) - logger.timer(self, 'vj and vk', *cpu0) + log.timer('vj and vk', *cpu0) + del log return vj, vk - def get_init_guess(self, cell=None, key='minao'): - if cell is None: - cell = self.cell - dm0 = pyscf_khf.KSCF.get_init_guess(self, stop_grad(cell), key) - dm0 = numpy.asarray(dm0) - return dm0 - def eig(self, h_kpts, s_kpts): nkpts = len(h_kpts) eig_kpts = [] @@ -175,4 +177,29 @@ def dump_chk(self, envs): get_k = pyscf_khf.KSCF.get_k get_grad = pyscf_khf.KSCF.get_grad -KRHF = KSCF +@util.pytree_node(Traced_Attributes, num_args=1) +class KRHF(KSCF, pyscf_khf.KRHF): + def get_init_guess(self, cell=None, key='minao', s1e=None): + from pyscf import lib + if s1e is None: + s1e = self.get_ovlp(cell) + dm = mol_hf.SCF.get_init_guess(self, cell, key) + nkpts = len(self.kpts) + if dm.ndim == 2: + # dm[nao,nao] at gamma point -> dm_kpts[nkpts,nao,nao] + dm = numpy.repeat(dm[None,:,:], nkpts, axis=0) + dm_kpts = dm + + ne = lib.einsum('kij,kji->', dm_kpts, stop_grad(s1e)).real + # FIXME: consider the fractional num_electron or not? This maybe + # relate to the charged system. + nelectron = float(self.cell.tot_electrons(nkpts)) + if abs(ne - nelectron) > 0.01*nkpts: + logger.debug(self, 'Big error detected in the electron number ' + 'of initial guess density matrix (Ne/cell = %g)!\n' + ' This can cause huge error in Fock matrix and ' + 'lead to instability in SCF for low-dimensional ' + 'systems.\n DM is normalized wrt the number ' + 'of electrons %s', ne/nkpts, nelectron/nkpts) + dm_kpts *= (nelectron / ne).reshape(-1,1,1) + return dm_kpts diff --git a/pyscfad/pbc/scf/test/test_hf.py b/pyscfad/pbc/scf/test/test_hf.py index 424f2fbd..708ac55e 100644 --- a/pyscfad/pbc/scf/test/test_hf.py +++ b/pyscfad/pbc/scf/test/test_hf.py @@ -33,7 +33,7 @@ def get_cell(): cell.a = lattice cell.basis = basis cell.pseudo = pseudo - cell.build(trace_coords=True) + cell.build(trace_exp=False, trace_ctr_coeff=False) return cell @pytest.fixture diff --git a/pyscfad/pbc/tools/pbc.py b/pyscfad/pbc/tools/pbc.py index ae2f7f97..5b6b1202 100644 --- a/pyscfad/pbc/tools/pbc.py +++ b/pyscfad/pbc/tools/pbc.py @@ -1,10 +1,12 @@ import warnings import copy import numpy as np -from jax import numpy as jnp from pyscf import lib from pyscf.pbc.tools import get_monkhorst_pack_size, cutoff_to_mesh -from pyscfad.lib import stop_grad, ops, logger +from pyscfad import numpy as jnp +from pyscfad import ops +from pyscfad.ops import stop_grad +from pyscfad.lib import logger def fft(f, mesh): ''' diff --git a/pyscfad/scf/chkfile.py b/pyscfad/scf/chkfile.py index 3ce2d320..444fcfb2 100644 --- a/pyscfad/scf/chkfile.py +++ b/pyscfad/scf/chkfile.py @@ -1,5 +1,6 @@ import h5py from pyscf.lib.chkfile import save_mol +from pyscfad.ops import stop_grad from pyscfad.lib.chkfile import save def dump_scf(mol, chkfile, e_tot, mo_energy, mo_coeff, mo_occ, @@ -11,8 +12,8 @@ def dump_scf(mol, chkfile, e_tot, mo_energy, mo_coeff, mo_occ, else: save_mol(mol, chkfile) - scf_dic = {'e_tot' : e_tot, - 'mo_energy': mo_energy, - 'mo_occ' : mo_occ, - 'mo_coeff' : mo_coeff} + scf_dic = {'e_tot' : stop_grad(e_tot), + 'mo_energy': stop_grad(mo_energy), + 'mo_occ' : stop_grad(mo_occ), + 'mo_coeff' : stop_grad(mo_coeff)} save(chkfile, 'scf', scf_dic) diff --git a/pyscfad/scf/diis.py b/pyscfad/scf/diis.py index 47c354af..a9359ae9 100644 --- a/pyscfad/scf/diis.py +++ b/pyscfad/scf/diis.py @@ -1,9 +1,9 @@ import numpy from pyscf.scf import diis as pyscf_cdiis from pyscfad import config +from pyscfad.ops import stop_grad from pyscfad.lib import ( logger, - stop_grad, diis, ) diff --git a/pyscfad/scf/hf.py b/pyscfad/scf/hf.py index 33313d39..320639fa 100644 --- a/pyscfad/scf/hf.py +++ b/pyscfad/scf/hf.py @@ -1,11 +1,9 @@ from functools import ( partial, - reduce, wraps, ) import numpy import jax -from jax import numpy as np from pyscf.data import nist from pyscf.lib import module_method @@ -13,14 +11,15 @@ from pyscf.scf.hf import TIGHT_GRAD_CONV_TOL from pyscfad import config +from pyscfad import numpy as np from pyscfad import util from pyscfad import lib -from pyscfad.lib import ( - logger, - jit, - stop_trace, +from pyscfad.ops import ( stop_grad, + stop_trace, + jit, ) +from pyscfad.lib import logger from pyscfad.implicit_diff import make_implicit_diff from pyscfad import df from pyscfad.scf import _vhf @@ -29,7 +28,7 @@ from pyscfad.scipy.linalg import eigh from pyscfad.tools.linear_solver import gen_gmres -Traced_Attributes = ['mol', '_eri']#, 'mo_coeff', 'mo_energy'] +Traced_Attributes = ['mol', '_eri', 'mo_coeff', 'mo_energy'] def _scf_optimality_cond(dm, mf, s1e, h1e): mol = getattr(mf, 'cell', mf.mol) @@ -123,11 +122,6 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, mo_energy = mo_coeff = mo_occ = None s1e = mf.get_ovlp(mol) - cond = numpy.linalg.cond(stop_grad(s1e)) - log.debug('cond(S) = %s', cond) - if cond.max()*1e-17 > conv_tol: - log.warn('Singularity detected in overlap matrix (condition number = %4.3g). ' - 'SCF may be inaccurate and hard to converge.', cond.max()) if mf.max_cycle <= 0: # Skip SCF iterations. Compute only the total energy of the initial density @@ -158,7 +152,7 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, chkfile.save_mol(mol, mf.chkfile) # A preprocessing hook before the SCF iteration - #mf.pre_kernel(locals()) + mf.pre_kernel(locals()) # SCF iteration # NOTE if use implicit differentiation, only dm will have gradient. @@ -211,7 +205,7 @@ def kernel(mf, conv_tol=1e-10, conv_tol_grad=None, log.timer('scf_cycle', *cput0) del log # A post-processing hook before return - #mf.post_kernel(locals()) + mf.post_kernel(locals()) return scf_conv, e_tot, mo_energy, mo_coeff, mo_occ @@ -233,10 +227,10 @@ def _dot_eri_dm_s1(eri, dm, with_j, with_k): def dot_eri_dm(eri, dm, hermi=0, with_j=True, with_k=True): dm = np.asarray(dm) nao = dm.shape[-1] - if eri.dtype == np.complex128 or eri.size == nao**4: + if np.iscomplexobj(eri) or eri.size == nao**4: vj, vk = _dot_eri_dm_s1(eri, dm, with_j, with_k) else: - if eri.dtype == np.complex128: + if np.iscomplexobj(eri): raise NotImplementedError vj, vk = _vhf.incore(eri, dm, hermi, with_j, with_k) return vj, vk @@ -252,23 +246,22 @@ def energy_elec(mf, dm=None, h1e=None, vhf=None): vhf = mf.get_veff(mf.mol, dm) e1 = np.einsum('ij,ji->', h1e, dm).real e_coul = np.einsum('ij,ji->', vhf, dm).real * .5 - mf.scf_summary['e1'] = stop_grad(e1) - mf.scf_summary['e2'] = stop_grad(e_coul) - logger.debug(mf, 'E1 = %s E_coul = %s', - stop_grad(e1), stop_grad(e_coul)) + mf.scf_summary['e1'] = e1 + mf.scf_summary['e2'] = e_coul + logger.debug(mf, 'E1 = %s E_coul = %s', e1, e_coul) return e1+e_coul, e_coul @wraps(pyscf_hf.make_rdm1) def make_rdm1(mo_coeff, mo_occ, **kwargs): mocc = mo_coeff[:,mo_occ>0] - dm = np.dot(mocc*mo_occ[mo_occ>0], mocc.conj().T) + dm = (mocc*mo_occ[mo_occ>0]) @ mocc.conj().T return dm @wraps(pyscf_hf.level_shift) def level_shift(s, d, f, factor): - dm_vir = s - reduce(np.dot, (s, d, s)) + dm_vir = s - s @ d @ s return f + dm_vir * factor @@ -276,7 +269,7 @@ def level_shift(s, d, f, factor): def dip_moment(mol, dm, unit='Debye', verbose=logger.NOTE, **kwargs): log = logger.new_logger(mol, verbose) - if 'unit_symbol' in kwargs: # pragma: no cover + if 'unit_symbol' in kwargs: log.warn('Kwarg "unit_symbol" was deprecated. It was replaced by kwarg ' 'unit since PySCF-1.5.') unit = kwargs['unit_symbol'] @@ -287,10 +280,10 @@ def dip_moment(mol, dm, unit='Debye', verbose=logger.NOTE, **kwargs): with mol.with_common_orig((0,0,0)): ao_dip = mol.intor_symmetric('int1e_r', comp=3) - el_dip = np.einsum('xij,ji->x', ao_dip, dm).real + el_dip = np.einsum('xij,ji->x', np.asarray(ao_dip), dm).real - charges = mol.atom_charges() - coords = mol.atom_coords() + charges = np.asarray(mol.atom_charges(), dtype=float) + coords = np.asarray(mol.atom_coords()) nucl_dip = np.einsum('i,ix->x', charges, coords) mol_dip = nucl_dip - el_dip @@ -304,8 +297,8 @@ def dip_moment(mol, dm, unit='Debye', verbose=logger.NOTE, **kwargs): def damping(s, d, f, factor): - dm_vir = np.eye(s.shape[0]) - np.dot(s, d) - f0 = reduce(np.dot, (dm_vir, f, d, s)) + dm_vir = np.eye(s.shape[0]) - s @ d + f0 = dm_vir @ f @ d @ s f0 = (f0 + f0.conj().T) * (factor/(factor+1.)) return f - f0 @@ -343,24 +336,23 @@ def get_fock(mf, h1e=None, s1e=None, vhf=None, dm=None, cycle=-1, diis=None, @util.pytree_node(Traced_Attributes, num_args=1) class SCF(pyscf_hf.SCF): - ''' - A subclass of :class:`pyscf.scf.hf.SCF` where the following - attributes can be traced. - - Attributes: - mol : :class:`pyscfad.gto.Mole` object - Molecular structure and global options. - mo_coeff : array - Molecular orbital coefficients. - mo_energy : array - Molecular orbital energies. - _eri : array - Two electron repulsion integrals. - ''' + """Subclass of :class:`pyscf.scf.hf.SCF` with traceable attributes. + + Attributes + ---------- + mol : :class:`pyscfad.gto.Mole` + :class:`pyscfad.gto.Mole` instance. + mo_coeff : array + MO coefficients. + mo_energy : array + MO energies. + _eri : array + Two-electron repulsion integrals. + """ DIIS = SCF_DIIS def __init__(self, mol, **kwargs): - pyscf_hf.SCF.__init__(self, mol) + super().__init__(mol) self.__dict__.update(kwargs) def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, @@ -382,21 +374,30 @@ def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True, vj, vk = dot_eri_dm(_eri, dm, hermi, with_j, with_k) return vj, vk - def get_init_guess(self, mol=None, key='minao'): + def get_init_guess(self, mol=None, key='minao', **kwargs): if mol is None: mol = self.mol - mol = stop_grad(mol) - dm0 = pyscf_hf.SCF.get_init_guess(self, mol, key) + dm0 = pyscf_hf.SCF.get_init_guess(self, stop_grad(mol), key) dm0 = numpy.asarray(dm0) #remove tags return dm0 # pylint: disable=arguments-differ def kernel(self, dm0=None, **kwargs): + self.dump_flags() self.build(self.mol) - self.converged, self.e_tot, \ - self.mo_energy, self.mo_coeff, self.mo_occ = \ - kernel(self, self.conv_tol, self.conv_tol_grad, - dm0=dm0, **kwargs) + + if self.max_cycle > 0 or self.mo_coeff is None: + self.converged, self.e_tot, \ + self.mo_energy, self.mo_coeff, self.mo_occ = \ + kernel(self, self.conv_tol, self.conv_tol_grad, + dm0=dm0, callback=self.callback, + conv_check=self.conv_check, **kwargs) + else: + self.e_tot = kernel(self, self.conv_tol, self.conv_tol_grad, + dm0=dm0, callback=self.callback, + conv_check=self.conv_check, **kwargs)[1] + + self._finalize() return self.e_tot def _eigh(self, h, s): @@ -406,16 +407,33 @@ def eig(self, h, s): return self._eigh(h, s) def energy_grad(self, dm0=None, mode='rev'): - ''' - Energy gradient with respect to AO parameters computed by AD. - In principle, MO response is not needed, and we can just take - the derivative of eigen decomposition with converged - density matrix. But here it is implemented in this way to show - the difference between unrolling for loops and implicit differentiation. - - NOTE: - The attributes of the SCF instance will not be modified - ''' + """Computing energy gradients w.r.t AO parameters. + + In principle, MO response is not needed, and it is sufficient to + compute the gradient of the eigen decomposition with the converged + density matrix. But this function is implemented as to trace the SCF iterations + to show the difference between unrolling for loops and implicit differentiation. + + Parameters + ---------- + dm0 : array, optional + Input density matrix. + mode : string, default='rev' + Differentiating using the ``forward`` or ``reverse`` mode. + + Returns + ------- + mol : :class:`pyscfad.gto.Mole` + :class:`Mole` object that contains the gradients. + + Notes + ----- + The attributes of the :class:`SCF` instance will not be modified. + This function only works with the JAX backend. + + .. deprecated:: 0.2.0 + This function will be deprecated in PySCFAD 0.2.0. + """ if dm0 is None: try: dm0 = self.make_rdm1() @@ -457,6 +475,7 @@ def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): vj, vk = self.get_jk(mol, dm, hermi=hermi) return vj - vk * .5 + @wraps(pyscf_hf.SCF.dip_moment) def dip_moment(self, mol=None, dm=None, unit='Debye', verbose=logger.NOTE, **kwargs): if mol is None: @@ -473,6 +492,11 @@ def dump_chk(self, envs): overwrite_mol=False) return self + def energy_nuc(self): + # recompute nuclear energy to trace it + return self.mol.energy_nuc() + + check_sanity = stop_trace(pyscf_hf.SCF.check_sanity) make_rdm1 = module_method(make_rdm1, absences=['mo_coeff', 'mo_occ']) energy_elec = energy_elec get_fock = get_fock @@ -480,6 +504,14 @@ def dump_chk(self, envs): @util.pytree_node(Traced_Attributes, num_args=1) class RHF(SCF, pyscf_hf.RHF): + @wraps(pyscf_hf.RHF.check_sanity) + def check_sanity(self): + mol = self.mol + if mol.nelectron != 1 and mol.spin != 0: + logger.warn(self, 'Invalid number of electrons %d for RHF method.', + mol.nelectron) + return SCF.check_sanity(self) + @wraps(pyscf_hf.RHF.get_veff) def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): if mol is None: diff --git a/pyscfad/scf/rohf.py b/pyscfad/scf/rohf.py index 94ab35da..7d23a971 100644 --- a/pyscfad/scf/rohf.py +++ b/pyscfad/scf/rohf.py @@ -1,9 +1,10 @@ from functools import reduce, wraps import numpy -from jax import numpy as np from pyscf.scf import rohf as pyscf_rohf +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import logger, stop_grad +from pyscfad.ops import stop_grad +from pyscfad.lib import logger from pyscfad.scf import hf, uhf, chkfile @wraps(pyscf_rohf.energy_elec) @@ -173,7 +174,7 @@ def get_occ(mf, mo_energy=None, mo_coeff=None): @util.pytree_node(hf.Traced_Attributes, num_args=1) class ROHF(hf.SCF, pyscf_rohf.ROHF): def __init__(self, mol, **kwargs): - pyscf_rohf.ROHF.__init__(self, mol) + super().__init__(mol) self.__dict__.update(kwargs) def eig(self, fock, s): diff --git a/pyscfad/scf/uhf.py b/pyscfad/scf/uhf.py index 596d2952..f4b7eac8 100644 --- a/pyscfad/scf/uhf.py +++ b/pyscfad/scf/uhf.py @@ -1,10 +1,11 @@ from functools import wraps import numpy -from jax import numpy as np from pyscf.lib import module_method from pyscf.scf import uhf as pyscf_uhf +from pyscfad import numpy as np from pyscfad import util -from pyscfad.lib import logger, stop_grad +from pyscfad.ops import stop_trace +from pyscfad.lib import logger from pyscfad.scf import hf @@ -71,10 +72,9 @@ def energy_elec(mf, dm=None, h1e=None, vhf=None): e_coul =(np.einsum('ij,ji->', vhf[0], dm[0]) + np.einsum('ij,ji->', vhf[1], dm[1])) * .5 e_elec = (e1 + e_coul).real - mf.scf_summary['e1'] = stop_grad(e1).real - mf.scf_summary['e2'] = stop_grad(e_coul).real - logger.debug(mf, 'E1 = %s Ecoul = %s', - stop_grad(e1), stop_grad(e_coul).real) + mf.scf_summary['e1'] = e1.real + mf.scf_summary['e2'] = e_coul.real + logger.debug(mf, 'E1 = %s Ecoul = %s', e1, e_coul.real) return e_elec, e_coul @@ -91,7 +91,7 @@ def make_rdm1(mo_coeff, mo_occ, **kwargs): @util.pytree_node(hf.Traced_Attributes, num_args=1) class UHF(hf.SCF, pyscf_uhf.UHF): def __init__(self, mol, **kwargs): - pyscf_uhf.UHF.__init__(self, mol) + super().__init__(mol) self.__dict__.update(kwargs) def eig(self, h, s): @@ -117,6 +117,8 @@ def get_veff(self, mol=None, dm=None, dm_last=0, vhf_last=0, hermi=1): vhf += np.asarray(vhf_last) return vhf + spin_square = stop_trace(pyscf_uhf.UHF.spin_square) get_fock = get_fock make_rdm1 = module_method(make_rdm1, absences=['mo_coeff', 'mo_occ']) energy_elec = energy_elec + diff --git a/pyscfad/soscf/ciah.py b/pyscfad/soscf/ciah.py index 8f478121..e19fb45d 100644 --- a/pyscfad/soscf/ciah.py +++ b/pyscfad/soscf/ciah.py @@ -1,7 +1,8 @@ import numpy -from jax import numpy as np +# TODO add other backend for expm from jax.scipy.linalg import expm -from pyscfad.lib import jit +from pyscfad import numpy as np +from pyscfad.ops import jit def pack_uniq_var(mat): nmo = mat.shape[0] diff --git a/pyscfad/tdscf/rhf.py b/pyscfad/tdscf/rhf.py index f6f7f762..48821531 100644 --- a/pyscfad/tdscf/rhf.py +++ b/pyscfad/tdscf/rhf.py @@ -1,10 +1,11 @@ from functools import reduce import numpy -from jax import vmap -from jax import numpy as np from pyscf import symm from pyscf.scf import hf_symm from pyscf.tdscf import rhf as pyscf_tdrhf +from pyscfad import numpy as np +from pyscfad import ops +from pyscfad.ops import vmap, stop_grad from pyscfad import util from pyscfad.lib import logger, chkfile from pyscfad.lib.linalg_helper import davidson1 @@ -150,9 +151,9 @@ def body(idx, mo1_occ): # pylint: disable=abstract-method @util.pytree_node(Traced_Attributes, num_args=1) -class TDMixin(pyscf_tdrhf.TDMixin): +class TDBase(pyscf_tdrhf.TDBase): def __init__(self, mf, **kwargs): - pyscf_tdrhf.TDMixin.__init__(self, mf) + pyscf_tdrhf.TDBase.__init__(self, mf) self.__dict__.update(kwargs) def get_ab(self, mf=None): @@ -163,12 +164,12 @@ def get_ab(self, mf=None): def get_precond(self, hdiag): def precond(x, e, x0): diagd = hdiag - (e-self.level_shift) - diagd = diagd.at[abs(diagd)<1e-8].set(1e-8) + diagd = ops.index_update(diagd, ops.index[abs(diagd)<1e-8], 1e-8) return x/diagd return precond @util.pytree_node(Traced_Attributes, num_args=1) -class TDA(TDMixin, pyscf_tdrhf.TDA): +class TDA(TDBase, pyscf_tdrhf.TDA): def gen_vind(self, mf=None): if mf is None: mf = self._scf @@ -189,7 +190,7 @@ def kernel(self, x0=None, nstates=None): precond = self.get_precond(hdiag) if x0 is None: - x0 = self.init_guess(self._scf, self.nstates) + x0 = self.init_guess(stop_grad(self._scf), self.nstates) def pickeig(w, v, nroots, envs): idx = numpy.where(w > self.positive_eig_threshold)[0] diff --git a/pyscfad/tools/util.py b/pyscfad/tools/util.py index 41639b09..0c5e8e71 100644 --- a/pyscfad/tools/util.py +++ b/pyscfad/tools/util.py @@ -1,4 +1,4 @@ -from jax import numpy as np +from pyscfad import numpy as np from pyscfad.soscf.ciah import extract_rotation def rotate_mo1(mo_coeff, x): diff --git a/pyscfad/util.py b/pyscfad/util.py index 396fed69..bc744d48 100644 --- a/pyscfad/util.py +++ b/pyscfad/util.py @@ -1,4 +1,27 @@ -# pylint: disable=unused-import,useless-import-alias -from pyscfad._src.util import ( - pytree_node as pytree_node -) +from functools import partial +from pyscfad import ops + +def pytree_node(leaf_names, num_args=0): + """Class decorator that registers the underlying class as a pytree node. + + See `jax document `_ + for the definition of pytrees. + + Parameters + ---------- + leaf_names : list or tuple + Attributes of the class that are traced as pytree leaves. + num_args : int, optional + Number of positional arguments in ``leaf_names``. + This is useful when the ``__init__`` method of the class + has positional arguments that are named differently than + the actual attribute names. Default value is 0. + + Notes + ----- + The ``__init__`` method of the class can't have positional arguments + that are not included in ``leaf_names``. If ``num_args`` is greater + than 0, the sequence of positional arguments in ``leaf_names`` must + follow that in the ``__init__`` method. + """ + return partial(ops.class_as_pytree_node, leaf_names=leaf_names, num_args=num_args) diff --git a/pyscfad/version.py b/pyscfad/version.py index bbab0242..1276d025 100644 --- a/pyscfad/version.py +++ b/pyscfad/version.py @@ -1 +1 @@ -__version__ = "0.1.4" +__version__ = "0.1.5" diff --git a/pyscfadlib/pyproject.toml b/pyscfadlib/pyproject.toml index 8e10e5f9..feed643f 100644 --- a/pyscfadlib/pyproject.toml +++ b/pyscfadlib/pyproject.toml @@ -1,7 +1,7 @@ [tool.cibuildwheel] build-verbosity = 1 -test-requires = ["pytest", "pyscf==2.3.0"] +test-requires = ["pytest", "pyscf>=2.3"] test-command = "pytest {package}/test" test-skip = "cp38-macosx_arm64 cp311-macosx_arm64" diff --git a/pyscfadlib/pyscfadlib/libcint.patch b/pyscfadlib/pyscfadlib/libcint.patch index 9c45f4bd..4fedaac4 120000 --- a/pyscfadlib/pyscfadlib/libcint.patch +++ b/pyscfadlib/pyscfadlib/libcint.patch @@ -1 +1 @@ -libcint.patch.5.4 \ No newline at end of file +libcint.patch.6.1 \ No newline at end of file diff --git a/pyscfadlib/pyscfadlib/libcint.patch.6.1 b/pyscfadlib/pyscfadlib/libcint.patch.6.1 new file mode 100644 index 00000000..bd31045e --- /dev/null +++ b/pyscfadlib/pyscfadlib/libcint.patch.6.1 @@ -0,0 +1,30 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index f52e278..12159cd 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -168,22 +168,22 @@ if(ENABLE_STATIC) + set(BUILD_SHARED_LIBS 0) + endif() + +-add_library(cint ${cintSrc}) +-set_target_properties(cint PROPERTIES ++add_library(cintad ${cintSrc}) ++set_target_properties(cintad PROPERTIES + VERSION ${cint_VERSION} + SOVERSION ${cint_SOVERSION} + LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) + if(QUADMATH_FOUND) +- target_link_libraries(cint quadmath) ++ target_link_libraries(cintad quadmath) + endif() +-target_link_libraries(cint "-lm") ++target_link_libraries(cintad "-lm") + + + set(CintHeaders + ${PROJECT_SOURCE_DIR}/include/cint_funcs.h + ${PROJECT_BINARY_DIR}/include/cint.h) + +-install(TARGETS cint DESTINATION "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" COMPONENT "lib") ++install(TARGETS cintad DESTINATION "${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_LIBDIR}" COMPONENT "lib") + install(FILES ${CintHeaders} DESTINATION ${CMAKE_INSTALL_PREFIX}/${CMAKE_INSTALL_INCLUDEDIR} COMPONENT "dev") diff --git a/pyscfadlib/pyscfadlib/version.py b/pyscfadlib/pyscfadlib/version.py index 7525d199..66a87bb6 100644 --- a/pyscfadlib/pyscfadlib/version.py +++ b/pyscfadlib/pyscfadlib/version.py @@ -1 +1 @@ -__version__ = '0.1.4' +__version__ = '0.1.5' diff --git a/requirements.txt b/requirements.txt index 14d91c18..e47b07da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ scipy<1.12 h5py jax>=0.3.25 jaxlib>=0.3.25 -pyscf==2.3.0 -pyscfadlib==0.1.4 +pyscf>=2.3 +pyscfadlib>=0.1.4 #pyscf-properties @ git+https://github.com/fishjojo/properties.git@ad diff --git a/setup.py b/setup.py index c6d55f48..4b661d6f 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ 'h5py', 'jax>=0.3.25', 'jaxlib>=0.3.25', - 'pyscf==2.3.0', - 'pyscfadlib==0.1.4', + 'pyscf>=2.3', + 'pyscfadlib>=0.1.4', #'pyscf-properties @ git+https://github.com/fishjojo/properties.git@ad', ], url='https://github.com/fishjojo/pyscfad',