We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi guys,
I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base. I attached it below :)
""" This module contains the implementation of the SCAN (Strongly Constrained and Appropriately Normed Semilocal) meta-GGA functional by Sun et al. https://doi.org/10.1103/PhysRevLett.115.036402 https://journals.aps.org/prl/supplemental/10.1103/PhysRevLett.115.036402 The SCAN functional is a meta-GGA functional, which means that it depends on: r_s = (3 / (4 * pi * n))**(1 / 3) the spin polarization: xi = (n_up - n_down) / n the reduced gradient: s = |grad(n)| / (2 * (3 * pi**2)**(1 / 3) * n**(4 / 3)) """ import jax import jax.numpy as jnp from functools import partial from mldft.xc_energy.ueg_xc import e_x_ueg, ec_LSDA1_fn pi = jnp.pi def e_x_ueg(n: jax.Array) -> jax.Array: """ The exchange energy per particle of the uniform electron gas. """ return - (3 / 4) * (3 / pi)**(1 / 3) * n**(1 / 3) def ec_LSDA1_fn(n: jax.Array, xi: jax.Array, use_RPA=False, modified=False) -> jax.Array: """ The correlation energy of the uniform electron gas by Perdew and Wang (1992) (PW92). https://doi.org/10.1103/PhysRevB.45.13244, libxc reference implementation: https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/lda_c_pw.c https://github.com/ElectronicStructureLibrary/libxc/blob/master/src/maple2c/lda_exc/lda_c_pw.c#L14 https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/lda_exc/lda_c_pw.mpl https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/util.mpl """ r_s = _calc_r_s(n) def analytic_base_from(p, A, a1, b1, b2, b3, b4) -> jax.Array: """ p and A are constrained the remaining parameters were fitted (see ref) """ beta_sum = b1 * r_s**(1 / 2) \ + b2 * r_s \ + b3 * r_s**(3 / 2) \ + b4 * r_s**(p + 1) log_term = jnp.log1p(1 / (2 * A * beta_sum)) return - 2 * A * (1 + a1 * r_s) * log_term A_unpolarized = 0.031091 if not modified else 0.0310907 A_polarized = 0.015545 if not modified else 0.01554535 A_alpha_c = 0.016887 if not modified else 0.0168869 if use_RPA: # random-phase approximation (RPA) ec_unpolarized =analytic_base_from(p=0.75, A=A_unpolarized, a1=0.082477, b1=5.1486, b2=1.6483, b3=0.2347, b4=0.20614) ec_polarized = analytic_base_from(p=0.75, A=A_polarized, a1=0.035374, b1=6.4869, b2=1.3083, b3=0.15180, b4=0.082349) alpha_c = - analytic_base_from(p=1, A=A_alpha_c, a1=0.028829, b1=10.357, b2=3.6231, b3=0.47990, b4=0.12279) else: ec_unpolarized =analytic_base_from(p=1, A=A_unpolarized, a1=0.21370, b1=7.5957, b2=3.5876, b3=1.6382, b4=0.49294) ec_polarized = analytic_base_from(p=1, A=A_polarized, a1=0.20548, b1=14.1189, b2=6.1977, b3=3.3662, b4=0.62517) alpha_c = - analytic_base_from(p=1, A=A_alpha_c, a1=0.11125, b1=10.357, b2=3.6231, b3=0.88026, b4=0.49671) dd_f_zero = 1.709921 if not modified else 1.709920934161365617563962776245 # alpha_c = dd_f_zero * (ec_polarized - ec_unpolarized) # from another reference f_xi = ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3) - 2) / (2**(4 / 3) - 2) return ec_unpolarized + alpha_c * f_xi / dd_f_zero * (1 - xi**4) \ + (ec_polarized - ec_unpolarized) * f_xi * xi**4 def _calc_r_s(n: jax.Array, epsilon: float = 0) -> jax.Array: """ The Wigner-Seitz radius """ return (3 / (4 * pi * n + epsilon))**(1 / 3) @jax.custom_jvp def _f_interp(alpha: jax.Array, c1: float, c2: float, d: float) -> jax.Array: """ TODO: add nondiff_argnums=(1,2,3) """ term1 = jnp.exp(-c1 * alpha / (1 - alpha)) term2 = -d * jnp.exp(c2 / (1 - alpha)) return jnp.where(alpha < 1, term1, term2) @_f_interp.defjvp def jvp_f_interp(primals, tangents): """ does not account for derivative w.r.t. constants c1, c2, d """ alpha, c1, c2, d = primals alpha_dot, _, _, _ = tangents df = _f_interp(alpha, c1, c2, d) dterm1_factor = -c1 / (1 - alpha)**2 dterm2_factor = c2 / (1 - alpha)**2 df_dot = jnp.where(alpha < 1, dterm1_factor, dterm2_factor) * df * alpha_dot return df, df_dot @jax.jit def e_x_scan(n: jax.Array, s:jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange energy per particle of SCAN """ def F_x(s: jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange enhancement factor """ # fit parameters k1 = 0.065 c1x = 0.667 c2x = 0.8 dx = 1.24 def h1x_fn(s: jax.Array, alpha: jax.Array) -> jax.Array: mu_ak = 10 / 81 b2 = jnp.sqrt(5913 / 405000) b1 = (511 / 13500) / (2 * b2) b3 = 0.5 b4 = mu_ak**2 / k1 - 1606 / 18225 - b1**2 exp1 = jnp.exp(- jnp.abs(b4) * s**2 / mu_ak) exp2 = jnp.exp(- b3 * (1 - alpha)**2) x = mu_ak * s**2 * (1 + (b4 * s**2 / mu_ak) * exp1) \ + (b1 * s**2 + b2 * (1 - alpha) * exp2)**2 return 1 + k1 - k1 / (1 + x / k1) a1 = 4.9479 h0x = 1.174 def gx(s: jax.Array) -> jax.Array: return - jnp.expm1(-a1 / jnp.sqrt(s)) h1x = h1x_fn(s, alpha) fx_alpha = _f_interp(alpha, c1x, c2x, dx) return (h1x + fx_alpha * (h0x - h1x)) * gx(s) return e_x_ueg(n) * F_x(s, alpha) @jax.jit # TODO: add static_argnames=("xi") ? def e_c_scan(n: jax.Array, s:jax.Array, xi: jax.Array, alpha: jax.Array) -> jax.Array: """ The correlation energy per particle of SCAN TODO: verify xi != 0 correctness if polarized systems are added https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/mgga_exc/mgga_c_scan.mpl """ # fit parameters c1c = 0.64 c2c = 1.5 dc = 0.7 #fixed constants b1c = 0.0285764 b2c = 0.0889 b3c = 0.125541 def Psi_fn(xi: jax.Array) -> jax.Array: return ((1 + xi) ** (2 / 3) + (1 - xi) ** (2 / 3)) / 2 def ec1_fn(n: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: """ Perdew-Ernzerhof-Wang 1996 (PEW96)-like correlation energy https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_c_pbe.mpl """ ec_LSDA1 = ec_LSDA1_fn(n, xi) def H1(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: # gamma = 0.031091 gamma = (1 - jnp.log(2)) / jnp.pi**2 # beta = 0.06672455060314922 beta = 0.066725 * (1 + 0.1 * r_s) / (1 + 0.1778 * r_s) # SCAN # beta = 0.066725 # PBE0 Psi = Psi_fn(xi) w1 = jnp.expm1(- ec_LSDA1 / (gamma * Psi**3)) A = beta / (gamma * w1) t = ((3 * pi**2 / 16) ** (1 / 3) * s) / (Psi * jnp.sqrt(r_s)) g = (1 + 4 * A * t**2)**(- 1 / 4) # g = 1 / (1 + A * t**2 + A**2 * t**4) # PBE0 return gamma * Psi**3 * jnp.log1p(w1 * (1 - g)) return ec_LSDA1 + H1(r_s, s, xi) def ec0_fn(r_s: jax.Array, s:jax.Array, xi: jax.Array) -> jax.Array: ec_LDA0 = - b1c / (1 + b2c * jnp.sqrt(r_s) + b3c * r_s) def dx(xi: jax.Array) -> jax.Array: return ((1 + xi)**(4 / 3) + (1 - xi)**(4 / 3)) / 2 def Gc(xi: jax.Array) -> jax.Array: return (1 - 2.3631 * (dx(xi) - 1)) * (1 - xi**12) w0 = jnp.expm1(- ec_LDA0 / b1c) chi_unpolarized = 0.12802585262625815 # 0.128026 g = 1 / (1 + 4 * chi_unpolarized * s**2)**(1 / 4) H0 = b1c * jnp.log1p(w0 * (1 - g)) return (ec_LDA0 + H0) * Gc(xi) r_s = _calc_r_s(n) ec1 = ec1_fn(n, s, xi) fc_alpha = _f_interp(alpha, c1c, c2c, dc) return ec1 + fc_alpha * (ec0_fn(r_s, s, xi) - ec1) @partial(jax.jit, static_argnames=("xi")) def e_xc_scan(n: jax.Array, s:jax.Array, xi: jax.Array, alpha: jax.Array) -> jax.Array: """ The exchange-correlation energy per particle of SCAN """ return e_x_scan(n, s, alpha) + e_c_scan(n, s, xi, alpha)
The text was updated successfully, but these errors were encountered:
Thank you @ESEberhard, SCAN is a functional I have wanted to learn more about and integrate into MESS so this is a great help 🤗
Sorry, something went wrong.
No branches or pull requests
Hi guys,
I thought you might find our implementation of SCAN for EG-XC useful to integrate into the code base.
I attached it below :)
The text was updated successfully, but these errors were encountered: