Skip to content

Commit

Permalink
added k and r space radial basis
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Oct 6, 2023
1 parent 365e79f commit 570fa99
Show file tree
Hide file tree
Showing 7 changed files with 730 additions and 9 deletions.
7 changes: 6 additions & 1 deletion python/rascaline/rascaline/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import metatensor

from .power_spectrum import PowerSpectrum # noqa
from .splines import RadialIntegralFromFunction, RadialIntegralSplinerBase # noqa
from .splines import ( # noqa
KSpaceSpliner,
RadialIntegralFromFunction,
RadialIntegralSplinerBase,
RealSpaceSpliner,
)


# path that can be used with cmake to access the rascaline library and headers
Expand Down
19 changes: 19 additions & 0 deletions python/rascaline/rascaline/utils/atomic_density.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC


class AtomicDensityBase(ABC):
...


class GaussianDensity(AtomicDensityBase):
def __init__(self, atomic_gaussian_width: float):
self.atomic_gaussian_width = atomic_gaussian_width


class DeltaDensity(AtomicDensityBase):
...


class LODEDensity(AtomicDensityBase):
def __init__(self, potential_exponent: int):
self.potential_exponent = potential_exponent
171 changes: 171 additions & 0 deletions python/rascaline/rascaline/utils/radial_basis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from abc import ABC, abstractmethod
from typing import Union

import numpy as np


try:
from scipy.integrate import quad

HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False


class RadialBasisBase(ABC):
"""
Base class for evaluating the radial basis.
:parameter orthonormalization_cutoff: Provide value if the radial integral should be
orthonormalized. If :py:obj:`None` no orthonormalization will be performed.
"""

def __init__(self, orthonormalization_cutoff: float):
self.orthonormalization_cutoff = orthonormalization_cutoff

@abstractmethod
def compute(
self, n: float, ell: float, integrand_positions: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
"""Method calculating the radial basis.
Explicitly implemented in child classes."""
...

def compute_derivative(
self, n: float, ell: float, integrand_positions: np.ndarray
) -> np.ndarray:
"""Derivative of the radial basis.
Used for radial integrals with delta like atomic densities."""
displacement = 1e-6
mean_abs_positions = np.abs(integrand_positions).mean()

if mean_abs_positions < 1.0:
raise ValueError(
"Numerically derivative of the radial integral can not be performed "
"since positions are too small. Mean of the absolute positions is "
f"{mean_abs_positions:.1e} but should be at least 1."
)

radial_basis_pos = self.compute(n, ell, integrand_positions + displacement / 2)
radial_basis_neg = self.compute(n, ell, integrand_positions - displacement / 2)

return (radial_basis_pos - radial_basis_neg) / displacement

def compute_gram_matrix(
self,
max_radial: int,
max_angular: int,
) -> np.ndarray:
"""Orthornomalize the basis.
:returns: orthornomalization matrix of shape (max_angular + 1, max_radial,
max_radial)
"""

if not HAS_SCIPY:
raise ValueError("Orthornomalization requires scipy!")

# Gram matrix (also called overlap matrix or inner product matrix)
gram_matrix = np.zeros((max_angular + 1, max_radial, max_radial))

def integrand(
integrand_positions: np.ndarray,
n1: int,
n2: int,
ell: int,
) -> np.ndarray:
return (
integrand_positions**2
* self.compute(n1, ell, integrand_positions)
* self.compute(n2, ell, integrand_positions)
)

for ell in range(max_angular + 1):
for n1 in range(max_radial):
for n2 in range(max_radial):
gram_matrix[ell, n1, n2] = quad(
func=integrand,
a=0,
b=self.orthonormalization_cutoff,
args=(n1, n2, ell),
)[0]

return gram_matrix

def compute_orthonormalization_matrix(
self,
max_radial: int,
max_angular: int,
) -> np.ndarray:
"""Compute orthonormalization matrix
:returns: orthornomalization matrix of shape (max_angular + 1, max_radial,
max_radial)
"""

gram_matrix = self.compute_gram_matrix(max_radial, max_angular)

# Get the normalization constants from the diagonal entries
normalizations = np.zeros((max_angular + 1, max_radial))

for ell in range(max_angular + 1):
for n in range(max_radial):
normalizations[ell, n] = 1 / np.sqrt(gram_matrix[ell, n, n])

# Rescale orthonormalization matrix to be defined
# in terms of the normalized (but not yet orthonormalized)
# basis functions
gram_matrix[ell, n, :] *= normalizations[ell, n]
gram_matrix[ell, :, n] *= normalizations[ell, n]

orthonormalization_matrix = np.zeros_like(gram_matrix)
for ell in range(max_angular + 1):
eigvals, eigvecs = np.linalg.eigh(gram_matrix[ell])
orthonormalization_matrix[ell] = (
eigvecs @ np.diag(np.sqrt(1.0 / eigvals)) @ eigvecs.T
)

# Rescale the orthonormalization matrix so that it
# works with respect to the primitive (not yet normalized)
# radial basis functions
for ell in range(max_angular + 1):
for n in range(max_radial):
orthonormalization_matrix[ell, :, n] *= normalizations[ell, n]

return orthonormalization_matrix


class GTOBasis(RadialBasisBase):
"""Primitive (not normolized nor orthonormlized) GTO radial basis."""

def __init__(self, max_radial, cutoff):
super().__init__(orthonormalization_cutoff=np.inf)
self.max_radial = max_radial
self.cutoff = cutoff
self.sigmas = np.ones(self.max_radial, dtype=float)

for i in range(1, self.max_radial):
self.sigmas[i] = np.sqrt(i)
self.sigmas *= self.cutoff / self.max_radial

def compute(
self, n: float, ell: float, integrand_positions: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
return integrand_positions**n * np.exp(
-0.5 * (integrand_positions / self.sigmas[n]) ** 2
)

def compute_derivative(
self, n: float, ell: float, integrand_positions: Union[float, np.ndarray]
) -> Union[float, np.ndarray]:
return n * integrand_positions ** (n - 1) * self.compute(
n, ell, integrand_positions
) - integrand_positions / self.sigmas[n] ** 2 * self.compute(
n, ell, integrand_positions
)


class MonomialBasis(RadialBasisBase):
...
Loading

0 comments on commit 570fa99

Please sign in to comment.