Skip to content

Commit

Permalink
Feature : add initialization of rhoij (#3164)
Browse files Browse the repository at this point in the history
Co-authored-by: wenfei-li <[email protected]>
  • Loading branch information
wenfei-li and wenfei-li authored Nov 7, 2023
1 parent 5acb805 commit d50b0ae
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 1 deletion.
11 changes: 11 additions & 0 deletions source/module_cell/module_paw/paw_atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ void Paw_Atom::accumulate_rhoij(const int current_spin)
}
}

void Paw_Atom::set_rhoij(std::vector<double> & rhoij_in)
{
for(int i = 0; i < nproj*(nproj+1)/2; i ++)
{
for(int is = 0; is < GlobalV::NSPIN; is ++)
{
rhoij[is][i] = rhoij_in[i];
}
}
}

void Paw_Atom::convert_rhoij()
{
nrhoijsel = 0;
Expand Down
3 changes: 3 additions & 0 deletions source/module_cell/module_paw/paw_atom.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ class Paw_Atom
//pass <psi|ptilde> from outside and saves it
void set_ca(std::vector<std::complex<double>> & ca_in, const double weight_in);

void init_rhoij(); //set rhoij according to occupation number in xml file

void reset_rhoij(); //set rhoij = 0
void accumulate_rhoij(const int current_spin); //calculate and accumulate <psi|ptilde><ptilde|psi> from <psi|ptilde>
void set_rhoij(std::vector<double> & rhoij_in);

void set_dij(double** dij_in); //sets dij from input
void reset_dij(); //set dij = 0
Expand Down
32 changes: 32 additions & 0 deletions source/module_cell/module_paw/paw_cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,38 @@ void Paw_Cell::init_paw_cell(
int nproj = paw_element_list[it].get_mstates();
paw_atom_list[iat].init_paw_atom(nproj);
}

this -> init_rhoij();
}

void Paw_Cell::init_rhoij()
{
ModuleBase::TITLE("Paw_Cell","init_rhoij");

for(int iat = 0; iat < nat; iat ++)
{
const int it = atom_type[iat];
const int nproj = paw_element_list[it].get_mstates();

const int size_rhoij = nproj * (nproj + 1) / 2;

std::vector<double> mstate_occ = paw_element_list[it].get_mstate_occ();

std::vector<double> rhoij_in;
rhoij_in.resize(size_rhoij);
for(int i = 0; i < size_rhoij; i ++)
{
rhoij_in[i] = 0.0;
}

for(int iproj = 0; iproj < nproj; iproj ++)
{
int i0 = iproj * (iproj + 1) / 2;
rhoij_in[i0 + iproj] = mstate_occ[iproj] / GlobalV::NSPIN;
}

paw_atom_list[iat].set_rhoij(rhoij_in);
}
}

void Paw_Cell::set_eigts(const int nx_in, const int ny_in, const int nz_in,
Expand Down
2 changes: 2 additions & 0 deletions source/module_cell/module_paw/paw_cell.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Paw_Cell
void accumulate_rhoij(const std::complex<double> * psi, const double weight);
void reset_rhoij();

void init_rhoij(); // set rhoij according to occupation number in xml file

// returns rhoij for each atom
//std::vector<std::vector<double>> get_rhoij();
// returns rhoijp and related info for each atom
Expand Down
15 changes: 14 additions & 1 deletion source/module_cell/module_paw/paw_element.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void Paw_Element::read_paw_xml(std::string filename)
this->reset_buffer(ifs);

// ============================================================
// 3. number of projector channels and corresponding l values
// 3. number of projector channels and corresponding l values, and occupation numbers
// example :
// <valence_states>
// <state n=" 1" l="0" f=" 1.0000000E+00" rc=" 0.9949503343" e="-2.3345876E-01" id= "H1"/>
Expand All @@ -57,13 +57,24 @@ void Paw_Element::read_paw_xml(std::string filename)
this->reset_buffer(ifs);

lstate.resize(nstates);
lstate_occ.resize(nstates);
lmax = 0;
for(int istate = 0; istate < nstates; istate ++)
{
line = this->scan_file(ifs, "<state");

this->lstate[istate] = this->extract_int(line,"l=");
lmax = std::max(lmax, lstate[istate]);

int pos = line.find("f=");
if(pos!=std::string::npos)
{
this->lstate_occ[istate] = this->extract_double(line,"f=");
}
else
{
this->lstate_occ[istate] = 0.0;
}
}

this->nstates_to_mstates();
Expand Down Expand Up @@ -273,6 +284,7 @@ void Paw_Element::nstates_to_mstates()

mstate.resize(mstates);
im_to_istate.resize(mstates);
mstate_occ.resize(mstates);

int index = 0;
for(int istate = 0; istate < nstates; istate ++)
Expand All @@ -282,6 +294,7 @@ void Paw_Element::nstates_to_mstates()
{
mstate[index] = im - lstate[istate];
im_to_istate[index] = istate;
mstate_occ[index] = lstate_occ[istate] / double(nm);
index ++;
}
}
Expand Down
6 changes: 6 additions & 0 deletions source/module_cell/module_paw/paw_element.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,14 @@ class Paw_Element
std::vector<int> lstate; //l quantum number of each channel
int lmax; // max of quantum number l

