diff --git a/docs/src/conf.py b/docs/src/conf.py index 3b45ea3f6..a8d8a21da 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -140,6 +140,7 @@ def setup(app): "metatensor": ("https://lab-cosmo.github.io/metatensor/latest/", None), "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), "skmatter": ("https://scikit-matter.readthedocs.io/en/latest/", None), "torch": ("https://pytorch.org/docs/stable/", None), "python": ("https://docs.python.org/3", None), diff --git a/docs/src/explanations/index.rst b/docs/src/explanations/index.rst index 8643a7235..a0c8c81d0 100644 --- a/docs/src/explanations/index.rst +++ b/docs/src/explanations/index.rst @@ -12,4 +12,5 @@ all about. :maxdepth: 1 concepts + radial-integral soap diff --git a/docs/src/explanations/radial-integral.rst b/docs/src/explanations/radial-integral.rst new file mode 100644 index 000000000..03cbd0eb4 --- /dev/null +++ b/docs/src/explanations/radial-integral.rst @@ -0,0 +1,278 @@ +.. _radial-integral: + +Concepts behind the radial integral +=================================== + +On this page, we describe the exact mathematical expression that are implemented in the +radial integral and the splined radial integral classes i.e. +:ref:`python-splined-radial-integral`. + +Preliminaries +------------- + +In this subsection, we briefly provide all the preliminary knowledge that is needed to +understand what the radial integral class is doing. The actual explanation for what is +computed in the radial integral class can be found in the next subsection (1.2). The +spherical expansion coefficients :math:`\langle anlm | \rho_i \rangle` are completely +determined by specifying two ingredients: + +- the atomic density function :math:`g(r)` as implemented in + :ref:`python-atomic-density`, often chosen to be a Gaussian or Delta function, that + defined the type of density under consideration. For a given center atom :math:`i` in + the structure, the total density function :math:`\rho_i(\boldsymbol{r})` around is + then defined as :math:`\rho_i(\boldsymbol{r}) = \sum_{j} g(\boldsymbol{r} - + \boldsymbol{r}_{ij})`. + +- the radial basis functions :math:`R_{nl}(r)` as implementated + :ref:`python-radial-basis`, on which the density :math:`\rho_i` is projected. To be + more precise, the actual basis functions are of the form + + .. math:: + + B_{nlm}(\boldsymbol{r}) = R_{nl}(r)Y_{lm}(\hat{r}), + + where :math:`Y_{lm}(\hat{r})` are the real spherical harmonics evaluated at the point + :math:`\hat{r}`, i.e. at the spherical angles :math:`(\theta, \phi)` that determine + the orientation of the unit vector :math:`\hat{r} = \boldsymbol{r}/r`. + +The spherical expansion coefficient :math:`\langle nlm | \rho_i \rangle` (we ommit the +chemical species index :math:`a` for simplicity) is then defined as + +.. math:: + + \begin{aligned} + \langle nlm | \rho_i \rangle & = \int \mathrm{d}^3\boldsymbol{r} + B_{nlm}(\boldsymbol{r}) \rho_i(\boldsymbol{r}) \\ \label{expansion_coeff_def} & = + \int \mathrm{d}^3\boldsymbol{r} R_{nl}(r)Y_{lm}(\hat{r})\rho_i(\boldsymbol{r}). + \end{aligned} + +In practice, the atom centered density :math:`\rho_i` is a superposition of the neighbor +contributions, namely :math:`\rho_i(\boldsymbol{r}) = \sum_{j} g(\boldsymbol{r} - +\boldsymbol{r}_{ij})`. Due to linearity of integration, evaluating the integral can then +be simplified to + +.. math:: + + \begin{aligned} + \langle nlm | \rho_i \rangle & = \int \mathrm{d}^3\boldsymbol{r} + R_{nl}(r)Y_{lm}(\hat{r})\rho_i(\boldsymbol{r}) \\ & = \int + \mathrm{d}^3\boldsymbol{r} R_{nl}(r)Y_{lm}(\hat{r})\left( \sum_{j} + g(\boldsymbol{r} - \boldsymbol{r}_{ij})\right) \\ & = \sum_{j} \int + \mathrm{d}^3\boldsymbol{r} R_{nl}(r)Y_{lm}(\hat{r}) g(\boldsymbol{r} - + \boldsymbol{r}_{ij}) \\ & = \sum_j \langle nlm | g;\boldsymbol{r}_{ij} \rangle. + \end{aligned} + +Thus, instead of having to compute integrals for arbitrary densities :math:`\rho_i`, we +have reduced our problem to the evaluation of integrals of the form + +.. math:: + + \begin{aligned} + \langle nlm | g;\boldsymbol{r}_{ij} \rangle & = \int \mathrm{d}^3\boldsymbol{r} + R_{nl}(r)Y_{lm}(\hat{r})g(\boldsymbol{r} - \boldsymbol{r}_{ij}), + \end{aligned} + +which are completely specified by + +- the density function :math:`g(\boldsymbol{r})` + +- the radial basis :math:`R_{nl}(r)` + +- the position of the neighbor atom :math:`\boldsymbol{r}_{ij}` relative to the center + atom + +The Radial Integral Class +------------------------- + +In the previous subsection, we have explained how the computation of the spherical +expansion coefficients can be reduced to integrals of the form + +.. math:: + + \begin{aligned} + \langle nlm | g;\boldsymbol{r}_{ij} \rangle & = \int \mathrm{d}^3\boldsymbol{r} + R_{nl}(r)Y_{lm}(\hat{r})g(\boldsymbol{r} - \boldsymbol{r}_{ij}). + \end{aligned} + +If the atomic density is spherically symmetric, i.e. if :math:`g(\boldsymbol{r}) = g(r)` +this integral can always be written in the following form: + +.. math:: + + \begin{aligned} \label{expansion_coeff_spherical_symmetric} + \langle nlm | g;\boldsymbol{r}_{ij} \rangle & = + Y_{lm}(\hat{r}_{ij})I_{nl}(r_{ij}). + \end{aligned} + +The key point is that the dependence on the vectorial position +:math:`\boldsymbol{r}_{ij}` is split into a factor that contains information about the +orientation of this vector, namely :math:`Y_{lm}(\hat{r}_{ij})`, which is just the +spherical harmonic evaluated at :math:`\hat{r}_{ij}`, and a remaining part that captures +the dependence on the distance of atom :math:`j` from the center atom :math:`i`, namely +:math:`I_{nl}(r_{ij})`, which we shall call the radial integral. The radial integral +class computes and outputs this radial part :math:`I_{nl}(r_{ij})`. Since the angular +part is just the usual spherical harmonic, this is the part that also depends on the +choice of atomic density :math:`g(r)`, as well as the radial basis :math:`R_{nl}(r)`. In +the following, for users only interested in a specific type of density, we provide the +explicit expressions of :math:`I_{nl}(r)` for the Delta and Gaussian densities, followed +by the general expression. + +Delta Densities +~~~~~~~~~~~~~~~ + +Here, we consider the especially simple special case where the atomic density function +:math:`g(\boldsymbol{r}) = \delta(\boldsymbol{r})`. Then: + +.. math:: + + \begin{aligned} + \langle nlm | g;\boldsymbol{r}_{ij} \rangle & = \int \mathrm{d}^3\boldsymbol{r} + R_{nl}(r)Y_{lm}(\hat{r})g(\boldsymbol{r} - \boldsymbol{r}_{ij}) \\ & = \int + \mathrm{d}^3\boldsymbol{r} R_{nl}(r)Y_{lm}(\hat{r})\delta(\boldsymbol{r} - + \boldsymbol{r}_{ij}) \\ & = R_{nl}(r) Y_{lm}(\hat{r}_{ij}) = + B_{nlm}(\boldsymbol{r}_{ij}). + \end{aligned} + +Thus, in this particularly simple case, the radial integral is simply the radial basis +function evaluated at the pair distance :math:`r_{ij}`, and we see that the integrals +have indeed the form presented above. + +Gaussian Densities +~~~~~~~~~~~~~~~~~~ + +Here, we consider another popular use case, where the atomic density function is a +Gaussian. In rascaline, we use the convention + +.. math:: + + g(r) = \frac{1}{(\pi \sigma^2)^{3/4}}e^{-\frac{r^2}{2\sigma^2}}. + +The prefactor was chosen such that the “L2-norm” of the Gaussian + +.. math:: + + \begin{aligned} + \|g\|^2 = \int \mathrm{d}^3\boldsymbol{r} |g(r)|^2 = 1, + \end{aligned} + +but does not affect the following calculations in any way. With these conventions, it +can be shown that the integral has the desired form + +.. math:: + + \begin{aligned} + \langle nlm | g;\boldsymbol{r}_{ij} \rangle & = \int \mathrm{d}^3\boldsymbol{r} + R_{nl}(r)Y_{lm}(\hat{r})g(\boldsymbol{r} - \boldsymbol{r}_{ij}) \\ & = + Y_{lm}(\hat{r}_{ij}) \cdot I_{nl}(r_{ij}) + \end{aligned} + +with + +.. math:: + + I_{nl}(r_{ij}) = \frac{1}{(\pi \sigma^2)^{3/4}}4\pi e^{-\frac{r_{ij}^2}{2\sigma^2}} + \int_0^\infty \mathrm{d}r r^2 R_{nl}(r) e^{-\frac{r^2}{2\sigma^2}} + i_l\left(\frac{rr_{ij}}{\sigma^2}\right), + +where :math:`i_l` is a modified spherical Bessel function. The first factor, of course, +is just the normalization factor of the Gaussian density. See the next two subsections +for a derivation of this formula. + +Derivation of the General Case +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We now derive an explicit formula for radial integral that works for any density. Let +:math:`g(r)` be a generic spherically symmetric density function. Our goal will be to +show that + +.. math:: + + \langle nlm | g;\boldsymbol{r}_{ij} \rangle = Y_{lm}(\hat{r}_{ij}) \left[2\pi + \int_0^\infty \mathrm{d}r r^2 R_{nl}(r) \int_{-1}^1 \mathrm{d}(\cos\theta) + P_l(\cos\theta) g(\sqrt{r^2+r_{ij}^2-2rr_{ij}\cos\theta}) \right] + +and thus we have the desired form :math:`\langle nlm | g;\boldsymbol{r}_{ij} \rangle = +Y_{lm}(\hat{r}_{ij}) I_{nl}(r_{ij})` with + +.. math:: + + \begin{aligned} + I_{nl}(r_{ij}) = 2\pi \int_0^\infty \mathrm{d}r r^2 R_{nl}(r) \int_{-1}^1 + \mathrm{d}u P_l(u) g(\sqrt{r^2+r_{ij}^2-2rr_{ij}u}), + \end{aligned} + +where :math:`P_l(x)` is the :math:`l`-th Legendre polynomial. + +Derivation of the explicit radial integral for Gaussian densities +----------------------------------------------------------------- + +Denoting by :math:`\theta(\boldsymbol{r},\boldsymbol{r}_{ij})` the angle between a +generic position vector :math:`\boldsymbol{r}` and the vector +:math:`\boldsymbol{r}_{ij}`, we can write + +.. math:: + + \begin{aligned} + g(\boldsymbol{r}- \boldsymbol{r}_{ij}) & = \frac{1}{(\pi + \sigma^2)^{3/4}}e^{-\frac{(\boldsymbol{r}- \boldsymbol{r}_{ij})^2}{2\sigma^2}} \\ + & = \frac{1}{(\pi + \sigma^2)^{3/4}}e^{-\frac{(r_{ij})^2}{2\sigma^2}}e^{-\frac{(\boldsymbol{r}^2- + 2\boldsymbol{r}\boldsymbol{r}_{ij})}{2\sigma^2}}, + \end{aligned} + +where the first factor no longer depends on the integration variable :math:`r`. + +Analytical Expressions for the GTO Basis +---------------------------------------- + +While the above integrals are hard to compute in general, the GTO basis is one of the +few sets of basis functions for which many of the integrals can be evaluated +analytically. This is also useful to test the correctness of more numerical +implementations. + +The primitive basis functions are defined as + +.. math:: + + \begin{aligned} + R_{nl}(r) = R_n(r) = r^n e^{-\frac{r^2}{2\sigma_n^2}} + \end{aligned} + +In this form, the basis functions are not yet orthonormal, which requires an extra +linear transformation. Since this transformation can also be applied after computing the +integrals, we simply evaluate the radial integral with respect to these primitive basis +functions. + +Real Space Integral for Gaussian Densities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +We now evaluate + +.. math:: + + \begin{aligned} + I_{nl}(r_{ij}) & = \frac{1}{(\pi \sigma^2)^{3/4}}4\pi + e^{-\frac{r_{ij}^2}{2\sigma^2}} \int_0^\infty \mathrm{d}r r^2 R_{nl}(r) + e^{-\frac{r^2}{2\sigma^2}} i_l\left(\frac{rr_{ij}}{\sigma^2}\right) \\ & = + \frac{1}{(\pi \sigma^2)^{3/4}}4\pi e^{-\frac{r_{ij}^2}{2\sigma^2}} \int_0^\infty + \mathrm{d}r r^2 r^n e^{-\frac{r^2}{2\sigma_n^2}} e^{-\frac{r^2}{2\sigma^2}} + i_l\left(\frac{rr_{ij}}{\sigma^2}\right), + \end{aligned} + +the result of which can be conveniently expressed using :math:`a=\frac{1}{2\sigma^2}`, +:math:`b_n = \frac{1}{2\sigma_n^2}`, :math:`n_\mathrm{eff}=\frac{n+l+3}{2}` and +:math:`l_\mathrm{eff}=l+\frac{3}{2}` as + +.. math:: + + \begin{aligned} + I_{nl}(r_{ij}) = \frac{1}{(\pi \sigma^2)^{3/4}} \cdot + \pi^{\frac{3}{2}}\frac{\Gamma\left(n_\mathrm{eff}\right)}{\Gamma\left(l_\mathrm{eff}\right)}\frac{(ar_{ij})^l}{(a+b)^{n_\mathrm{eff}}}M\left(n_\mathrm{eff},l_\mathrm{eff},\frac{a^2r_{ij}^2}{a^2+b^2}\right), + \end{aligned} + +where :math:`M(a,b,z)` is the confluent hypergeometric function (hyp1f1). + +.. _k-space-radial-integral-1: + +K-space Radial Integral +----------------------- diff --git a/docs/src/references/api/python/utils/atomic-density.rst b/docs/src/references/api/python/utils/atomic-density.rst new file mode 100644 index 000000000..0c57e7d00 --- /dev/null +++ b/docs/src/references/api/python/utils/atomic-density.rst @@ -0,0 +1 @@ +.. automodule:: rascaline.utils.atomic_density diff --git a/docs/src/references/api/python/utils/index.rst b/docs/src/references/api/python/utils/index.rst index abe7e11a0..3d88c4da9 100644 --- a/docs/src/references/api/python/utils/index.rst +++ b/docs/src/references/api/python/utils/index.rst @@ -7,5 +7,7 @@ Utility functions and classes that extend the core usage of rascaline. .. toctree:: :maxdepth: 1 + atomic-density + radial-basis power-spectrum splines diff --git a/docs/src/references/api/python/utils/radial-basis.rst b/docs/src/references/api/python/utils/radial-basis.rst new file mode 100644 index 000000000..461a39945 --- /dev/null +++ b/docs/src/references/api/python/utils/radial-basis.rst @@ -0,0 +1 @@ +.. automodule:: rascaline.utils.radial_basis diff --git a/python/rascaline-torch/tests/calculator.py b/python/rascaline-torch/tests/calculator.py index 4f6d4017a..28c2c6999 100644 --- a/python/rascaline-torch/tests/calculator.py +++ b/python/rascaline-torch/tests/calculator.py @@ -62,7 +62,6 @@ def test_compute(system): assert torch.all(gradient.values[i, 2, :] == torch.tensor([0, 1])) assert len(gradient.samples) == 8 - print(gradient.samples.values) assert gradient.samples.names == ["sample", "structure", "atom"] assert tuple(gradient.samples[0]) == (0, 0, 0) assert tuple(gradient.samples[1]) == (0, 0, 1) diff --git a/python/rascaline/rascaline/utils/__init__.py b/python/rascaline/rascaline/utils/__init__.py index 88d7e9e90..681944e06 100644 --- a/python/rascaline/rascaline/utils/__init__.py +++ b/python/rascaline/rascaline/utils/__init__.py @@ -3,7 +3,12 @@ import metatensor from .power_spectrum import PowerSpectrum # noqa -from .splines import RadialIntegralFromFunction, RadialIntegralSplinerBase # noqa +from .splines import ( # noqa + KSpaceSpliner, + RadialIntegralFromFunction, + RadialIntegralSplinerBase, + RealSpaceSpliner, +) # path that can be used with cmake to access the rascaline library and headers diff --git a/python/rascaline/rascaline/utils/atomic_density.py b/python/rascaline/rascaline/utils/atomic_density.py new file mode 100644 index 000000000..eb3744045 --- /dev/null +++ b/python/rascaline/rascaline/utils/atomic_density.py @@ -0,0 +1,166 @@ +r""" +.. _python-atomic-density: + +Atomic Density +============== + +the atomic density function :math:`g(r)`, often chosen to be a Gaussian or Delta +function, that defined the type of density under consideration. For a given center atom +:math:`i` in the structure, the total density function :math:`\rho_i(\boldsymbol{r})` +around is then defined as :math:`\rho_i(\boldsymbol{r}) = \sum_{j} g(\boldsymbol{r} - +\boldsymbol{r}_{ij})`. + +All atomic densities are based on + +.. autoclass:: rascaline.utils.atomic_density.AtomicDensityBase + :members: + :show-inheritance: + +In addition, we provide the following explicit implementations + +.. autoclass:: rascaline.utils.atomic_density.DeltaDensity + :members: + :show-inheritance: + +.. autoclass:: rascaline.utils.atomic_density.GaussianDensity + :members: + :show-inheritance: + +.. autoclass:: LodeDensity + :members: + :show-inheritance: + +""" +from abc import ABC, abstractmethod +from typing import Union + +import numpy as np + + +try: + from scipy.special import gamma, gammainc + + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + + +class AtomicDensityBase(ABC): + """Base class implemententing atomic densities.""" + + @abstractmethod + def compute(self, positions: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """Method calculating the atomic density. + + :param positions: positions to evaluate the atomic densities + :returns: evaluated atomic density + """ + ... + + +class DeltaDensity(AtomicDensityBase): + r"""Delta atomic densities of the form :math:`g(r)=\delta(r)`.""" + + def compute(self, positions: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + raise ValueError( + "Compute function of the delta density should never called directly." + ) + + +class GaussianDensity(AtomicDensityBase): + r"""Gaussian atomic density function. + + In rascaline, we use the convention + + .. math:: + + g(r) = \frac{1}{(\pi \sigma^2)^{3/4}}e^{-\frac{r^2}{2\sigma^2}} \,. + + The prefactor was chosen such that the "L2-norm" of the Gaussian + + .. math:: + + \|g\|^2 = \int \mathrm{d}^3\boldsymbol{r} |g(r)|^2 = 1\,, + + :param atomic_gaussian_width: Width of the atom-centered gaussian used to create the + atomic density + """ + + def __init__(self, atomic_gaussian_width: float): + self.atomic_gaussian_width = atomic_gaussian_width + + def compute(self, positions: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + atomic_gaussian_width_sq = self.atomic_gaussian_width**2 + return np.exp(-0.5 * positions**2 / atomic_gaussian_width_sq) / ( + np.pi * atomic_gaussian_width_sq + ) ** (3 / 4) + + +class LodeDensity(AtomicDensityBase): + r"""Smeared Power Law Densities/ + + It is defined as + + .. math:: + + g(r) = \frac{1}{\Gamma\left(\frac{p}{2}\right)} + \frac{\gamma\left( \frac{p}{2}, \frac{r^2}{2\sigma^2} \right)} + {r^p}, + + where :math:`\Gamma(z)` is the Gamma function and :math:`\gamma(a, x)` is the + incomplete lower Gamma function. However its evaluation at :math:`r=0` is + problematic because :math:`g(r)` is of the form :math:`0/0`. For practical + implementations, it is thus more convenient to rewrite the density as + + .. math:: + + g(r) = \frac{1}{\Gamma(a)}\frac{1}{\left(2 \sigma^2\right)^a} + \begin{cases} + \frac{1}{a} - \frac{x}{a+1} + \frac{x^2}{2(a+2)} + \mathcal{O}(x^3) + & x < 10^{-5} \\ + \frac{\gamma(a,x)}{x^a} + & x \geq 10^{-5} + \end{cases} + + It is convenient to use the expression for sufficiently small :math:`x` since the + relative weight of the first neglected term is on the order of :math:`1/6x^3`. + Therefore, the threshold :math:`x = 10^{-5}` leads to relative errors on the order + of the machine epsilon. + + :param atomic_gaussian_width: Width of the atom-centered gaussian used to create the + atomic density + :param potential_exponent: Potential exponent of the decorated atom density. + Currently only implemented for potential_exponent < 10. Some exponents can be + connected to SOAP or physics-based quantities: p=0 uses Gaussian densities as in + SOAP, p=1 uses 1/r Coulomb like densities, p=6 uses 1/r^6 dispersion like + densities. + """ + + def __init__(self, atomic_gaussian_width: float, potential_exponent: int): + if not HAS_SCIPY: + raise ValueError("LodeDensity requires scipy!") + + self.potential_exponent = potential_exponent + self.atomic_gaussian_width = atomic_gaussian_width + + def _f_sr(self, a, x): + return 1 / a - x / (a + 1) + x**2 / (2 * (a + 2)) + + def _f_lr(self, a, x): + return gamma(a) * gammainc(a, x) / x**a + + def compute(self, positions: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + if self.potential_exponent == 0: + return GaussianDensity.compute(self, positions=positions) + else: + atomic_gaussian_width_sq = self.atomic_gaussian_width**2 + a = self.potential_exponent / 2 + x = positions**2 / (2 * atomic_gaussian_width_sq) + + prefac = 1 / gamma(a) / (2 * atomic_gaussian_width_sq) ** a + + return prefac * np.where( + x < 1e-5, + self._f_sr(a, x), + self._f_lr(a, x), + ) diff --git a/python/rascaline/rascaline/utils/radial_basis.py b/python/rascaline/rascaline/utils/radial_basis.py new file mode 100644 index 000000000..284b99236 --- /dev/null +++ b/python/rascaline/rascaline/utils/radial_basis.py @@ -0,0 +1,325 @@ +r""" +.. _python-radial-basis: + +Radial Basis +============ + +Radial basis functions :math:`R_{nl}(\boldsymbol{r})` are besides :ref:`atomic densities +` :math:`\rho_i` the central ingridents to compute spherical +expansion coefficients :math:`\langle anlm\vert\rho_i\rangle`. Radial basis functions, +define how which the atomic density is projected. To be more precise, the actual basis +functions are of + +.. math:: + + B_{nlm}(\boldsymbol{r}) = R_{nl}(r)Y_{lm}(\hat{r}) \,, + +where :math:`Y_{lm}(\hat{r})` are the real spherical harmonics evaluated at the point +:math:`\hat{r}`, i.e. at the spherical angles :math:`(\theta, \phi)` that determine the +orientation of the unit vector :math:`\hat{r} = \boldsymbol{r}/r`. + +All radial basis function are based on + +.. autoclass:: rascaline.utils.radial_basis.RadialBasisBase + :members: + :show-inheritance: + +In addition, we provide the following explicit implementations + +.. autoclass:: rascaline.utils.radial_basis.GtoBasis + :members: + :show-inheritance: + +.. autoclass:: rascaline.utils.radial_basis.MonomialBasis + :members: + :show-inheritance: +""" + +from abc import ABC, abstractmethod +from typing import Union + +import numpy as np + + +try: + from scipy.integrate import quad + from scipy.optimize import fsolve + from scipy.special import spherical_jn + + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + + +class RadialBasisBase(ABC): + r""" + Base class for evaluating the radial basis. + + The class provides methods to evaluate the radial basis :math:`R_{nl}(r)` as well as + its (numerical) derivative with respect to positions :math:`r`. + + :parameter integeration_radius: Value up to which the radial integral should be + performed. The usual value is :math:`\infty`. + """ + + def __init__(self, integeration_radius: float): + self.integeration_radius = integeration_radius + + @abstractmethod + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + """Method calculating the radial basis. + + :param n: radial channel + :param ell: angular channel + :param integrand_positions: positions to evaluate the radial basis + :returns: evaluated radial basis + """ + ... + + def compute_derivative( + self, n: float, ell: float, integrand_positions: np.ndarray + ) -> np.ndarray: + """Derivative of the radial basis. + + Used for radial integrals with delta like atomic densities. If not defined in a + child class a numerical derivatice based on finite differences of + ``integrand_positions``. + + :param n: radial channel + :param ell: angular channel + :param integrand_positions: positions to evaluate the radial basis + :returns: evaluated derivative of the radial basis + """ + displacement = 1e-6 + mean_abs_positions = np.abs(integrand_positions).mean() + + if mean_abs_positions < 1.0: + raise ValueError( + "Numerically derivative of the radial integral can not be performed " + "since positions are too small. Mean of the absolute positions is " + f"{mean_abs_positions:.1e} but should be at least 1." + ) + + radial_basis_pos = self.compute(n, ell, integrand_positions + displacement / 2) + radial_basis_neg = self.compute(n, ell, integrand_positions - displacement / 2) + + return (radial_basis_pos - radial_basis_neg) / displacement + + def compute_gram_matrix( + self, + max_radial: int, + max_angular: int, + ) -> np.ndarray: + """Gram matrix of the current basis. + + :parameter max_radial: number of angular components + :parameter max_angular: number of radial components + :returns: orthornomalization matrix of shape + ``(max_angular + 1, max_radial, max_radial)`` + """ + + if not HAS_SCIPY: + raise ValueError("Orthornomalization requires scipy!") + + # Gram matrix (also called overlap matrix or inner product matrix) + gram_matrix = np.zeros((max_angular + 1, max_radial, max_radial)) + + def integrand( + integrand_positions: np.ndarray, + n1: int, + n2: int, + ell: int, + ) -> np.ndarray: + return ( + integrand_positions**2 + * self.compute(n1, ell, integrand_positions) + * self.compute(n2, ell, integrand_positions) + ) + + for ell in range(max_angular + 1): + for n1 in range(max_radial): + for n2 in range(max_radial): + gram_matrix[ell, n1, n2] = quad( + func=integrand, + a=0, + b=self.integeration_radius, + args=(n1, n2, ell), + )[0] + + return gram_matrix + + def compute_orthonormalization_matrix( + self, + max_radial: int, + max_angular: int, + ) -> np.ndarray: + """Compute orthonormalization matrix + + :parameter max_radial: number of angular components + :parameter max_angular: number of radial components + :returns: orthornomalization matrix of shape (max_angular + 1, max_radial, + max_radial) + """ + + gram_matrix = self.compute_gram_matrix(max_radial, max_angular) + + # Get the normalization constants from the diagonal entries + normalizations = np.zeros((max_angular + 1, max_radial)) + + for ell in range(max_angular + 1): + for n in range(max_radial): + normalizations[ell, n] = 1 / np.sqrt(gram_matrix[ell, n, n]) + + # Rescale orthonormalization matrix to be defined + # in terms of the normalized (but not yet orthonormalized) + # basis functions + gram_matrix[ell, n, :] *= normalizations[ell, n] + gram_matrix[ell, :, n] *= normalizations[ell, n] + + orthonormalization_matrix = np.zeros_like(gram_matrix) + for ell in range(max_angular + 1): + eigvals, eigvecs = np.linalg.eigh(gram_matrix[ell]) + orthonormalization_matrix[ell] = ( + eigvecs @ np.diag(np.sqrt(1.0 / eigvals)) @ eigvecs.T + ) + + # Rescale the orthonormalization matrix so that it + # works with respect to the primitive (not yet normalized) + # radial basis functions + for ell in range(max_angular + 1): + for n in range(max_radial): + orthonormalization_matrix[ell, :, n] *= normalizations[ell, n] + + return orthonormalization_matrix + + +class GtoBasis(RadialBasisBase): + r"""Primitive (not normalized nor orthonormlized) GTO radial basis. + + It is defined as + + .. math:: + + R_{nl}(r) = R_n(r) = r^n e^{-\frac{r^2}{2\sigma_n^2}}, + + where :math:`\sigma_n = \sqrt{n} r_\mathrm{cut}/n_\mathrm{max}` with + :math:`r_\mathrm{cut}` being the ``cutoff`` and :math:`n_\mathrm{max}` the maximal + number of radial components. + + :parameter cutoff: spherical cutoff for the radial basis + :parameter max_radial: number of radial components + """ + + def __init__(self, cutoff, max_radial): + # choosing infinity leads to problems when calculating the radial integral with + # `quad`! + super().__init__(integeration_radius=5 * cutoff) + self.max_radial = max_radial + self.cutoff = cutoff + self.sigmas = np.ones(self.max_radial, dtype=float) + + for n in range(1, self.max_radial): + self.sigmas[n] = np.sqrt(n) + self.sigmas *= self.cutoff / self.max_radial + + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return integrand_positions**n * np.exp( + -0.5 * (integrand_positions / self.sigmas[n]) ** 2 + ) + + def compute_derivative( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return n / integrand_positions * self.compute( + n, ell, integrand_positions + ) - integrand_positions / self.sigmas[n] ** 2 * self.compute( + n, ell, integrand_positions + ) + + +class MonomialBasis(RadialBasisBase): + r"""Monomial basis. + + Basis is consisting of functions + + .. math:: + R_{nl}(r) = r^{l+2n}, + + where :math:`n` runs from :math:`0,1,...,n_\mathrm{max}-1`. These capture precisely + the radial dependence if we compute the Taylor expansion of a generic funct m-lgion + defined in 3D space. + + :parameter cutoff: spherical cutoff for the radial basis + """ + + def __init__(self, cutoff): + super().__init__(integeration_radius=cutoff) + + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return integrand_positions ** (ell + 2 * n) + + def compute_derivative( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return (ell + 2 * n) * integrand_positions ** (ell + 2 * n - 1) + + +def _spherical_jn_swapped(z, n): + """spherical_jn with swapped arguments for usage in `fsolve`.""" + return spherical_jn(n=n, z=z) + + +class SphericalBesselBasis(RadialBasisBase): + """Spherical Bessel functions used in the Laplacian eigenstate (LE) basis. + + :parameter cutoff: spherical cutoff for the radial basis + :parameter max_radial: number of angular components + :parameter max_angular: number of radial components + """ + + def __init__(self, cutoff, max_radial, max_angular): + if not HAS_SCIPY: + raise ValueError("SphericalBesselBasis requires scipy!") + + super().__init__(integeration_radius=cutoff) + + self.max_radial = max_radial + self.max_angular = max_angular + self.roots = np.zeros([max_angular + 1, self.max_radial]) + + # Define target function and the estimated location of roots obtained from the + # asymptotic expansion of the spherical Bessel functions for large arguments x + for ell in range(self.max_angular + 1): + roots_guesses = np.pi * (np.arange(1, self.max_radial + 1) + ell / 2) + # Compute roots from initial guess using Newton method + for n, root_guess in enumerate(roots_guesses): + self.roots[ell, n] = fsolve( + func=_spherical_jn_swapped, x0=root_guess, args=(ell,) + )[0] + + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return spherical_jn( + ell, + integrand_positions * self.roots[ell, n] / self.integeration_radius, + ) + + def compute_derivative( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return ( + self.roots[ell, n] + / self.integeration_radius + * spherical_jn( + ell, + integrand_positions * self.roots[ell, n] / self.integeration_radius, + derivative=True, + ) + ) diff --git a/python/rascaline/rascaline/utils/splines.py b/python/rascaline/rascaline/utils/splines.py index af8534b5e..f85a9e43d 100644 --- a/python/rascaline/rascaline/utils/splines.py +++ b/python/rascaline/rascaline/utils/splines.py @@ -1,18 +1,43 @@ """ +.. _python-splined-radial-integral: + Splined radial integrals ======================== Classes for generating splines which can be used as tabulated radial integrals in the -various SOAP and LODE calculators. For an complete example of how to use these classes -see :ref:`example-splines`. +various SOAP and LODE calculators. + +All classes are based on .. autoclass:: rascaline.utils.RadialIntegralSplinerBase :members: :show-inheritance: +Rascaline splining provides several ways to compute a radial integral based. You may +chose and initlize a pre defined atomic density and radial basis and provide them to + +.. autoclass:: rascaline.utils.RealSpaceSpliner + :members: + :show-inheritance: + +or + +.. autoclass:: rascaline.utils.KSpaceSpliner + :members: + :show-inheritance: + +Note that :class:`RealSpaceSpliner` and :class:`KSpaceSpliner` require `scipy`_ to +be installed in order to perform the numercial integrals. + +Besides the two classes you can also explicitly provide functions for the radial +integral and its derivive and passing them to + .. autoclass:: rascaline.utils.RadialIntegralFromFunction :members: :show-inheritance: + + +.. _`scipy`: https://scipy.org """ from abc import ABC, abstractmethod @@ -21,6 +46,18 @@ import numpy as np +try: + from scipy.integrate import quad, quad_vec + from scipy.special import spherical_in, spherical_jn + + HAS_SCIPY = True +except ImportError: + HAS_SCIPY = False + +from .atomic_density import AtomicDensityBase, DeltaDensity, GaussianDensity +from .radial_basis import RadialBasisBase + + class RadialIntegralSplinerBase(ABC): """Base class for splining arbitrary radial integrals. @@ -31,6 +68,8 @@ class RadialIntegralSplinerBase(ABC): :parameter max_radial: number of angular components :parameter spline_cutoff: cutoff radius for the spline interpolation. This is also the maximal value that can be interpolated. + :parameter basis: Provide a :class:`RadialBasisBase` instance to orthornomalize the + radial integral. :parameter accuracy: accuracy of the numerical integration and the splining. Accuracy is reached when either the mean absolute error or the mean relative error gets below the ``accuracy`` threshold. @@ -41,11 +80,13 @@ def __init__( max_radial: int, max_angular: int, spline_cutoff: float, + basis: Optional[RadialBasisBase], accuracy: float, ): self.max_radial = max_radial self.max_angular = max_angular self.spline_cutoff = spline_cutoff + self.basis = basis self.accuracy = accuracy @abstractmethod @@ -55,7 +96,11 @@ def _radial_integral(self, n: int, ell: int, positions: np.ndarray) -> np.ndarra @property def _center_contribution(self) -> Union[None, np.ndarray]: - r"""Contribution of the central atom required for LODE calculations.""" + r"""Contribution of the central atom. + + Required for LODE calculations. The central atom contribution will not be + orthornomalized! + """ return None @@ -64,9 +109,9 @@ def _radial_integral_derivative( ) -> np.ndarray: """Method calculating the derivatice of the radial integral.""" displacement = 1e-6 - mean_abs_positions = np.abs(positions).mean() + mean_abs_positions = np.mean(np.abs(positions)) - if mean_abs_positions <= 1.0: + if mean_abs_positions < 1.0: raise ValueError( "Numerically derivative of the radial integral can not be performed " "since positions are too small. Mean of the absolute positions is " @@ -85,6 +130,7 @@ def _radial_integral_derivative( def _value_evaluator_3D( self, positions: np.ndarray, + orthonormalization_matrix: Optional[np.ndarray], derivative: bool, ): values = np.zeros([len(positions), self.max_angular + 1, self.max_radial]) @@ -97,6 +143,15 @@ def _value_evaluator_3D( else: values[:, ell, n] = self._radial_integral(n, ell, positions) + if orthonormalization_matrix is not None: + # For each l channel we do a dot product of the orthonormalization_matrix of + # shape (n, n) with the values which should have the shape (n, n_positions). + # To achieve the correct broadcasting we have to transpose twice. + for ell in range(self.max_angular + 1): + values[:, ell, :] = ( + orthonormalization_matrix[ell] @ values[:, ell, :].T + ).T + return values def compute( @@ -111,11 +166,22 @@ def compute( rascaline calculator. """ + if self.basis is not None: + orthonormalization_matrix = self.basis.compute_orthonormalization_matrix( + self.max_radial, self.max_angular + ) + else: + orthonormalization_matrix = None + def value_evaluator_3D(positions): - return self._value_evaluator_3D(positions, derivative=False) + return self._value_evaluator_3D( + positions, orthonormalization_matrix, derivative=False + ) def derivative_evaluator_3D(positions): - return self._value_evaluator_3D(positions, derivative=True) + return self._value_evaluator_3D( + positions, orthonormalization_matrix, derivative=True + ) if n_spline_points is not None: positions = np.linspace(0, self.spline_cutoff, n_spline_points) @@ -154,7 +220,12 @@ def derivative_evaluator_3D(positions): center_contribution = self._center_contribution if center_contribution is not None: - parameters["center_contribution"] = center_contribution + if self.basis is not None: + parameters["center_contribution"] = list( + orthonormalization_matrix[0] @ center_contribution + ) + else: + parameters["center_contribution"] = center_contribution return {"TabulatedRadialIntegral": parameters} @@ -293,7 +364,7 @@ def _compute_from_spline(self, positions): class RadialIntegralFromFunction(RadialIntegralSplinerBase): - r"""Compute the radial integral spline points based on a provided function. + r"""Compute radial integral spline points based on a provided function. :parameter radial_integral: Function to compute the radial integral. Function must take ``n``, ``l``, and ``positions`` as inputs, where ``n`` and ``l`` are @@ -358,8 +429,8 @@ class RadialIntegralFromFunction(RadialIntegralSplinerBase): The ``atomic_gaussian_width`` paramater is required by the calculator but will be will be ignored during the feature computation. - A more in depth example using a "rectangular" Laplacian eigenstate basis - is provided in the :ref:`example section`. + A more in depth example using a "rectangular" Laplacian eigenstate basis is provided + in the :ref:`example section`. """ def __init__( @@ -376,12 +447,20 @@ def __init__( ): self.radial_integral_function = radial_integral self.radial_integral_derivative_funcion = radial_integral_derivative + + if center_contribution is not None and len(center_contribution) != max_radial: + raise ValueError( + f"center contribution has {len(center_contribution)} entries but " + f"should be the same as max_radial ({max_radial})" + ) + self.center_contribution = center_contribution super().__init__( max_radial=max_radial, max_angular=max_angular, spline_cutoff=spline_cutoff, + basis=None, # do no orthornormlize the radial integral accuracy=accuracy, ) @@ -390,8 +469,6 @@ def _radial_integral(self, n: int, ell: int, positions: np.ndarray) -> np.ndarra @property def _center_contribution(self) -> Union[None, np.ndarray]: - # Test that ``len(self.center_contribution) == max_radial`` is performed by the - # calculator. return self.center_contribution def _radial_integral_derivative( @@ -401,3 +478,331 @@ def _radial_integral_derivative( return super()._radial_integral_derivative(n, ell, positions) else: return self.radial_integral_derivative_funcion(n, ell, positions) + + +class RealSpaceSpliner(RadialIntegralSplinerBase): + """Compute radial integral spline points for real space calculators. + + Use only in combination with a real space calculators like + :class:`rascaline.SphericalExpansion` or :class:`rascaline.SoapPowerSpectrum`. For + k-space spherical expansions use :class:`KSpaceSpliner`. + + If ``density`` is either :class:`rascaline.utils.atomic_density.DeltaDensity` or + :class:`rascaline.utils.atomic_density.GaussianDensity` the radial integral will be + partly solved analytical. These simpler expressions result in a faster and more + stable evaluation. + + :parameter cutoff: spherical cutoff for the radial basis + :parameter max_radial: number of angular components + :parameter max_angular: number of radial components + :parameter basis: instance defining the radial basis + :parameter density: instancel defining the atomic density + :parameter accuracy: accuracy of the numerical integration and the splining. + Accuracy is reached when either the mean absolute error or the mean relative + error gets below the ``accuracy`` threshold. + :raise ValueError: if `scipy`_ is not available + + Example + ------- + + .. seealso:: + :class:`KSpaceSpliner` for a spliner class that works with + :class:`rascaline.LodeSphericalExpansion` + """ + + def __init__( + self, + cutoff: float, + max_radial: int, + max_angular: int, + basis: RadialBasisBase, + density: AtomicDensityBase, + accuracy: float = 1e-8, + ): + if not HAS_SCIPY: + raise ValueError("Spliner class requires scipy!") + + self.density = density + + super().__init__( + max_radial=max_radial, + max_angular=max_angular, + spline_cutoff=cutoff, + basis=basis, + accuracy=accuracy, + ) + + def _radial_integral(self, n: int, ell: int, positions: np.ndarray) -> np.ndarray: + if type(self.density) is DeltaDensity: + return self._radial_integral_delta(n, ell, positions) + elif type(self.density) is GaussianDensity: + return self._radial_integral_gaussian(n, ell, positions) + else: + return self._radial_integral_custom(n, ell, positions) + + def _radial_integral_derivative( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + if type(self.density) is DeltaDensity: + return self._radial_integral_delta_derivative(n, ell, positions) + elif type(self.density) is GaussianDensity: + return self._radial_integral_gaussian_derivative(n, ell, positions) + else: + return self._radial_integral_custom_derivative(n, ell, positions) + + def _radial_integral_delta( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + return self.basis.compute(n, ell, positions) + + def _radial_integral_delta_derivative( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + return self.basis.compute_derivative(n, ell, positions) + + def _radial_integral_gaussian( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + atomic_gaussian_width_sq = self.density.atomic_gaussian_width**2 + + prefac = ( + (4 * np.pi) + / (np.pi * atomic_gaussian_width_sq) ** (3 / 4) + * np.exp(-0.5 * positions**2 / atomic_gaussian_width_sq) + ) + + def integrand( + integrand_position: float, n: int, ell: int, positions: np.array + ) -> np.ndarray: + return ( + integrand_position**2 + * self.basis.compute(n, ell, integrand_position) + * np.exp(-0.5 * integrand_position**2 / atomic_gaussian_width_sq) + * spherical_in( + ell, + integrand_position * positions / atomic_gaussian_width_sq, + ) + ) + + return ( + prefac + * quad_vec( + f=lambda x: integrand(x, n, ell, positions), + a=0, + b=self.basis.integeration_radius, + )[0] + ) + + def _radial_integral_gaussian_derivative( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + atomic_gaussian_width_sq = self.density.atomic_gaussian_width**2 + + prefac = ( + (4 * np.pi) + / (np.pi * atomic_gaussian_width_sq) ** (3 / 4) + * np.exp(-0.5 * positions**2 / atomic_gaussian_width_sq) + ) + + def integrand( + integrand_position: float, n: int, ell: int, positions: np.array + ) -> np.ndarray: + return ( + integrand_position**3 + * self.basis.compute(n, ell, integrand_position) + * np.exp(-(integrand_position**2) / (2 * atomic_gaussian_width_sq)) + * spherical_in( + ell, + integrand_position * positions / atomic_gaussian_width_sq, + derivative=True, + ) + ) + + return atomic_gaussian_width_sq**-1 * ( + prefac + * quad_vec( + f=lambda x: integrand(x, n, ell, positions), + a=0, + b=self.basis.integeration_radius, + )[0] + - positions * self._radial_integral_gaussian(n, ell, positions) + ) + + def _radial_integral_custom( + self, n: int, ell: int, positions: np.ndarray, derivative: bool + ) -> np.ndarray: + raise NotImplementedError( + "Radial integral with custom atomic densities is not implemented yet!" + ) + + def _radial_integral_custom_derivative( + self, n: int, ell: int, positions: np.ndarray, derivative: bool + ) -> np.ndarray: + raise NotImplementedError( + "Radial integral with custom atomic densities is not implemented yet!" + ) + + +class KSpaceSpliner(RadialIntegralSplinerBase): + r"""Compute radial integral spline points for k-space calculators. + + Use only in combination with a k/fourier space calculators like + :class:`rascaline.LodeSphericalExpansion`. For real space spherical expansions use + :class:`RealSpaceSpliner`. + + :parameter k_cutoff: spherical reciprocal cutoff + :parameter max_radial: number of angular components + :parameter max_angular: number of radial components + :parameter basis: instance defining the radial basis + :parameter density: instancel defining the atomic density + :parameter accuracy: accuracy of the numerical integration and the splining. + Accuracy is reached when either the mean absolute error or the mean relative + error gets below the ``accuracy`` threshold. + :raise ValueError: if `scipy`_ is not available + + Example + ------- + + First import the necessary classed and define hyper parameters for the spherical + expansions. + + >>> from rascaline import LodeSphericalExpansion + >>> from rascaline.utils.atomic_density import GaussianDensity + >>> from rascaline.utils.radial_basis import GtoBasis + + Note that ``cutoff`` defined below denotes the maximal distance for the projection + of the density. In contrast to SOAP, LODE also takes atoms outside of this + ``cutoff`` into account for the density. + + >>> cutoff = 2 + >>> max_radial = 6 + >>> max_angular = 4 + >>> atomic_gaussian_width = 1.0 + + :math:`1.2 \, \pi \, \sigma` where :math:`\sigma` is the ``atomic_gaussian_width`` + which is a reasonable value for most systems. + + >>> k_cutoff = 1.2 * np.pi / atomic_gaussian_width + + Next we initlize our radial basis and the density + + >>> basis = GtoBasis(cutoff=cutoff, max_radial=max_radial) + >>> density = GaussianDensity(atomic_gaussian_width=atomic_gaussian_width) + + And finally the actual spliner instance + + >>> spliner = KSpaceSpliner( + ... k_cutoff=k_cutoff, + ... max_radial=max_radial, + ... max_angular=max_angular, + ... basis=basis, + ... density=density, + ... ) + + As for all spliner classes you can use the output + :meth:`RadialIntegralSplinerBase.compute` method directly as the + ``radial_basis`` parameter. + + >>> calculator = LodeSphericalExpansion( + ... cutoff=cutoff, + ... max_radial=max_radial, + ... max_angular=max_angular, + ... center_atom_weight=1.0, + ... atomic_gaussian_width=atomic_gaussian_width, + ... potential_exponent=1, + ... radial_basis=spliner.compute(), + ... ) + + You can now use ``calculator`` to obtain the spherical expansion coefficents of your + systems. Note that the the spliner based used here will produce the same coefficents + as if ``radial_basis={"Gto": {}}`` would be used. + + .. seealso:: + :class:`RealSpaceSpliner` for a spliner class that works with + :class:`rascaline.SphericalExpansion` + """ + + def __init__( + self, + k_cutoff: float, + max_radial: int, + max_angular: int, + basis: RadialBasisBase, + density: AtomicDensityBase, + accuracy: float = 1e-8, + ): + if not HAS_SCIPY: + raise ValueError("Spliner class requires scipy!") + + self.density = density + + super().__init__( + max_radial=max_radial, + max_angular=max_angular, + basis=basis, + spline_cutoff=k_cutoff, # use k_cutoff here because we spline in k-space + accuracy=accuracy, + ) + + def _radial_integral(self, n: int, ell: int, positions: np.ndarray) -> np.ndarray: + def integrand( + integrand_position: float, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + return ( + integrand_position**2 + * self.basis.compute(n, ell, integrand_position) + * spherical_jn(ell, integrand_position * positions) + ) + + return quad_vec( + f=lambda x: integrand(x, n, ell, positions), + a=0, + b=self.basis.integeration_radius, + )[0] + + def _radial_integral_derivative( + self, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + def integrand( + integrand_position: float, n: int, ell: int, positions: np.ndarray + ) -> np.ndarray: + return ( + integrand_position**3 + * self.basis.compute(n, ell, integrand_position) + * spherical_jn(ell, integrand_position * positions, derivative=True) + ) + + return quad_vec( + f=lambda x: integrand(x, n, ell, positions), + a=0, + b=self.basis.integeration_radius, + )[0] + + @property + def _center_contribution(self) -> np.ndarray: + if type(self.density) is DeltaDensity: + center_contrib = self._center_contribution_delta + else: + center_contrib = self._center_contribution_custom + + return [np.sqrt(4 * np.pi) * center_contrib(n) for n in range(self.max_radial)] + + def _center_contribution_delta(self, n: int): + raise NotImplementedError( + "center contribution for delta disributions is not implemented yet." + ) + + def _center_contribution_custom(self, n: int): + def integrand(integrand_position: float, n: int) -> np.ndarray: + return ( + integrand_position**2 + * self.basis.compute(n, 0, integrand_position) + * self.density.compute(integrand_position) + ) + + return quad( + func=integrand, + a=0, + b=self.basis.integeration_radius, + args=(n), + )[0] diff --git a/python/rascaline/tests/utils/radial_basis.py b/python/rascaline/tests/utils/radial_basis.py new file mode 100644 index 000000000..424f2ec66 --- /dev/null +++ b/python/rascaline/tests/utils/radial_basis.py @@ -0,0 +1,91 @@ +from typing import Union + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from rascaline.utils.radial_basis import ( + GtoBasis, + MonomialBasis, + RadialBasisBase, + SphericalBesselBasis, +) + + +pytest.importorskip("scipy") + + +class RtonRadialBasis(RadialBasisBase): + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return integrand_positions**n + + +def test_radial_basis_gram(): + """Test that quad integration of the gram matrix is the same as an analytical.""" + + integeration_radius = 1 + max_radial = 4 + max_angular = 2 + + test_basis = RtonRadialBasis(integeration_radius=integeration_radius) + + numerical_gram = test_basis.compute_gram_matrix(max_radial, max_angular) + analytical_gram = np.zeros_like(numerical_gram) + + for ell in range(max_angular + 1): + for n1 in range(max_radial): + for n2 in range(max_radial): + exp = 3 + n1 + n2 + analytical_gram[ell, n1, n2] = integeration_radius**exp / exp + + assert_allclose(numerical_gram, analytical_gram) + + +def test_radial_basis_orthornormalization(): + integeration_radius = 1 + max_radial = 4 + max_angular = 2 + + test_basis = RtonRadialBasis(integeration_radius=integeration_radius) + + gram = test_basis.compute_gram_matrix(max_radial, max_angular) + ortho = test_basis.compute_orthonormalization_matrix(max_radial, max_angular) + + for ell in range(max_angular): + eye = ortho[ell] @ gram[ell] @ ortho[ell].T + assert_allclose(eye, np.eye(max_radial, max_radial), atol=1e-11) + + +@pytest.mark.parametrize( + "analytical_basis", + [ + GtoBasis(cutoff=4, max_radial=6), + MonomialBasis(cutoff=4), + SphericalBesselBasis(cutoff=4, max_radial=6, max_angular=4), + ], +) +def test_derivative(analytical_basis: RadialBasisBase): + """Finite difference test for testing the derivatice of a raidal basis""" + + class NumericalRadialBasis(RadialBasisBase): + def compute( + self, n: float, ell: float, integrand_positions: Union[float, np.ndarray] + ) -> Union[float, np.ndarray]: + return analytical_basis.compute(n, ell, integrand_positions) + + numerical_basis = NumericalRadialBasis(integeration_radius=np.inf) + + cutoff = 4 + max_radial = 6 + max_angular = 4 + positions = np.linspace(2, cutoff) + + for n in range(max_radial): + for ell in range(max_angular): + assert_allclose( + numerical_basis.compute_derivative(n, ell, positions), + analytical_basis.compute_derivative(n, ell, positions), + atol=1e-9, + ) diff --git a/python/rascaline/tests/utils/splines.py b/python/rascaline/tests/utils/splines.py index a08cc54e3..d5f2dd062 100644 --- a/python/rascaline/tests/utils/splines.py +++ b/python/rascaline/tests/utils/splines.py @@ -2,7 +2,16 @@ import pytest from numpy.testing import assert_allclose, assert_equal -from rascaline.utils import RadialIntegralFromFunction +from rascaline import LodeSphericalExpansion, SphericalExpansion +from rascaline.utils import KSpaceSpliner, RadialIntegralFromFunction, RealSpaceSpliner +from rascaline.utils.atomic_density import DeltaDensity, GaussianDensity, LodeDensity +from rascaline.utils.radial_basis import GtoBasis + +from ..test_systems import SystemForTests + + +pytest.importorskip("scipy") +from scipy.special import gamma, hyp1f1 # noqa def sine(n: int, ell: int, positions: np.ndarray) -> np.ndarray: @@ -102,3 +111,261 @@ def test_splines_numerical_derivative_error(): match = "Numerically derivative of the radial integral can not be performed" with pytest.raises(ValueError, match=match): RadialIntegralFromFunction(**kwargs).compute() + + +def test_kspace_radial_integral(): + """Test against anayliycal integral with Gaussian densities and GTOs""" + + cutoff = 2 + max_radial = 6 + max_angular = 3 + atomic_gaussian_width = 1.0 + k_cutoff = 1.2 * np.pi / atomic_gaussian_width + + basis = GtoBasis(cutoff=cutoff, max_radial=max_radial) + + spliner = KSpaceSpliner( + max_radial=max_radial, + max_angular=max_angular, + k_cutoff=k_cutoff, + basis=basis, + density=DeltaDensity(), # density does not enter in a Kspace radial integral + accuracy=1e-8, + ) + + Neval = 100 + kk = np.linspace(0, k_cutoff, Neval) + + sigma = np.ones(max_radial, dtype=float) + for i in range(1, max_radial): + sigma[i] = np.sqrt(i) + sigma *= cutoff / max_radial + + factors = np.sqrt(np.pi) * np.ones((max_radial, max_angular + 1)) + + coeffs_num = np.zeros([max_radial, max_angular + 1, Neval]) + coeffs_exact = np.zeros_like(coeffs_num) + + for ell in range(max_angular + 1): + for n in range(max_radial): + i1 = 0.5 * (3 + n + ell) + i2 = 1.5 + ell + factors[n, ell] *= ( + 2 ** (0.5 * (n - ell - 1)) + * gamma(i1) + / gamma(i2) + * sigma[n] ** (2 * i1) + ) + coeffs_exact[n, ell] = ( + factors[n, ell] + * kk**ell + * hyp1f1(i1, i2, -0.5 * (kk * sigma[n]) ** 2) + ) + + coeffs_num[n, ell] = spliner._radial_integral(n, ell, kk) + + assert_allclose(coeffs_num, coeffs_exact) + + +# def test_rspace_radial_integral(): +# """Test against anayliycal integral with Gaussian densities and GTOs""" + +# cutoff = 2 +# max_radial = 6 +# max_angular = 3 +# atomic_gaussian_width = 1.0 + +# basis = GtoBasis(cutoff=cutoff, max_radial=max_radial) +# density = GaussianDensity(atomic_gaussian_width=atomic_gaussian_width) + +# spliner = RealSpaceSpliner( +# cutoff=cutoff, +# max_radial=max_radial, +# max_angular=max_angular, +# basis=basis, +# density=density, +# accuracy=1e-8, +# ) + +# Neval = 100 +# positions = np.linspace(0, cutoff, Neval) + +# coeffs_num = np.zeros([max_radial, max_angular + 1, Neval]) +# coeffs_exact = np.zeros_like(coeffs_num) + +# a = 1 / (2 * atomic_gaussian_width**2) +# b_n = basis.sigmas + +# for ell in range(max_angular + 1): +# for n in range(max_radial): +# n_eff = (n + ell + 3) / 2 +# l_eff = ell + 1.5 +# x = a**2 * positions**2 / (a**2 + b_n[n] ** 2) + +# coeffs_exact[n, ell] = ( +# (gamma(n_eff) * (a * positions) ** ell) +# / (gamma(l_eff) * (a + b_n[n]) ** n_eff) +# * hyp1f1(n_eff, l_eff, x) +# ) + +# coeffs_num[n, ell] = spliner._radial_integral(n, ell, positions) + +# coeffs_exact *= np.pi**1.5 / (np.pi * atomic_gaussian_width**2) ** 0.75 +# assert_allclose(coeffs_num, coeffs_exact) + + +def test_rspace_delta(): + cutoff = 2 + max_radial = 6 + max_angular = 3 + + basis = GtoBasis(cutoff=cutoff, max_radial=max_radial) + density = DeltaDensity() + + spliner = RealSpaceSpliner( + max_radial=max_radial, + max_angular=max_angular, + cutoff=cutoff, + basis=basis, + density=density, + accuracy=1e-8, + ) + + positions = np.linspace(0, cutoff) + + for ell in range(max_angular + 1): + for n in range(max_radial): + assert_equal( + spliner._radial_integral(n, ell, positions), + basis.compute(n, ell, positions), + ) + assert_equal( + spliner._radial_integral_derivative(n, ell, positions), + basis.compute_derivative(n, ell, positions), + ) + + +def test_real_space_spliner(): + """Compare splined spherical expansion with GTOs and a Gaussian density to + analytical implementation.""" + cutoff = 8.0 + max_radial = 12 + max_angular = 9 + atomic_gaussian_width = 1.2 + + # We choose an accuracy that is larger then the default one (1e-8) to limit the time + # consumption of the test. + accuracy = 1e-4 + + spliner = RealSpaceSpliner( + cutoff=cutoff, + max_radial=max_radial, + max_angular=max_angular, + basis=GtoBasis(cutoff=cutoff, max_radial=max_radial), + density=GaussianDensity(atomic_gaussian_width=atomic_gaussian_width), + accuracy=accuracy, + ) + + hypers_spherical_expansion = { + "cutoff": cutoff, + "max_radial": max_radial, + "max_angular": max_angular, + "center_atom_weight": 1.0, + "atomic_gaussian_width": atomic_gaussian_width, + "cutoff_function": {"Step": {}}, + } + + analytic = SphericalExpansion( + radial_basis={"Gto": {}}, **hypers_spherical_expansion + ).compute(SystemForTests()) + splined = SphericalExpansion( + radial_basis=spliner.compute(), **hypers_spherical_expansion + ).compute(SystemForTests()) + + for key, block_analytic in analytic.items(): + block_splined = splined.block(key) + assert_allclose( + block_splined.values, block_analytic.values, rtol=5e-4, atol=2e-5 + ) + + +@pytest.mark.parametrize("center_atom_weight", [1.0, 0.0]) +@pytest.mark.parametrize("potential_exponent", [0, 1]) +def test_fourier_space_spliner(center_atom_weight, potential_exponent): + """Compare splined LODE spherical expansion with GTOs and a Gaussian density to + analytical implementation.""" + + cutoff = 2 + max_radial = 6 + max_angular = 4 + atomic_gaussian_width = 0.8 + k_cutoff = 1.2 * np.pi / atomic_gaussian_width + + spliner = KSpaceSpliner( + k_cutoff=k_cutoff, + max_radial=max_radial, + max_angular=max_angular, + basis=GtoBasis(cutoff=cutoff, max_radial=max_radial), + density=LodeDensity( + atomic_gaussian_width=atomic_gaussian_width, + potential_exponent=potential_exponent, + ), + ) + + hypers_spherical_expansion = { + "cutoff": cutoff, + "max_radial": max_radial, + "max_angular": max_angular, + "center_atom_weight": center_atom_weight, + "atomic_gaussian_width": atomic_gaussian_width, + "potential_exponent": potential_exponent, + } + + analytic = LodeSphericalExpansion( + radial_basis={"Gto": {}}, **hypers_spherical_expansion + ).compute(SystemForTests()) + splined = LodeSphericalExpansion( + radial_basis=spliner.compute(), **hypers_spherical_expansion + ).compute(SystemForTests()) + + for key, block_analytic in analytic.items(): + block_splined = splined.block(key) + assert_allclose(block_splined.values, block_analytic.values, atol=1e-14) + + +def test_center_contribution_gto_gaussian(): + cutoff = 2.0 + max_radial = 6 + max_angular = 4 + atomic_gaussian_width = 0.8 + k_cutoff = 1.2 * np.pi / atomic_gaussian_width + + # Numerical evaluation of center contributions + spliner = KSpaceSpliner( + k_cutoff=k_cutoff, + max_radial=max_radial, + max_angular=max_angular, + basis=GtoBasis(cutoff=cutoff, max_radial=max_radial), + density=GaussianDensity(atomic_gaussian_width=atomic_gaussian_width), + ) + + # Analytical evaluation of center contributions + center_contr_analytical = np.zeros((max_radial)) + + normalization = 1.0 / (np.pi * atomic_gaussian_width**2) ** (3 / 4) + sigma_radial = np.ones(max_radial, dtype=float) + + for n in range(1, max_radial): + sigma_radial[n] = np.sqrt(n) + sigma_radial *= cutoff / max_radial + + for n in range(max_radial): + sigmatemp_sq = 1.0 / ( + 1.0 / atomic_gaussian_width**2 + 1.0 / sigma_radial[n] ** 2 + ) + neff = 0.5 * (3 + n) + center_contr_analytical[n] = (2 * sigmatemp_sq) ** neff * gamma(neff) + + center_contr_analytical *= normalization * 2 * np.pi / np.sqrt(4 * np.pi) + + assert_allclose(spliner._center_contribution, center_contr_analytical, rtol=1e-14) diff --git a/rascaline/src/calculators/lode/radial_integral/gto.rs b/rascaline/src/calculators/lode/radial_integral/gto.rs index 95507a878..89afbe439 100644 --- a/rascaline/src/calculators/lode/radial_integral/gto.rs +++ b/rascaline/src/calculators/lode/radial_integral/gto.rs @@ -205,7 +205,6 @@ impl LodeRadialIntegral for LodeRadialIntegralGto { } let gto_orthonormalization = basis.orthonormalization_matrix(); - return gto_orthonormalization.dot(&(contrib)); } } diff --git a/tox.ini b/tox.ini index 351f8da60..20f7fb92b 100644 --- a/tox.ini +++ b/tox.ini @@ -53,7 +53,7 @@ deps = chemfiles pytest pytest-cov - + scipy commands = pytest --cov={env_site_packages_dir}/rascaline --cov-report xml:.tox/coverage.xml --import-mode=append {posargs}