-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
735 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env bash | ||
|
||
cd pyscfadlib/pyscfadlib | ||
mkdir build | ||
cmake -B build | ||
cmake --build build -j2 | ||
rm -rf build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
#!/usr/bin/env bash | ||
export OMP_NUM_THREADS=1 | ||
export PYTHONPATH=$(pwd):$(pwd)/pyscf:$PYTHONPATH | ||
echo "pyscfad = True" >> $HOME/.pyscf_conf.py | ||
|
||
export PYTHONPATH=$(pwd):$(pwd)/pyscfadlib:$PYTHONPATH | ||
|
||
pylint pyscfad |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
#!/usr/bin/env bash | ||
export OMP_NUM_THREADS=1 | ||
export PYTHONPATH=$(pwd):$PYTHONPATH | ||
|
||
export PYTHONPATH=$(pwd):$(pwd)/pyscfadlib:$PYTHONPATH | ||
|
||
export OMP_NUM_THREADS=2 | ||
pytest ./pyscfad --cov-report xml --cov=. --verbosity=1 --durations=10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Custom jax.lax functions | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
""" | ||
Custom jax.lax.linalg functions | ||
""" | ||
from functools import partial | ||
import numpy as np | ||
import jax | ||
from jax import lax | ||
from jax._src import ad_util | ||
from jax._src import api | ||
from jax._src import dispatch | ||
from jax._src.core import ( | ||
Primitive, | ||
ShapedArray, | ||
is_constant_shape, | ||
) | ||
from jax._src.interpreters import ad | ||
from jax._src.interpreters import batching | ||
from jax._src.interpreters import mlir | ||
from jax._src.lax import lax as lax_internal | ||
from jax._src.lax.linalg import ( | ||
_H, | ||
symmetrize, | ||
_nan_like_hlo, | ||
_broadcasting_select_hlo, | ||
) | ||
from jax._src.numpy import lax_numpy as jnp | ||
from jax._src.numpy import ufuncs | ||
|
||
from pyscfadlib import lapack as lp | ||
|
||
def eigh_gen(a, b, *, | ||
lower=True, | ||
itype=1, | ||
deg_thresh=1e-9): | ||
a = symmetrize(a) | ||
b = symmetrize(b) | ||
w, v = eigh_gen_p.bind(a, b, lower=lower, itype=itype, deg_thresh=deg_thresh) | ||
return w, v | ||
|
||
def _eigh_gen_impl(a, b, *, lower, itype, deg_thresh): | ||
w, v = dispatch.apply_primitive( | ||
eigh_gen_p, | ||
a, b, | ||
lower=lower, | ||
itype=itype, | ||
deg_thresh=deg_thresh) | ||
return w, v | ||
|
||
def _eigh_gen_abstract_eval(a, b, *, lower, itype, deg_thresh): | ||
if isinstance(a, ShapedArray): | ||
if a.ndim < 2 or a.shape[-2] != a.shape[-1]: | ||
raise ValueError( | ||
"Argument \'a\' to eigh must have shape [..., n, n], " | ||
"but got shape {}".format(a.shape)) | ||
|
||
batch_dims = a.shape[:-2] | ||
n = a.shape[-1] | ||
v = a.update(shape=batch_dims + (n, n)) | ||
w = a.update( | ||
shape=batch_dims + (n,), | ||
dtype=lax_internal._complex_basetype(a.dtype)) | ||
else: | ||
w, v = a, a | ||
return w, v | ||
|
||
def _eigh_gen_jvp_rule(primals, tangents, *, lower, itype, deg_thresh): | ||
if itype != 1: | ||
raise NotImplementedError(f"JVP for itype={itype} is not implemented.") | ||
a, b = primals | ||
n = a.shape[-1] | ||
at, bt = tangents | ||
|
||
w_real, v = eigh_gen_p.bind( | ||
symmetrize(a), | ||
symmetrize(b), | ||
lower=lower, | ||
itype=itype, | ||
deg_thresh=deg_thresh) | ||
|
||
w = w_real.astype(a.dtype) | ||
eji = w[..., jnp.newaxis, :] - w[..., jnp.newaxis] | ||
Fmat = ufuncs.reciprocal( | ||
jnp.where(ufuncs.absolute(eji) > deg_thresh, eji, jnp.inf) | ||
) | ||
|
||
dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, | ||
precision=lax.Precision.HIGHEST) | ||
|
||
if type(at) is ad_util.Zero: | ||
vt_at_v = lax.zeros_like_array(a) | ||
else: | ||
vt_at_v = dot(_H(v), dot(at, v)) | ||
|
||
if type(bt) is not ad_util.Zero: | ||
if a.ndim == 2: | ||
w_diag = jnp.diag(w) | ||
else: | ||
batch_dims = a.shape[:-2] | ||
w_diag = api.vmap(jnp.diag, in_axes=batch_dims, out_axes=batch_dims)(w) | ||
vt_bt_v = dot(_H(v), dot(bt, v)) | ||
vt_bt_v_w = dot(vt_bt_v, w_diag) | ||
vt_at_v -= vt_bt_v_w | ||
|
||
dw = ufuncs.real(jnp.diagonal(vt_at_v, axis1=-2, axis2=-1)) | ||
|
||
F_vt_at_v = ufuncs.multiply(Fmat, vt_at_v) | ||
if type(bt) is not ad_util.Zero: | ||
bmask = jnp.where(ufuncs.absolute(eji) > deg_thresh, jnp.zeros_like(a), 1) | ||
F_vt_at_v -= ufuncs.multiply(bmask, vt_bt_v) * .5 | ||
|
||
dv = dot(v, F_vt_at_v) | ||
return (w_real, v), (dw, dv) | ||
|
||
def _eigh_gen_cpu_lowering(ctx, a, b, *, lower, itype, deg_thresh): | ||
del deg_thresh | ||
a_aval, b_aval = ctx.avals_in | ||
w_aval, v_aval = ctx.avals_out | ||
n = a_aval.shape[-1] | ||
batch_dims = a_aval.shape[:-2] | ||
|
||
if not is_constant_shape(a_aval.shape[-2:]): | ||
raise NotImplementedError( | ||
"Shape polymorphism for native lowering for eigh is implemented " | ||
f"only for the batch dimensions: {a_aval.shape}") | ||
|
||
a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) | ||
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape) | ||
v, w, info = lp.sygvd_hlo(a_aval.dtype, a, b, | ||
a_shape_vals=a_shape_vals, | ||
b_shape_vals=b_shape_vals, | ||
lower=lower, itype=itype) | ||
|
||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) | ||
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") | ||
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) | ||
v = _broadcasting_select_hlo( | ||
ctx, | ||
mlir.broadcast_in_dim(ctx, ok, select_v_aval, | ||
broadcast_dimensions=range(len(batch_dims))), | ||
select_v_aval, | ||
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval) | ||
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) | ||
w = _broadcasting_select_hlo( | ||
ctx, | ||
mlir.broadcast_in_dim(ctx, ok, select_w_aval, | ||
broadcast_dimensions=range(len(batch_dims))), | ||
select_w_aval, | ||
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval) | ||
return [w, v] | ||
|
||
def _eigh_gen_batching_rule(batched_args, batch_dims, *, | ||
lower, itype, deg_thresh): | ||
a, b = batched_args | ||
bd_a, bd_b = batch_dims | ||
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) | ||
if i is not None) | ||
a = batching.bdim_at_front(a, bd_a, size) | ||
b = batching.bdim_at_front(b, bd_b, size) | ||
return eigh_gen_p.bind(a, b, | ||
lower=lower, | ||
itype=itype, | ||
deg_thresh=deg_thresh), (0, 0) | ||
|
||
def _eigh_gen_lowering(*args, **kwargs): | ||
raise NotImplementedError("Generalized eigh is only implemented for CPU.") | ||
|
||
eigh_gen_p = Primitive('eigh_gen') | ||
eigh_gen_p.multiple_results = True | ||
eigh_gen_p.def_impl(_eigh_gen_impl) | ||
eigh_gen_p.def_abstract_eval(_eigh_gen_abstract_eval) | ||
ad.primitive_jvps[eigh_gen_p] = _eigh_gen_jvp_rule | ||
batching.primitive_batchers[eigh_gen_p] = _eigh_gen_batching_rule | ||
mlir.register_lowering(eigh_gen_p, _eigh_gen_lowering) | ||
mlir.register_lowering(eigh_gen_p, _eigh_gen_cpu_lowering, platform='cpu') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
""" | ||
Custom jax.scipy functions | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
""" | ||
Custom jax.scipy.linalg functions | ||
""" | ||
import warnings | ||
import jax | ||
from jax import numpy as jnp | ||
from jax import scipy as jsp | ||
from jax._src.numpy.util import promote_dtypes_inexact | ||
from jax._src.lax.linalg import _T, _H | ||
from ..lax import linalg as lax_linalg | ||
|
||
def eigh(a, b=None, *, | ||
lower=True, | ||
eigvals_only=False, | ||
overwrite_a=False, | ||
overwrite_b=False, | ||
type=1, | ||
check_finite=False, | ||
subset_by_index=None, | ||
subset_by_value=None, | ||
driver=None, | ||
deg_thresh=1e-9): | ||
if overwrite_a or overwrite_b: | ||
warnings.warn('Arguments \'overwrite_a\' and \'overwrite_b\' have no effect.') | ||
if check_finite: | ||
warnings.warn('Argument \'check_finite\' has no effect.') | ||
if subset_by_index or subset_by_value: | ||
raise NotImplementedError('Computing subset of eigenvalues is not supported.') | ||
if driver: | ||
warnings.warn('Argument \'driver\' has no effect.') | ||
|
||
del (overwrite_a, overwrite_b, check_finite, | ||
subset_by_index, subset_by_value, driver) | ||
return _eigh(a, b, lower, type, eigvals_only, deg_thresh) | ||
|
||
def _eigh(a, b, lower, itype, eigvals_only, deg_thresh): | ||
if b is None: | ||
b = jnp.zeros_like(a) + jnp.eye(a.shape[-1]) | ||
|
||
a, b = promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b)) | ||
w, v = lax_linalg.eigh_gen(a, b, lower=lower, itype=itype, deg_thresh=deg_thresh) | ||
|
||
if eigvals_only: | ||
return w | ||
else: | ||
return w, v | ||
|
||
def svd(a, full_matrices=True, compute_uv=True, | ||
overwrite_a=False, check_finite=False, | ||
lapack_driver=None): | ||
if overwrite_a: | ||
warnings.warn('Argument \'overwrite_a\' has no effect.') | ||
if check_finite: | ||
warnings.warn('Argument \'check_finite\' has no effect.') | ||
if lapack_driver: | ||
warnings.warn('Argument \'lapack_driver\' has no effect.') | ||
del overwrite_a, check_finite, lapack_driver | ||
|
||
if not full_matrices or not compute_uv: | ||
return jsp.linalg.svd(a, | ||
full_matrices=full_matrices, | ||
compute_uv=compute_uv) | ||
else: | ||
return _svd(a) | ||
|
||
@jax.custom_jvp | ||
def _svd(a): | ||
return jsp.linalg.svd(a, full_matrices=True, compute_uv=True) | ||
|
||
@_svd.defjvp | ||
def _svd_jvp(primals, tangents): | ||
A, = primals | ||
dA, = tangents | ||
m, n = A.shape | ||
if m > n: | ||
raise NotImplementedError('Use svd(A.conj().T) instead.') | ||
|
||
U, s, Vt = _svd(A) | ||
Ut = _H(U) | ||
V = _H(Vt) | ||
s_dim = s[..., jnp.newaxis, :] | ||
|
||
dS = Ut @ dA @ V | ||
ds = jnp.diagonal(dS, 0, -2, -1).real | ||
|
||
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) | ||
s_diffs_zeros = (s_diffs == 0).astype(s.dtype) | ||
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros | ||
|
||
dP1 = dS[..., :, :m] | ||
dP2 = dS[..., :, m:] | ||
dSS = dP1 * s_dim | ||
SdS = _T(s_dim) * dP1 | ||
|
||
dU = U @ (F * (dSS + _H(dSS))) | ||
dD1 = F * (SdS + _H(SdS)) | ||
|
||
s_zeros = (s == 0).astype(s.dtype) | ||
s_inv = 1 / (s + s_zeros) - s_zeros | ||
dD2 = s_inv[..., :, jnp.newaxis] * dP2 | ||
|
||
dV = jnp.zeros_like(V) | ||
dV = dV.at[..., :m, :m].set(dD1) | ||
dV = dV.at[..., :m, m:].set(-dD2) | ||
dV = dV.at[..., m:, :m].set(dD2.conj().T) | ||
dV = V @ dV | ||
return (U, s, Vt), (dU, ds, _H(dV)) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,17 @@ | ||
import torch | ||
from .core import convert_to_tensor | ||
|
||
def cholesky(a, **kwargs): | ||
a = convert_to_tensor(a) | ||
a = torch.as_tensor(a) | ||
return torch.linalg.cholesky(a, **kwargs) | ||
|
||
def eigh(a, UPLO='L', **kwargs): | ||
a = convert_to_tensor(a) | ||
a = torch.as_tensor(a) | ||
return torch.linalg.eigh(a, UPLO, **kwargs) | ||
|
||
def inv(a, **kwargs): | ||
a = convert_to_tensor(a) | ||
a = torch.as_tensor(a) | ||
return torch.linalg.inv(a, **kwargs) | ||
|
||
def norm(x, ord=None, axis=None, keepdims=False, **kwargs): | ||
x = convert_to_tensor(x) | ||
x = torch.as_tensor(x) | ||
return torch.linalg.norm(x, ord, axis, keepdims, **kwargs) |
Oops, something went wrong.