std::vector<double> lstate_occ; //occupation number of each (n,l) channel

int mstates; //#. m states (for each (n,l) channel, there will be 2l+1 m states)
std::vector<int> mstate; //m quantum number of each mstate
std::vector<int> im_to_istate; //map from mstate to (n,l) channel (namely nstates)

std::vector<double> mstate_occ; //occupation number of each mstate

//for log grid, r_i = rstep * exp[(lstep * i)-1]
//rstep <-> a, lstep <-> d from xml file
double lstep, rstep;
Expand Down Expand Up @@ -140,6 +144,8 @@ class Paw_Element
//max quantum nubmer l
int get_lmax(){return lmax;}

std::vector<double> get_mstate_occ(){return mstate_occ;}

// return ptilde_q for a given channel at a given q_in, using spline
double get_ptilde(const int istate_in, const double q_in, const double omega);

Expand Down
1 change: 1 addition & 0 deletions source/module_cell/module_paw/test/test_paw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class Test_Ptilde : public testing::Test

TEST_F(Test_Ptilde, test_paw)
{
paw_element.init_paw_element(50,1.2);
paw_element.read_paw_xml("Si_test.xml");

const int npw = 411;
Expand Down
22 changes: 22 additions & 0 deletions source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,28 @@ namespace ModuleESolver
GlobalC::paw_cell.set_eigts(
this->pw_wfc->nx,this->pw_wfc->ny,this->pw_wfc->nz,
this->sf.eigts1.c,this->sf.eigts2.c,this->sf.eigts3.c);

std::vector<std::vector<double>> rhoijp;
std::vector<std::vector<int>> rhoijselect;
std::vector<int> nrhoijsel;
#ifdef __MPI
if(GlobalV::RANK_IN_POOL == 0)
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for(int iat = 0; iat < GlobalC::ucell.nat; iat ++)
{
GlobalC::paw_cell.set_rhoij(iat,nrhoijsel[iat],rhoijselect[iat].size(),rhoijselect[iat].data(),rhoijp[iat].data());
}
}
#else
this->get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for(int iat = 0; iat < GlobalC::ucell.nat; iat ++)
{
GlobalC::paw_cell.set_rhoij(iat,nrhoijsel[iat],rhoijselect[iat].size(),rhoijselect[iat].data(),rhoijp[iat].data());
}
#endif
}
#endif
}
Expand Down
22 changes: 22 additions & 0 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,28 @@ void ESolver_KS_PW<T, Device>::init_after_vc(Input& inp, UnitCell& ucell)
GlobalC::paw_cell.prepare_paw();
#endif
GlobalC::paw_cell.set_sij();

std::vector<std::vector<double>> rhoijp;
std::vector<std::vector<int>> rhoijselect;
std::vector<int> nrhoijsel;
#ifdef __MPI
if(GlobalV::RANK_IN_POOL == 0)
{
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for(int iat = 0; iat < GlobalC::ucell.nat; iat ++)
{
GlobalC::paw_cell.set_rhoij(iat,nrhoijsel[iat],rhoijselect[iat].size(),rhoijselect[iat].data(),rhoijp[iat].data());
}
}
#else
GlobalC::paw_cell.get_rhoijp(rhoijp, rhoijselect, nrhoijsel);

for(int iat = 0; iat < GlobalC::ucell.nat; iat ++)
{
GlobalC::paw_cell.set_rhoij(iat,nrhoijsel[iat],rhoijselect[iat].size(),rhoijselect[iat].data(),rhoijp[iat].data());
}
#endif
}
#endif

Expand Down

0 comments on commit d50b0ae

Please sign in to comment.