Skip to content

Commit

Permalink
update scipy.linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
fishjojo committed Jul 14, 2024
1 parent 2000f59 commit b051eae
Show file tree
Hide file tree
Showing 27 changed files with 735 additions and 44 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/build_pyscfadlib.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash

cd pyscfadlib/pyscfadlib
mkdir build
cmake -B build
cmake --build build -j2
rm -rf build
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-20.04]
os: [ubuntu-latest]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
environment: ci
steps:
Expand All @@ -24,10 +24,12 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: install pyscf
run: ./.github/workflows/install_pyscf.sh
- name: build pyscfadlib
run: ./.github/workflows/build_pyscfadlib.sh
- name: test
run: ./.github/workflows/run_test.sh
- name: Upload coverage to codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
with:
token: ${{secrets.CODECOV_TOKEN}}
files: ./pyscfad/coverage.xml
1 change: 0 additions & 1 deletion .github/workflows/install_pyscf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ pip install h5py
pip install jaxlib
pip install jax
pip install 'pyscf>=2.3,<2.7'
pip install 'pyscfadlib>=0.1.4'
5 changes: 2 additions & 3 deletions .github/workflows/lint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env bash
export OMP_NUM_THREADS=1
export PYTHONPATH=$(pwd):$(pwd)/pyscf:$PYTHONPATH
echo "pyscfad = True" >> $HOME/.pyscf_conf.py

export PYTHONPATH=$(pwd):$(pwd)/pyscfadlib:$PYTHONPATH

pylint pyscfad
2 changes: 2 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ jobs:
python-version: 3.11
- name: install pyscf
run: ./.github/workflows/install_pyscf.sh
- name: build pyscfadlib
run: ./.github/workflows/build_pyscfadlib.sh
- name: install pylint
run: pip install 'pylint==2.17.7'
- name: style check
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/run_test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env bash
export OMP_NUM_THREADS=1
export PYTHONPATH=$(pwd):$PYTHONPATH

export PYTHONPATH=$(pwd):$(pwd)/pyscfadlib:$PYTHONPATH

export OMP_NUM_THREADS=2
pytest ./pyscfad --cov-report xml --cov=. --verbosity=1 --durations=10
2 changes: 1 addition & 1 deletion pyscfad/_src/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def reset(self):

config.set_default('pyscfad_scf_implicit_diff', False)
config.set_default('pyscfad_ccsd_implicit_diff', False)
config.set_default('pyscfad_ccsd_checkpoint', False)
#config.set_default('pyscfad_ccsd_checkpoint', False)
config.set_default('pyscfad_moleintor_opt', False)


Expand Down
3 changes: 3 additions & 0 deletions pyscfad/backend/_jax/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Custom jax.lax functions
"""
175 changes: 175 additions & 0 deletions pyscfad/backend/_jax/lax/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Custom jax.lax.linalg functions
"""
from functools import partial
import numpy as np
import jax
from jax import lax
from jax._src import ad_util
from jax._src import api
from jax._src import dispatch
from jax._src.core import (
Primitive,
ShapedArray,
is_constant_shape,
)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax as lax_internal
from jax._src.lax.linalg import (
_H,
symmetrize,
_nan_like_hlo,
_broadcasting_select_hlo,
)
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs

from pyscfadlib import lapack as lp

def eigh_gen(a, b, *,
lower=True,
itype=1,
deg_thresh=1e-9):
a = symmetrize(a)
b = symmetrize(b)
w, v = eigh_gen_p.bind(a, b, lower=lower, itype=itype, deg_thresh=deg_thresh)
return w, v

def _eigh_gen_impl(a, b, *, lower, itype, deg_thresh):
w, v = dispatch.apply_primitive(
eigh_gen_p,
a, b,
lower=lower,
itype=itype,
deg_thresh=deg_thresh)
return w, v

def _eigh_gen_abstract_eval(a, b, *, lower, itype, deg_thresh):
if isinstance(a, ShapedArray):
if a.ndim < 2 or a.shape[-2] != a.shape[-1]:
raise ValueError(
"Argument \'a\' to eigh must have shape [..., n, n], "
"but got shape {}".format(a.shape))

batch_dims = a.shape[:-2]
n = a.shape[-1]
v = a.update(shape=batch_dims + (n, n))
w = a.update(
shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(a.dtype))
else:
w, v = a, a
return w, v

