Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SCAN functional #21

Open
ESEberhard opened this issue Oct 23, 2024 · 1 comment
Open

Add SCAN functional #21

ESEberhard opened this issue Oct 23, 2024 · 1 comment

Comments

@ESEberhard
Copy link

Hi guys,

I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base.
I attached it below :)

"""
This module contains the implementation of the SCAN (Strongly Constrained and
Appropriately Normed Semilocal) meta-GGA functional by Sun et al.
https://doi.org/10.1103/PhysRevLett.115.036402
https://journals.aps.org/prl/supplemental/10.1103/PhysRevLett.115.036402

The SCAN functional is a meta-GGA functional, which means that it depends on:
    r_s = (3 / (4 * pi * n))**(1 / 3)
the spin polarization:
    xi  = (n_up - n_down) / n
the reduced gradient:
    s = |grad(n)| / (2 * (3 * pi**2)**(1 / 3) * n**(4 / 3))
"""
import jax
import jax.numpy as jnp
from functools import partial

from mldft.xc_energy.ueg_xc import e_x_ueg, ec_LSDA1_fn

pi = jnp.pi

def e_x_ueg(n: jax.Array) -> jax.Array:
    """
    The exchange energy per particle of the uniform electron gas.
    """
    return - (3 / 4) * (3 / pi)**(1 / 3) * n**(1 / 3)


def ec_LSDA1_fn(n: jax.Array,
                xi: jax.Array,
                use_RPA=False,
                modified=False) -> jax.Array:
    """
    The correlation energy of the uniform electron gas by
    Perdew and Wang (1992) (PW92).
    https://doi.org/10.1103/PhysRevB.45.13244,

    libxc reference implementation:
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/lda_c_pw.c
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/maple2c/lda_exc/lda_c_pw.c#L14
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/lda_exc/lda_c_pw.mpl
    https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/util.mpl
    """
    r_s = _calc_r_s(n)

    def analytic_base_from(p, A, a1, b1, b2, b3, b4) -> jax.Array:
        """
        p and A are constrained the remaining parameters were fitted (see ref)
        """
        beta_sum = b1 * r_s**(1 / 2) \
                 + b2 * r_s          \
                 + b3 * r_s**(3 / 2) \
                 + b4 * r_s**(p + 1)
        log_term = jnp.log1p(1 / (2 * A * beta_sum))
        return - 2 * A * (1 + a1 * r_s) * log_term

    A_unpolarized = 0.031091 if not modified else 0.0310907
    A_polarized   = 0.015545 if not modified else 0.01554535
    A_alpha_c     = 0.016887 if not modified else 0.0168869
    if use_RPA:
        # random-phase approximation (RPA)
        ec_unpolarized =analytic_base_from(p=0.75, A=A_unpolarized,
                            a1=0.082477, b1=5.1486, b2=1.6483, b3=0.2347,  b4=0.20614)
        ec_polarized =  analytic_base_from(p=0.75, A=A_polarized,
                            a1=0.035374, b1=6.4869, b2=1.3083, b3=0.15180, b4=0.082349)
        alpha_c =     - analytic_base_from(p=1,    A=A_alpha_c,
                            a1=0.028829, b1=10.357, b2=3.6231, b3=0.47990, b4=0.12279)
    else:
        ec_unpolarized =analytic_base_from(p=1, A=A_unpolarized,
                            a1=0.21370, b1=7.5957,  b2=3.5876, b3=1.6382,  b4=0.49294)
        ec_polarized =  analytic_base_from(p=1, A=A_polarized,
                            a1=0.20548, b1=14.1189, b2=6.1977, b3=3.3662,  b4=0.62517)
        alpha_c =     - analytic_base_from(p=1, A=A_alpha_c,
                            a1=0.11125, b1=10.357,  b2=3.6231, b3=0.88026, b4=0.49671)

    dd_f_zero = 1.709921 if not modified else 1.709920934161365617563962776245
    # alpha_c = dd_f_zero * (ec_polarized - ec_unpolarized)  # from another reference
    f_xi = ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3) - 2) / (2**(4 / 3) - 2)

    return ec_unpolarized + alpha_c * f_xi / dd_f_zero * (1 - xi**4) \
          + (ec_polarized - ec_unpolarized) * f_xi * xi**4

def _calc_r_s(n: jax.Array, epsilon: float = 0) -> jax.Array:
    """
    The Wigner-Seitz radius
    """
    return (3 / (4 * pi * n + epsilon))**(1 / 3)


@jax.custom_jvp
def _f_interp(alpha: jax.Array, c1: float, c2: float, d: float) -> jax.Array:
    """
    TODO: add nondiff_argnums=(1,2,3)
    """
    term1 = jnp.exp(-c1 * alpha / (1 - alpha))
    term2 = -d * jnp.exp(c2 / (1 - alpha))
    return jnp.where(alpha < 1, term1, term2)


@_f_interp.defjvp
def jvp_f_interp(primals, tangents):
    """
    does not account for derivative w.r.t. constants c1, c2, d
    """
    alpha, c1, c2, d = primals
    alpha_dot, _, _, _ = tangents
    df = _f_interp(alpha, c1, c2, d)
    dterm1_factor = -c1 / (1 - alpha)**2
    dterm2_factor =  c2 / (1 - alpha)**2
    df_dot = jnp.where(alpha < 1, dterm1_factor, dterm2_factor) * df * alpha_dot
    return df, df_dot


