Skip to content

Commit

Permalink
Merge pull request lammps#4080 from rbberger/compute_reaxff_atom_over…
Browse files Browse the repository at this point in the history
…flow_fix

Fix buffer overflow in compute reaxff/atom
  • Loading branch information
akohlmey authored Feb 21, 2024
2 parents 0849863 + 90ebca6 commit 81609d0
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
11 changes: 7 additions & 4 deletions src/KOKKOS/compute_reaxff_atom_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ void ComputeReaxFFAtomKokkos<DeviceType>::init()
template<class DeviceType>
void ComputeReaxFFAtomKokkos<DeviceType>::compute_bonds()
{
if (atom->nlocal > nlocal) {
if (atom->nmax > nmax) {
memory->destroy(array_atom);
nlocal = atom->nlocal;
memory->create(array_atom, nlocal, 3, "reaxff/atom:array_atom");
nmax = atom->nmax;
memory->create(array_atom, nmax, 3, "reaxff/atom:array_atom");
}

// retrieve bond information from kokkos pair style. the data potentially
Expand All @@ -85,6 +85,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_bonds()
else
host_pair()->FindBond(maxnumbonds, groupbit);

const int nlocal = atom->nlocal;
nbuf = ((store_bonds ? maxnumbonds*2 : 0) + 3)*nlocal;

if (!buf || ((int)k_buf.extent(0) < nbuf)) {
Expand Down Expand Up @@ -135,6 +136,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_local()
int b = 0;
int j = 0;
auto tag = atom->tag;
const int nlocal = atom->nlocal;

for (int i = 0; i < nlocal; ++i) {
const int numbonds = static_cast<int>(buf[j+2]);
Expand All @@ -161,6 +163,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_peratom()
compute_bonds();

// extract peratom bond information from buffer
const int nlocal = atom->nlocal;

int j = 0;
for (int i = 0; i < nlocal; ++i) {
Expand All @@ -180,7 +183,7 @@ void ComputeReaxFFAtomKokkos<DeviceType>::compute_peratom()
template<class DeviceType>
double ComputeReaxFFAtomKokkos<DeviceType>::memory_usage()
{
double bytes = (double)(nlocal*3) * sizeof(double);
double bytes = (double)(nmax*3) * sizeof(double);
if (store_bonds)
bytes += (double)(nbonds*3) * sizeof(double);
bytes += (double)(nbuf > 0 ? nbuf * sizeof(double) : 0);
Expand Down
26 changes: 16 additions & 10 deletions src/REAXFF/compute_reaxff_atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ComputeReaxFFAtom::ComputeReaxFFAtom(LAMMPS *lmp, int narg, char **arg) :

// initialize output

nlocal = -1;
nmax = -1;
nbonds = 0;
prev_nbonds = -1;

Expand Down Expand Up @@ -162,20 +162,22 @@ void ComputeReaxFFAtom::compute_bonds()
{
invoked_bonds = update->ntimestep;

if (atom->nlocal > nlocal) {
if (atom->nmax > nmax) {
memory->destroy(abo);
memory->destroy(neighid);
memory->destroy(bondcount);
memory->destroy(array_atom);
nlocal = atom->nlocal;
nmax = atom->nmax;
if (store_bonds) {
memory->create(abo, nlocal, MAXREAXBOND, "reaxff/atom:abo");
memory->create(neighid, nlocal, MAXREAXBOND, "reaxff/atom:neighid");
memory->create(abo, nmax, MAXREAXBOND, "reaxff/atom:abo");
memory->create(neighid, nmax, MAXREAXBOND, "reaxff/atom:neighid");
}
memory->create(bondcount, nlocal, "reaxff/atom:bondcount");
memory->create(array_atom, nlocal, 3, "reaxff/atom:array_atom");
memory->create(bondcount, nmax, "reaxff/atom:bondcount");
memory->create(array_atom, nmax, 3, "reaxff/atom:array_atom");
}

const int nlocal = atom->nlocal;

for (int i = 0; i < nlocal; i++) {
bondcount[i] = 0;
for (int j = 0; store_bonds && j < MAXREAXBOND; j++) {
Expand Down Expand Up @@ -208,6 +210,8 @@ void ComputeReaxFFAtom::compute_local()

int b = 0;

const int nlocal = atom->nlocal;

for (int i = 0; i < nlocal; ++i) {
const int numbonds = bondcount[i];

Expand All @@ -230,6 +234,8 @@ void ComputeReaxFFAtom::compute_peratom()
compute_bonds();
}

const int nlocal = atom->nlocal;

for (int i = 0; i < nlocal; ++i) {
auto ptr = array_atom[i];
ptr[0] = reaxff->api->workspace->total_bond_order[i];
Expand All @@ -244,10 +250,10 @@ void ComputeReaxFFAtom::compute_peratom()

double ComputeReaxFFAtom::memory_usage()
{
double bytes = (double)(nlocal*3) * sizeof(double);
bytes += (double)(nlocal) * sizeof(int);
double bytes = (double)(nmax*3) * sizeof(double);
bytes += (double)(nmax) * sizeof(int);
if (store_bonds) {
bytes += (double)(2*nlocal*MAXREAXBOND) * sizeof(double);
bytes += (double)(2*nmax*MAXREAXBOND) * sizeof(double);
bytes += (double)(nbonds*3) * sizeof(double);
}
return bytes;
Expand Down
2 changes: 1 addition & 1 deletion src/REAXFF/compute_reaxff_atom.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ComputeReaxFFAtom : public Compute {

protected:
bigint invoked_bonds; // last timestep on which compute_bonds() was invoked
int nlocal;
int nmax;
int nbonds;
int prev_nbonds;
int nsub;
Expand Down

0 comments on commit 81609d0

Please sign in to comment.