def _eigh_gen_jvp_rule(primals, tangents, *, lower, itype, deg_thresh):
if itype != 1:
raise NotImplementedError(f"JVP for itype={itype} is not implemented.")
a, b = primals
n = a.shape[-1]
at, bt = tangents

w_real, v = eigh_gen_p.bind(
symmetrize(a),
symmetrize(b),
lower=lower,
itype=itype,
deg_thresh=deg_thresh)

w = w_real.astype(a.dtype)
eji = w[..., jnp.newaxis, :] - w[..., jnp.newaxis]
Fmat = ufuncs.reciprocal(
jnp.where(ufuncs.absolute(eji) > deg_thresh, eji, jnp.inf)
)

dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
precision=lax.Precision.HIGHEST)

if type(at) is ad_util.Zero:
vt_at_v = lax.zeros_like_array(a)
else:
vt_at_v = dot(_H(v), dot(at, v))

if type(bt) is not ad_util.Zero:
if a.ndim == 2:
w_diag = jnp.diag(w)
else:
batch_dims = a.shape[:-2]
w_diag = api.vmap(jnp.diag, in_axes=batch_dims, out_axes=batch_dims)(w)
vt_bt_v = dot(_H(v), dot(bt, v))
vt_bt_v_w = dot(vt_bt_v, w_diag)
vt_at_v -= vt_bt_v_w

dw = ufuncs.real(jnp.diagonal(vt_at_v, axis1=-2, axis2=-1))

F_vt_at_v = ufuncs.multiply(Fmat, vt_at_v)
if type(bt) is not ad_util.Zero:
bmask = jnp.where(ufuncs.absolute(eji) > deg_thresh, jnp.zeros_like(a), 1)
F_vt_at_v -= ufuncs.multiply(bmask, vt_bt_v) * .5

dv = dot(v, F_vt_at_v)
return (w_real, v), (dw, dv)

def _eigh_gen_cpu_lowering(ctx, a, b, *, lower, itype, deg_thresh):
del deg_thresh
a_aval, b_aval = ctx.avals_in
w_aval, v_aval = ctx.avals_out
n = a_aval.shape[-1]
batch_dims = a_aval.shape[:-2]

if not is_constant_shape(a_aval.shape[-2:]):
raise NotImplementedError(
"Shape polymorphism for native lowering for eigh is implemented "
f"only for the batch dimensions: {a_aval.shape}")

a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape)
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
v, w, info = lp.sygvd_hlo(a_aval.dtype, a, b,
a_shape_vals=a_shape_vals,
b_shape_vals=b_shape_vals,
lower=lower, itype=itype)

zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
v = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
broadcast_dimensions=range(len(batch_dims))),
select_v_aval,
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval)
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
w = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
broadcast_dimensions=range(len(batch_dims))),
select_w_aval,
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval)
return [w, v]

def _eigh_gen_batching_rule(batched_args, batch_dims, *,
lower, itype, deg_thresh):
a, b = batched_args
bd_a, bd_b = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims)
if i is not None)
a = batching.bdim_at_front(a, bd_a, size)
b = batching.bdim_at_front(b, bd_b, size)
return eigh_gen_p.bind(a, b,
lower=lower,
itype=itype,
deg_thresh=deg_thresh), (0, 0)

def _eigh_gen_lowering(*args, **kwargs):
raise NotImplementedError("Generalized eigh is only implemented for CPU.")

eigh_gen_p = Primitive('eigh_gen')
eigh_gen_p.multiple_results = True
eigh_gen_p.def_impl(_eigh_gen_impl)
eigh_gen_p.def_abstract_eval(_eigh_gen_abstract_eval)
ad.primitive_jvps[eigh_gen_p] = _eigh_gen_jvp_rule
batching.primitive_batchers[eigh_gen_p] = _eigh_gen_batching_rule
mlir.register_lowering(eigh_gen_p, _eigh_gen_lowering)
mlir.register_lowering(eigh_gen_p, _eigh_gen_cpu_lowering, platform='cpu')

3 changes: 3 additions & 0 deletions pyscfad/backend/_jax/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Custom jax.scipy functions
"""
108 changes: 108 additions & 0 deletions pyscfad/backend/_jax/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Custom jax.scipy.linalg functions
"""
import warnings
import jax
from jax import numpy as jnp
from jax import scipy as jsp
from jax._src.numpy.util import promote_dtypes_inexact
from jax._src.lax.linalg import _T, _H
from ..lax import linalg as lax_linalg

def eigh(a, b=None, *,
lower=True,
eigvals_only=False,
overwrite_a=False,
overwrite_b=False,
type=1,
check_finite=False,
subset_by_index=None,
subset_by_value=None,
driver=None,
deg_thresh=1e-9):
if overwrite_a or overwrite_b:
warnings.warn('Arguments \'overwrite_a\' and \'overwrite_b\' have no effect.')
if check_finite:
warnings.warn('Argument \'check_finite\' has no effect.')
if subset_by_index or subset_by_value:
raise NotImplementedError('Computing subset of eigenvalues is not supported.')
if driver:
warnings.warn('Argument \'driver\' has no effect.')

del (overwrite_a, overwrite_b, check_finite,
subset_by_index, subset_by_value, driver)
return _eigh(a, b, lower, type, eigvals_only, deg_thresh)

def _eigh(a, b, lower, itype, eigvals_only, deg_thresh):
if b is None:
b = jnp.zeros_like(a) + jnp.eye(a.shape[-1])

a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
w, v = lax_linalg.eigh_gen(a, b, lower=lower, itype=itype, deg_thresh=deg_thresh)

if eigvals_only:
return w
else:
return w, v

def svd(a, full_matrices=True, compute_uv=True,
overwrite_a=False, check_finite=False,
lapack_driver=None):
if overwrite_a:
warnings.warn('Argument \'overwrite_a\' has no effect.')
if check_finite:
warnings.warn('Argument \'check_finite\' has no effect.')
if lapack_driver:
warnings.warn('Argument \'lapack_driver\' has no effect.')
del overwrite_a, check_finite, lapack_driver

if not full_matrices or not compute_uv:
return jsp.linalg.svd(a,
full_matrices=full_matrices,
compute_uv=compute_uv)
else:
return _svd(a)

@jax.custom_jvp
def _svd(a):
return jsp.linalg.svd(a, full_matrices=True, compute_uv=True)

@_svd.defjvp
def _svd_jvp(primals, tangents):
A, = primals
dA, = tangents
m, n = A.shape
if m > n:
raise NotImplementedError('Use svd(A.conj().T) instead.')

U, s, Vt = _svd(A)
Ut = _H(U)
V = _H(Vt)
s_dim = s[..., jnp.newaxis, :]

dS = Ut @ dA @ V
ds = jnp.diagonal(dS, 0, -2, -1).real

s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
s_diffs_zeros = (s_diffs == 0).astype(s.dtype)
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros

dP1 = dS[..., :, :m]
dP2 = dS[..., :, m:]
dSS = dP1 * s_dim
SdS = _T(s_dim) * dP1

dU = U @ (F * (dSS + _H(dSS)))
dD1 = F * (SdS + _H(SdS))

s_zeros = (s == 0).astype(s.dtype)
s_inv = 1 / (s + s_zeros) - s_zeros
dD2 = s_inv[..., :, jnp.newaxis] * dP2

dV = jnp.zeros_like(V)
dV = dV.at[..., :m, :m].set(dD1)
dV = dV.at[..., :m, m:].set(-dD2)
dV = dV.at[..., m:, :m].set(dD2.conj().T)
dV = V @ dV
return (U, s, Vt), (dU, ds, _H(dV))

2 changes: 1 addition & 1 deletion pyscfad/backend/_numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __getattr__(self, name):

backend = NumpyBackend(np)

# FIXME maybe separate ops from numpy
# TODO maybe separate ops from numpy
backend._cache['is_array'] = is_array
backend._cache['to_numpy'] = to_numpy
backend._cache['stop_gradient'] = stop_gradient
Expand Down
9 changes: 4 additions & 5 deletions pyscfad/backend/_torch/linalg.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import torch
from .core import convert_to_tensor

def cholesky(a, **kwargs):
a = convert_to_tensor(a)
a = torch.as_tensor(a)
return torch.linalg.cholesky(a, **kwargs)

def eigh(a, UPLO='L', **kwargs):
a = convert_to_tensor(a)
a = torch.as_tensor(a)
return torch.linalg.eigh(a, UPLO, **kwargs)

def inv(a, **kwargs):
a = convert_to_tensor(a)
a = torch.as_tensor(a)
return torch.linalg.inv(a, **kwargs)

def norm(x, ord=None, axis=None, keepdims=False, **kwargs):
x = convert_to_tensor(x)
x = torch.as_tensor(x)
return torch.linalg.norm(x, ord, axis, keepdims, **kwargs)
Loading

0 comments on commit b051eae

Please sign in to comment.