@jax.jit
def e_x_scan(n: jax.Array, s:jax.Array, alpha: jax.Array) -> jax.Array:
    """
    The exchange energy per particle of SCAN
    """

    def F_x(s: jax.Array, alpha: jax.Array) -> jax.Array:
        """
        The exchange enhancement factor
        """
        # fit parameters
        k1 = 0.065
        c1x = 0.667
        c2x = 0.8
        dx = 1.24

        def h1x_fn(s: jax.Array, alpha: jax.Array) -> jax.Array:
            mu_ak = 10 / 81
            b2 = jnp.sqrt(5913 / 405000)
            b1 = (511 / 13500) / (2 * b2)
            b3 = 0.5
            b4 = mu_ak**2 / k1 - 1606 / 18225 - b1**2

            exp1 = jnp.exp(- jnp.abs(b4) * s**2 / mu_ak)
            exp2 = jnp.exp(- b3 * (1  - alpha)**2)
            x = mu_ak * s**2 * (1 + (b4 * s**2 / mu_ak) * exp1) \
                + (b1 * s**2 + b2 * (1 - alpha) * exp2)**2
            return 1 + k1 - k1 / (1 + x / k1)

        a1 = 4.9479
        h0x = 1.174

        def gx(s: jax.Array) -> jax.Array:
            return - jnp.expm1(-a1 / jnp.sqrt(s))

        h1x = h1x_fn(s, alpha)
        fx_alpha = _f_interp(alpha, c1x, c2x, dx)
        return (h1x + fx_alpha * (h0x - h1x)) * gx(s)

    return e_x_ueg(n) * F_x(s, alpha)


@jax.jit  # TODO: add static_argnames=("xi") ?
def e_c_scan(n: jax.Array,
             s:jax.Array,
             xi: jax.Array,
             alpha: jax.Array) -> jax.Array:
    """
    The correlation energy per particle of SCAN
    TODO: verify xi != 0 correctness if polarized systems are added

    https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/mgga_exc/mgga_c_scan.mpl
    """

    # fit parameters
    c1c = 0.64
    c2c = 1.5
    dc = 0.7

    #fixed constants
    b1c = 0.0285764
    b2c = 0.0889
    b3c = 0.125541

    def Psi_fn(xi: jax.Array) -> jax.Array:
        return ((1 + xi) ** (2 / 3) + (1 - xi) ** (2 / 3)) / 2

    def ec1_fn(n: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
        """
        Perdew-Ernzerhof-Wang 1996 (PEW96)-like correlation energy
        https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_c_pbe.mpl
        """
        ec_LSDA1 = ec_LSDA1_fn(n, xi)

        def H1(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
            # gamma = 0.031091
            gamma = (1 - jnp.log(2)) / jnp.pi**2
            # beta = 0.06672455060314922
            beta = 0.066725 * (1 + 0.1 * r_s) / (1 + 0.1778 * r_s)  # SCAN
            # beta = 0.066725  # PBE0
            Psi = Psi_fn(xi)
            w1 = jnp.expm1(- ec_LSDA1 / (gamma * Psi**3))
            A = beta / (gamma * w1)
            t = ((3 * pi**2 / 16) ** (1 / 3) * s) / (Psi * jnp.sqrt(r_s))
            g = (1 + 4 * A * t**2)**(- 1 / 4)
            # g = 1 / (1 + A * t**2 + A**2 * t**4)  # PBE0
            return gamma * Psi**3 * jnp.log1p(w1 * (1 - g))

        return ec_LSDA1 + H1(r_s, s, xi)

    def ec0_fn(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array:
        ec_LDA0 = - b1c / (1 + b2c * jnp.sqrt(r_s) + b3c * r_s)

        def dx(xi: jax.Array) -> jax.Array:
            return ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3)) / 2

        def Gc(xi: jax.Array) -> jax.Array:
            return (1 - 2.3631 * (dx(xi) - 1)) * (1 - xi**12)

        w0 = jnp.expm1(- ec_LDA0 / b1c)
        chi_unpolarized = 0.12802585262625815 # 0.128026
        g = 1 / (1 + 4 * chi_unpolarized * s**2)**(1 / 4)
        H0 = b1c * jnp.log1p(w0 * (1 - g))
        return (ec_LDA0 + H0) * Gc(xi)


    r_s = _calc_r_s(n)
    ec1 = ec1_fn(n, s, xi)
    fc_alpha = _f_interp(alpha, c1c, c2c, dc)
    return ec1 + fc_alpha * (ec0_fn(r_s, s, xi) - ec1)

@partial(jax.jit, static_argnames=("xi"))
def e_xc_scan(n: jax.Array,
              s:jax.Array,
              xi: jax.Array,
              alpha: jax.Array) -> jax.Array:
    """
    The exchange-correlation energy per particle of SCAN
    """
    return e_x_scan(n, s, alpha) + e_c_scan(n, s, xi, alpha)
@hatemhelal
Copy link
Member

Thank you @ESEberhard, SCAN is a functional I have wanted to learn more about and integrate into MESS so this is a great help 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants