Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Q constraint #68

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 92 additions & 47 deletions pyci/rdm/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import numpy as np

from scipy.optimize import root

from scipy.optimize import root

__all__ = [
"find_closest_sdp",
Expand All @@ -30,6 +29,7 @@
"calc_T2_prime",
]


def find_closest_sdp(dm, constraint, alpha):
r"""
Projection onto a semidefinite constraint.
Expand All @@ -44,15 +44,15 @@ def find_closest_sdp(dm, constraint, alpha):
Value of the correct trace.

"""
#symmetrize if necessary
# symmetrize if necessary
constrained = constraint(dm)
L = constrained + constrained.conj().T
#find eigendecomposition
# find eigendecomposition
vals, vecs = np.linalg.eigh(L)
#calculate the shift, sigma0
# calculate the shift, sigma0
sigma0 = calculate_shift(vals, alpha)
#calculate the closest semidefinite positive matrix with correct trace

# calculate the closest semidefinite positive matrix with correct trace
L_closest = vals @ np.diag(vecs - sigma0) @ vecs.conj().T

# return the reconstructed density matrix
Expand All @@ -72,17 +72,62 @@ def calculate_shift(eigenvalues, alpha):
Value of the coprrect trace.

"""
#sample code, to be confirmed
trace = lambda sigma0: np.sum(np.heaviside(eigenvalues - sigma0, 0.5)*(eigenvalues - sigma0))
constraint = lambda x: trace(x) - alpha
res = root(constraint, 0)
# sample code, to be confirmed
trace = lambda sigma0: np.sum(np.heaviside(eigenvalues - sigma0, 0.5) * (eigenvalues - sigma0))
constraint = lambda x: trace(x) - alpha
res = root(constraint, 0)
return res.x


def calc_P():
pass

def calc_Q():
pass

def calc_Q(gamma, N, conjugate=False):
"""
Calculate the Q tensor.

Parameters
----------
gamma: np.ndarray
One-particle density matrix (1DM) tensor.
N: int
Number of electrons in the system.
conjugate: bool
conjugate or regular condition

Returns
-------
np.ndarray

Notes
-----
Q is defined as:

.. math::
\mathcal{Q}_{\alpha \beta ; \gamma \delta}(\Gamma) = \delta_{\alpha \gamma}\delta_{\beta \delta} -
\delta_{\alpha \delta}\delta_{\beta \gamma} + \Gamma_{\alpha \beta ; \gamma \delta} -
\delta_{\alpha \gamma}\rho_{\beta \delta} + \delta_{\beta \gamma}\rho_{\alpha \delta} +
\delta_{\alpha \delta}\rho_{\beta \gamma} - \delta_{\beta \delta}\rho_{\alpha \gamma}

Q is self-adjoint as stated in Appendix F of Poelman's thesis https://biblio.ugent.be/publication/6933577
"""
dim = gamma.shape[0]
eye = np.eye(dim)

a_bar = np.einsum('abgd -> ag', gamma)
rho = 1 / (N - 1) * a_bar

delta_ag_bd = np.einsum('ag,bd->abgd', eye, eye)
delta_ad_bg = np.einsum('ad,bg->abgd', eye, eye)

ag_rho_bd = np.einsum('ag,bd->abgd', eye, rho)
bg_rho_ad = np.einsum('bg,ad->abgd', eye, rho)
ad_rho_bg = np.einsum('ad,bg->abgd', eye, rho)
bd_rho_ag = np.einsum('bd,ag->abgd', eye, rho)

return delta_ag_bd - delta_ad_bg + gamma - ag_rho_bd + bg_rho_ad + ad_rho_bg - bd_rho_ag


def calc_G(gamma, N, conjugate=False):
"""
Expand Down Expand Up @@ -111,17 +156,18 @@ def calc_G(gamma, N, conjugate=False):
"""
eye = np.eye(gamma.shape[0])
a_bar = np.einsum('abgb -> ag', gamma)
rho = 1/(N - 1) * a_bar
rho = 1 / (N - 1) * a_bar
if not conjugate:
return np.einsum('bd, ag -> abgd', eye, rho) - np.einsum('adgb -> abgd', gamma)
term_1 = 1/(N-1) *\
(np.einsum('bd, ag -> abgd', eye, a_bar) - np.einsum('ad, bg -> abgd', eye, a_bar) -\
np.einsum('bg, ad -> abgd', eye, a_bar) + np.einsum('ag, bd -> abgd', eye, a_bar)
)
term_2 = -np.einsum('adgb -> abgd', gamma) + np.einsum('bdga -> abgd', gamma) +\
np.einsum('agdb -> abgd', gamma) - np.einsum('bgda -> abgd', gamma)
term_1 = 1 / (N - 1) * \
(np.einsum('bd, ag -> abgd', eye, a_bar) - np.einsum('ad, bg -> abgd', eye, a_bar) - \
np.einsum('bg, ad -> abgd', eye, a_bar) + np.einsum('ag, bd -> abgd', eye, a_bar)
)
term_2 = -np.einsum('adgb -> abgd', gamma) + np.einsum('bdga -> abgd', gamma) + \
np.einsum('agdb -> abgd', gamma) - np.einsum('bgda -> abgd', gamma)
return term_1 + term_2


def calc_T1(gamma, N, conjugate):
"""
Calculating T1 tensor
Expand Down Expand Up @@ -174,30 +220,30 @@ def calc_T1(gamma, N, conjugate):
eye = np.eye(gamma.shape[0])

