Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ODESolvers: Solve only in the interior #301

Merged
merged 13 commits into from
Aug 1, 2024
Merged
22 changes: 15 additions & 7 deletions FluxWaveToyX/schedule.ccl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ SCHEDULE FluxWaveToyX_Initial AT initial
{
LANG: C
WRITES: state(interior)
SYNC: state
} "Initialize scalar wave state"

SCHEDULE FluxWaveToyX_Fluxes IN ODESolvers_RHS
Expand All @@ -16,6 +15,7 @@ SCHEDULE FluxWaveToyX_Fluxes IN ODESolvers_RHS
WRITES: flux_x(interior)
WRITES: flux_y(interior)
WRITES: flux_z(interior)
# Sync for test output
SYNC: flux_x
SYNC: flux_y
SYNC: flux_z
Expand All @@ -24,23 +24,31 @@ SCHEDULE FluxWaveToyX_Fluxes IN ODESolvers_RHS
SCHEDULE FluxWaveToyX_RHS IN ODESolvers_RHS AFTER FluxWaveToyX_Fluxes
{
LANG: C
READS: state(everywhere)
READS: flux_x(everywhere)
READS: flux_y(everywhere)
READS: flux_z(everywhere)
READS: state(interior)
READS: flux_x(interior)
READS: flux_y(interior)
READS: flux_z(interior)
WRITES: rhs(interior)
# Sync for test output
SYNC: rhs
} "Calculate scalar wave RHS"

SCHEDULE FluxWaveToyX_Constraints IN ODESolvers_PostStep
SCHEDULE FluxWaveToyX_Boundaries IN ODESolvers_PostStep
{
LANG: C
OPTIONS: global
SYNC: state
} "Apply boundary conditions"

SCHEDULE FluxWaveToyX_Constraints IN ODESolvers_PostStep AFTER FluxWaveToyX_Boundaries
{
LANG: C
READS: state(everywhere)
WRITES: cons(interior)
SYNC: cons
} "Calculate scalar wave constraints"

SCHEDULE FluxWaveToyX_Energy IN ODESolvers_PostStep
SCHEDULE FluxWaveToyX_Energy IN ODESolvers_PostStep AFTER FluxWaveToyX_Boundaries
{
LANG: C
READS: state(everywhere)
Expand Down
7 changes: 7 additions & 0 deletions FluxWaveToyX/src/fluxwavetoyx.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,13 @@ extern "C" void FluxWaveToyX_RHS(CCTK_ARGUMENTS) {
});
}

extern "C" void FluxWaveToyX_Boundaries(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_FluxWaveToyX_Boundaries;
DECLARE_CCTK_PARAMETERS;

// Do nothing
}

extern "C" void FluxWaveToyX_Constraints(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_FluxWaveToyX_Constraints;

Expand Down
85 changes: 45 additions & 40 deletions ODESolvers/src/solve.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ struct statecomp_t {
template <size_t N>
static void combine_valids(const statecomp_t &dst, const CCTK_REAL scale,
const array<CCTK_REAL, N> &factors,
const array<const statecomp_t *, N> &srcs);
const array<const statecomp_t *, N> &srcs,
const valid_t where);
void check_valid(const valid_t required, const function<string()> &why) const;
void check_valid(const valid_t required, const string &why) const {
check_valid(required, [=]() { return why; });
Expand Down Expand Up @@ -189,7 +190,8 @@ void statecomp_t::set_valid(const valid_t valid) const {
template <size_t N>
void statecomp_t::combine_valids(const statecomp_t &dst, const CCTK_REAL scale,
const array<CCTK_REAL, N> &factors,
const array<const statecomp_t *, N> &srcs) {
const array<const statecomp_t *, N> &srcs,
const valid_t where) {
const int ngroups = dst.groupdatas.size();
for (const auto &src : srcs)
assert(int(src->groupdatas.size()) == ngroups);
Expand All @@ -207,7 +209,7 @@ void statecomp_t::combine_valids(const statecomp_t &dst, const CCTK_REAL scale,
const int nvars = dstgroup->numvars;
const int tl = 0;
for (int vi = 0; vi < nvars; ++vi) {
valid_t valid = valid_t(true);
valid_t valid = where;
bool did_set_valid = false;
if (scale != 0) {
valid &= dstgroup->valid.at(tl).at(vi).get();
Expand Down Expand Up @@ -255,26 +257,29 @@ statecomp_t statecomp_t::copy(const valid_t where) const {
result.mfabs.reserve(size);
for (size_t n = 0; n < size; ++n) {
const auto groupdata = groupdatas.at(n);
#ifdef CCTK_DEBUG
const auto &x = mfabs.at(n);
if (x->contains_nan())
CCTK_VERROR("statecomp_t::copy.x: Group %s contains nans",
groupdata->groupname.c_str());
#endif
// This global nan-check doesn't work since we don't care about the
// boundaries
// #ifdef CCTK_DEBUG
// const auto &x = mfabs.at(n);
// if (x->contains_nan())
// CCTK_VERROR("statecomp_t::copy.x: Group %s contains nans",
// groupdata->groupname.c_str());
// #endif
auto y = groupdata->alloc_tmp_mfab();
result.groupdatas.push_back(groupdata);
result.mfabs.push_back(y);
}
lincomb(result, 0, make_array(CCTK_REAL(1)), make_array(this), where);
#ifdef CCTK_DEBUG
for (size_t n = 0; n < size; ++n) {
const auto groupdata = result.groupdatas.at(n);
const auto &y = result.mfabs.at(n);
if (y->contains_nan())
CCTK_VERROR("statecomp_t::copy.y: Group %s contains nans",
groupdata->groupname.c_str());
}
#endif
// This global nan-check doesn't work since we don't care about the boundaries
// #ifdef CCTK_DEBUG
// for (size_t n = 0; n < size; ++n) {
// const auto groupdata = result.groupdatas.at(n);
// const auto &y = result.mfabs.at(n);
// if (y->contains_nan())
// CCTK_VERROR("statecomp_t::copy.y: Group %s contains nans",
// groupdata->groupname.c_str());
// }
// #endif
return result;
}

Expand All @@ -300,7 +305,7 @@ void statecomp_t::lincomb(const statecomp_t &dst, const CCTK_REAL scale,
for (size_t n = 0; n < N; ++n)
assert(isfinite(factors[n]));

statecomp_t::combine_valids(dst, scale, factors, srcs);
statecomp_t::combine_valids(dst, scale, factors, srcs, where);

#ifndef AMREX_USE_GPU
vector<function<void()> > tasks;
Expand Down Expand Up @@ -746,8 +751,8 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
static Timer timer_rhs("ODESolvers::Solve::rhs");
static Timer timer_poststep("ODESolvers::Solve::poststep");

const auto copy_state = [](const auto &var) {
return var.copy(make_valid_int());
const auto copy_state = [](const auto &var, const valid_t where) {
return var.copy(where);
};
const auto calcrhs = [&](const int n) {
Interval interval_rhs(timer_rhs);
Expand Down Expand Up @@ -801,7 +806,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
// k2 = f(y0 + h/2 k1)
// y1 = y0 + h k2

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

calcrhs(1);
calcupdate(1, dt / 2, 1.0, reals<1>{dt / 2}, states<1>{&rhs});
Expand All @@ -816,14 +821,14 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
// k3 = f(y0 - h k1 + 2 h k2)
// y1 = y0 + h/6 k1 + 2/3 h k2 + h/6 k3

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

calcrhs(1);
const auto k1 = copy_state(rhs);
const auto k1 = copy_state(rhs, make_valid_int());
calcupdate(1, dt / 2, 1.0, reals<1>{dt / 2}, states<1>{&k1});

calcrhs(2);
const auto k2 = copy_state(rhs);
const auto k2 = copy_state(rhs, make_valid_int());
calcupdate(2, dt, 0.0, reals<3>{1.0, -dt, 2 * dt},
states<3>{&old, &k1, &k2});

Expand All @@ -838,14 +843,14 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
// k3 = f(y0 + h/4 k1 + h/4 k2)
// y1 = y0 + h/6 k1 + h/6 k2 + 2/3 h k3

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

calcrhs(1);
const auto k1 = copy_state(rhs);
const auto k1 = copy_state(rhs, make_valid_int());
calcupdate(1, dt, 1.0, reals<1>{dt}, states<1>{&k1});

calcrhs(2);
const auto k2 = copy_state(rhs);
const auto k2 = copy_state(rhs, make_valid_int());
calcupdate(2, dt / 2, 0.0, reals<3>{1.0, dt / 4, dt / 4},
states<3>{&old, &k1, &k2});

Expand All @@ -861,10 +866,10 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
// k4 = f(y0 + h k3)
// y1 = y0 + h/6 k1 + h/3 k2 + h/3 k3 + h/6 k4

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

calcrhs(1);
const auto kaccum = copy_state(rhs);
const auto kaccum = copy_state(rhs, make_valid_int());
calcupdate(1, dt / 2, 1.0, reals<1>{dt / 2}, states<1>{&kaccum});

calcrhs(2);
Expand Down Expand Up @@ -948,15 +953,15 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
assert(fabs(x - 1) <= 10 * numeric_limits<T>::epsilon());
}

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

vector<statecomp_t> ks;
ks.reserve(nsteps);
for (size_t step = 0; step < nsteps; ++step) {
// Skip the first state vector calculation, it is always trivial
if (step > 0) {
const auto &c = get<0>(get<0>(tableau).at(step));
const auto &as = get<1>(get<0>(tableau).at(step));
const auto &c = get<0>(get<0>(tableau).at(step));
const auto &as = get<1>(get<0>(tableau).at(step));

// Add scaled RHS to state vector
vector<CCTK_REAL> factors;
Expand All @@ -976,7 +981,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
}

calcrhs(step + 1);
ks.push_back(copy_state(rhs));
ks.push_back(copy_state(rhs, make_valid_int()));
}

// Calculate new state vector
Expand Down Expand Up @@ -1058,17 +1063,17 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
assert(fabs(x - 1) <= 10 * numeric_limits<T>::epsilon());
}

const auto old = copy_state(var);
const auto old = copy_state(var, make_valid_all());

vector<statecomp_t> ks;
ks.reserve(nsteps);
for (size_t step = 0; step < nsteps; ++step) {
// Skip the first state vector calculation, it is always trivial
if (step > 0) {
const auto &as = get<0>(tableau).at(step);
T c = 0;
for (const auto &a : as)
c += a;
const auto &as = get<0>(tableau).at(step);
T c = 0;
for (const auto &a : as)
c += a;

// Add scaled RHS to state vector
vector<CCTK_REAL> factors;
Expand All @@ -1088,7 +1093,7 @@ extern "C" void ODESolvers_Solve(CCTK_ARGUMENTS) {
}

calcrhs(step + 1);
ks.push_back(copy_state(rhs));
ks.push_back(copy_state(rhs, make_valid_int()));
}

// Calculate new state vector
Expand Down
11 changes: 9 additions & 2 deletions SIMDWaveToyX/schedule.ccl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,25 @@ SCHEDULE SIMDWaveToyX_Initial AT initial
{
LANG: C
WRITES: state(interior)
SYNC: state
} "Initialize scalar wave state"

SCHEDULE SIMDWaveToyX_RHS IN ODESolvers_RHS
{
LANG: C
READS: state(everywhere)
WRITES: rhs(interior)
# Sync for test output
SYNC: rhs
} "Calculate scalar wave RHS"

SCHEDULE SIMDWaveToyX_Energy IN ODESolvers_PostStep
SCHEDULE SIMDWaveToyX_Boundaries IN ODESolvers_PostStep
{
LANG: C
OPTIONS: global
SYNC: state
} "Apply boundary conditions"

SCHEDULE SIMDWaveToyX_Energy IN ODESolvers_PostStep AFTER SIMDWaveToyX_Boundaries
{
LANG: C
READS: state(everywhere)
Expand Down
7 changes: 7 additions & 0 deletions SIMDWaveToyX/src/simdwavetoyx.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ extern "C" void SIMDWaveToyX_RHS(CCTK_ARGUMENTS) {
});
}

extern "C" void SIMDWaveToyX_Boundaries(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_SIMDWaveToyX_Boundaries;
DECLARE_CCTK_PARAMETERS;

// Do nothing
}

extern "C" void SIMDWaveToyX_Energy(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_SIMDWaveToyX_Energy;

Expand Down
1 change: 1 addition & 0 deletions StaggeredWaveToyX/par/standing.par
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ $out_every = 16
Cactus::cctk_show_schedule = no
Cactus::presync_mode = "mixed-error"

CarpetX::verbose = no
CarpetX::poison_undefined_values = no

CarpetX::periodic_x = yes
Expand Down
15 changes: 11 additions & 4 deletions StaggeredWaveToyX/schedule.ccl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ SCHEDULE StaggeredWaveToyX_Initial AT initial
WRITES: fxstate(interior)
WRITES: fystate(interior)
WRITES: fzstate(interior)
SYNC: ustate
SYNC: fxstate
SYNC: fystate
SYNC: fzstate
} "Initialize scalar wave state"

SCHEDULE StaggeredWaveToyX_RHS IN ODESolvers_RHS
Expand All @@ -26,12 +22,23 @@ SCHEDULE StaggeredWaveToyX_RHS IN ODESolvers_RHS
WRITES: fxrhs(interior)
WRITES: fyrhs(interior)
WRITES: fzrhs(interior)
# Sync for test output
SYNC: urhs
SYNC: fxrhs
SYNC: fyrhs
SYNC: fzrhs
} "Calculate scalar wave RHS"

SCHEDULE StaggeredWaveToyX_Boundaries IN ODESolvers_PostStep
{
LANG: C
OPTIONS: global
SYNC: ustate
SYNC: fxstate
SYNC: fystate
SYNC: fzstate
} "Apply boundary conditions to scalar wave state"

SCHEDULE StaggeredWaveToyX_Constraints IN ODESolvers_PostStep
{
LANG: C
Expand Down
6 changes: 6 additions & 0 deletions StaggeredWaveToyX/src/staggeredwavetoyx.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ extern "C" void StaggeredWaveToyX_RHS(CCTK_ARGUMENTS) {
});
}

extern "C" void StaggeredWaveToyX_Boundaries(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_StaggeredWaveToyX_Boundaries;

// Do nothing
}

extern "C" void StaggeredWaveToyX_Constraints(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTSX_StaggeredWaveToyX_Constraints;

Expand Down
1 change: 1 addition & 0 deletions StaggeredWaveToyX/test/standing.par
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ $out_every = 16
Cactus::cctk_show_schedule = no
Cactus::presync_mode = "mixed-error"

CarpetX::verbose = no
CarpetX::poison_undefined_values = yes

CarpetX::ncells_x = 8
Expand Down
Loading
Loading