if not conjugate:
rho = 1 / (N-1) * np.einsum('abgb -> ag', gamma)
rho = 1 / (N - 1) * np.einsum('abgb -> ag', gamma)
term_1 = np.einsum('gz, be, ad -> abgdez', eye, eye, eye) + \
np.einsum('ge, ad, bz -> abgdez', eye, eye, eye) + \
np.einsum('az, ge, bd -> abgdez', eye, eye, eye) + \
np.einsum('gz, ae, bd -> abgdez', eye, eye, eye) + \
np.einsum('az, be, gd -> abgdez', eye, eye, eye)
term_2 = - np.einsum('gz, be, ad -> abgdez', eye, eye, rho) + \
np.einsum('bz, ge, ad -> abgdez', eye, eye, rho) + \
np.einsum('gz, ae, bd -> abgdez', eye, eye, rho) - \
np.einsum('az, ge, bd -> abgdez', eye, eye, rho) - \
np.einsum('bz, ae, gd -> abgdez', eye, eye, rho) + \
np.einsum('az, be, gd -> abgdez', eye, eye, rho)
np.einsum('bz, ge, ad -> abgdez', eye, eye, rho) + \
np.einsum('gz, ae, bd -> abgdez', eye, eye, rho) - \
np.einsum('az, ge, bd -> abgdez', eye, eye, rho) - \
np.einsum('bz, ae, gd -> abgdez', eye, eye, rho) + \
np.einsum('az, be, gd -> abgdez', eye, eye, rho)
term_3 = np.einsum('gz, bd, ae -> abgdez', eye, eye, rho) - \
np.einsum('bz, gd, ae -> abgdez', eye, eye, rho) - \
np.einsum('gz, ad, eb -> abgdez', eye, eye, rho) + \
np.einsum('az, gd, eb -> abgdez', eye, eye, rho) + \
np.einsum('bz, ad, ge -> abgdez', eye, eye, rho) - \
np.einsum('az, bd, ge -> abgdez', eye, eye, rho)
term_4 = - np.einsum('bd, ge, az -> abgdez', eye, eye, rho) + \
np.einsum('be, gd, az -> abgdez', eye, eye, rho) + \
np.einsum('ge, ad, bz -> abgdez', eye, eye, rho) - \
np.einsum('ae, gd, bz -> abgdez', eye, eye, rho) - \
np.einsum('be, ad, gz -> abgdez', eye, eye, rho) + \
np.einsum('ae, bd, gz -> abgdez', eye, eye, rho)
np.einsum('be, gd, az -> abgdez', eye, eye, rho) + \
np.einsum('ge, ad, bz -> abgdez', eye, eye, rho) - \
np.einsum('ae, gd, bz -> abgdez', eye, eye, rho) - \
np.einsum('be, ad, gz -> abgdez', eye, eye, rho) + \
np.einsum('ae, bd, gz -> abgdez', eye, eye, rho)
term_5 = np.einsum('gz, abde -> abgdez', eye, gamma) - np.einsum('bz, agde -> abgdez', eye, gamma) + \
np.einsum('az, bgde -> abgdez', eye, gamma) - np.einsum('ge, abdz -> abgdez', eye, gamma) + \
np.einsum('be, agdz -> abgdez', eye, gamma) - np.einsum('ae, bgdz -> abgdez', eye, gamma) + \
Expand All @@ -208,17 +254,17 @@ def calc_T1(gamma, N, conjugate):
else:
tr_gamma = np.einsum('aaaaaa', gamma)
gamma_abgd = np.einsum('ablgdl -> abgd', gamma)
term_1 = 2 / (N*N - N) *\
(np.einsum('ag, bd -> abgd', eye, eye) - np.einsum('ad, bg -> abgd', eye, eye)) * tr_gamma + gamma_abgd
term_1 = 2 / (N * N - N) * \
(np.einsum('ag, bd -> abgd', eye, eye) - np.einsum('ad, bg -> abgd', eye, eye)) * tr_gamma + gamma_abgd

gamma_ag = np.einsum('abgb -> ag', gamma_abgd)
gamma_bg = np.einsum('abag -> bg', gamma_abgd)
gamma_ad = np.einsum('agdg -> ad', gamma_abgd)
gamma_bd = np.einsum('abda -> bd', gamma_abgd)
term_2 = - 2 / (2*N - 2)*\
(np.einsum('bd, ag -> abgd', eye, gamma_ag) - np.einsum('ad, bg -> abgd', eye, gamma_bg) -\
np.einsum('bg, ad -> abgd', eye, gamma_ad) + np.einsum('ag, bd -> abgd', eye, gamma_bd))

term_2 = - 2 / (2 * N - 2) * \
(np.einsum('bd, ag -> abgd', eye, gamma_ag) - np.einsum('ad, bg -> abgd', eye, gamma_bg) - \
np.einsum('bg, ad -> abgd', eye, gamma_ad) + np.einsum('ag, bd -> abgd', eye, gamma_bd))

return term_1 + term_2

Expand Down Expand Up @@ -257,8 +303,8 @@ def calc_T2(gamma, N, conjugate=False):
"""
eye = np.eye(gamma.shape[0])
if not conjugate:
rho = 1/(N-1) * np.einsum('abgb -> ag', gamma)
term_1 = np.einsum('ad, be, gz -> abgdez', eye, eye, rho) -\
rho = 1 / (N - 1) * np.einsum('abgb -> ag', gamma)
term_1 = np.einsum('ad, be, gz -> abgdez', eye, eye, rho) - \
np.einsum('ae, bd, gz -> abgdez', eye, eye, rho)
term_2 = np.einsum('gz, abde -> abgdez', eye, gamma)
term_3 = np.einsum('ad, gezb -> abgdez', eye, gamma)
Expand All @@ -273,18 +319,18 @@ def calc_T2(gamma, N, conjugate=False):
term_1 = np.einsum('bd, ag -> abgd', eye, a_dtilda)
term_2 = np.einsum('ad, bg -> abgd', eye, a_dtilda)
term_3 = np.einsum('bg, ad -> abgd', eye, a_dtilda)
term_4 = np.einsum('ag, bd -> abgd', eye, a_dtilda)
term_4 = np.einsum('ag, bd -> abgd', eye, a_dtilda)
# term_5 = a_bar
term_6 = np.einsum('dabg -> abgd', a_tilda)
term_7 = np.einsum('dbag -> abgd', a_tilda)
term_8 = np.einsum('gabd -> abgd', a_tilda)
term_9 = np.einsum('gbad -> abgd', a_tilda)
return 0.5/(N-1) * (term_1 - term_2 - term_3 + term_4) +\
return 0.5 / (N - 1) * (term_1 - term_2 - term_3 + term_4) + \
a_bar - (term_6 - term_7 - term_8 + term_9)
eye = np.eye(gamma.shape[0])
rho = 1/(N-1) * np.einsum('abgb -> ag', gamma)
rho = 1 / (N - 1) * np.einsum('abgb -> ag', gamma)
if not conjugate:
term_1 = np.einsum('ad, be, gz -> abgdez', eye, eye, rho) -\
term_1 = np.einsum('ad, be, gz -> abgdez', eye, eye, rho) - \
np.einsum('ae, bd, gz -> abgdez', eye, eye, rho)
term_2 = np.einsum('gz, abde -> abgdez', eye, gamma)
term_3 = np.einsum('ad, gezb -> abgdez', eye, gamma)
Expand All @@ -299,16 +345,15 @@ def calc_T2(gamma, N, conjugate=False):
term_1 = np.einsum('bd, ag -> abgd', eye, a_dtilda)
term_2 = np.einsum('ad, bg -> abgd', eye, a_dtilda)
term_3 = np.einsum('bg, ad -> abgd', eye, a_dtilda)
term_4 = np.einsum('ag, bd -> abgd', eye, a_dtilda)
term_4 = np.einsum('ag, bd -> abgd', eye, a_dtilda)
# term_5 = a_bar
term_6 = np.einsum('dabg -> abgd', a_tilda)
term_7 = np.einsum('dbag -> abgd', a_tilda)
term_8 = np.einsum('gabd -> abgd', a_tilda)
term_9 = np.einsum('gbad -> abgd', a_tilda)
return 0.5/(N-1) * (term_1 - term_2 - term_3 + term_4) +\
return 0.5 / (N - 1) * (term_1 - term_2 - term_3 + term_4) + \
a_bar - (term_6 - term_7 - term_8 + term_9)


def calc_T2_prime():
pass

Loading