Skip to content

Commit

Permalink
MRA fix (#131)
Browse files Browse the repository at this point in the history
* Add checks for MRA compatibility

* Evaluate FunctionTree out of bounds to zero
  • Loading branch information
stigrj committed Mar 16, 2020
1 parent 146852c commit 7ef0269
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/treebuilders/add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ void add(double prec,
*
*/
template <int D> void add(double prec, FunctionTree<D> &out, FunctionTreeVector<D> &inp, int maxIter, bool absPrec) {
for (auto i = 0; i < inp.size(); i++)
if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA");

int maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
WaveletAdaptor<D> adaptor(prec, maxScale, absPrec);
Expand Down
6 changes: 6 additions & 0 deletions src/treebuilders/apply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void apply(double prec,
FunctionTree<D> &inp,
int maxIter,
bool absPrec) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");

Timer pre_t;
oper.calcBandWidths(prec);
int maxScale = out.getMRA().getMaxScale();
Expand Down Expand Up @@ -109,6 +111,8 @@ void apply(double prec,
*
*/
template <int D> void apply(FunctionTree<D> &out, DerivativeOperator<D> &oper, FunctionTree<D> &inp, int dir) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");

TreeBuilder<D> builder;
int maxScale = out.getMRA().getMaxScale();

Expand Down Expand Up @@ -181,6 +185,8 @@ template <int D> FunctionTreeVector<D> gradient(DerivativeOperator<D> &oper, Fun
*/
template <int D> void divergence(FunctionTree<D> &out, DerivativeOperator<D> &oper, FunctionTreeVector<D> &inp) {
if (inp.size() != D) MSG_ABORT("Dimension mismatch");
for (auto i = 0; i < inp.size(); i++)
if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA");

FunctionTreeVector<D> tmp_vec;
for (int d = 0; d < D; d++) {
Expand Down
6 changes: 6 additions & 0 deletions src/treebuilders/grid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ template <int D> void build_grid(FunctionTree<D> &out, const GaussExp<D> &inp, i
*
*/
template <int D> void build_grid(FunctionTree<D> &out, FunctionTree<D> &inp, int maxIter) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");
auto maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
CopyAdaptor<D> adaptor(inp, maxScale, nullptr);
Expand Down Expand Up @@ -171,6 +172,9 @@ template <int D> void build_grid(FunctionTree<D> &out, FunctionTree<D> &inp, int
*
*/
template <int D> void build_grid(FunctionTree<D> &out, FunctionTreeVector<D> &inp, int maxIter) {
for (auto i = 0; i < inp.size(); i++)
if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA");

auto maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
CopyAdaptor<D> adaptor(inp, maxScale, nullptr);
Expand Down Expand Up @@ -209,6 +213,7 @@ template <int D> void copy_func(FunctionTree<D> &out, FunctionTree<D> &inp) {
*
*/
template <int D> void copy_grid(FunctionTree<D> &out, FunctionTree<D> &inp) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA")
out.clear();
build_grid(out, inp);
}
Expand Down Expand Up @@ -284,6 +289,7 @@ template <int D> int refine_grid(FunctionTree<D> &out, double prec, bool absPrec
*
*/
template <int D> int refine_grid(FunctionTree<D> &out, FunctionTree<D> &inp) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA")
auto maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
CopyAdaptor<D> adaptor(inp, maxScale, nullptr);
Expand Down
15 changes: 11 additions & 4 deletions src/treebuilders/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ void multiply(double prec,
*/
template <int D>
void multiply(double prec, FunctionTree<D> &out, FunctionTreeVector<D> &inp, int maxIter, bool absPrec) {
for (auto i = 0; i < inp.size(); i++)
if (out.getMRA() != get_func(inp, i).getMRA()) MSG_ABORT("Incompatible MRA");

int maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
WaveletAdaptor<D> adaptor(prec, maxScale, absPrec);
Expand Down Expand Up @@ -148,6 +151,8 @@ void multiply(double prec, FunctionTree<D> &out, FunctionTreeVector<D> &inp, int
*
*/
template <int D> void square(double prec, FunctionTree<D> &out, FunctionTree<D> &inp, int maxIter, bool absPrec) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");

int maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
WaveletAdaptor<D> adaptor(prec, maxScale, absPrec);
Expand Down Expand Up @@ -193,6 +198,8 @@ template <int D> void square(double prec, FunctionTree<D> &out, FunctionTree<D>
*/
template <int D>
void power(double prec, FunctionTree<D> &out, FunctionTree<D> &inp, double p, int maxIter, bool absPrec) {
if (out.getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");

int maxScale = out.getMRA().getMaxScale();
TreeBuilder<D> builder;
WaveletAdaptor<D> adaptor(prec, maxScale, absPrec);
Expand Down Expand Up @@ -246,8 +253,6 @@ void dot(double prec,
double coef_b = get_coef(inp_b, d);
FunctionTree<D> &tree_a = get_func(inp_a, d);
FunctionTree<D> &tree_b = get_func(inp_b, d);
if (out.getMRA() != tree_a.getMRA()) MSG_ABORT("Trees not compatible");
if (out.getMRA() != tree_b.getMRA()) MSG_ABORT("Trees not compatible");
auto *out_d = new FunctionTree<D>(out.getMRA());
build_grid(*out_d, out);
multiply(prec, *out_d, 1.0, tree_a, tree_b, maxIter, absPrec);
Expand All @@ -272,7 +277,8 @@ void dot(double prec,
*
*/
template <int D> double dot(FunctionTree<D> &bra, FunctionTree<D> &ket) {
if (bra.getMRA() != ket.getMRA()) { MSG_ABORT("Trees not compatible"); }
if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Trees not compatible");

MWNodeVector<D> nodeTable;
HilbertIterator<D> it(&bra);
it.setReturnGenNodes(false);
Expand Down Expand Up @@ -318,6 +324,7 @@ template <int D> double dot(FunctionTree<D> &bra, FunctionTree<D> &ket) {
* If the product is zero, the functions are disjoints.
*/
template <int D> double node_norm_dot(FunctionTree<D> &bra, FunctionTree<D> &ket, bool exact) {
if (bra.getMRA() != ket.getMRA()) MSG_ABORT("Incompatible MRA");

double result = 0.0;
int ncoef = bra.getKp1_d() * bra.getTDim();
Expand All @@ -331,7 +338,7 @@ template <int D> double node_norm_dot(FunctionTree<D> &bra, FunctionTree<D> &ket
if (exact) {
// convert to interpolating coef, take abs, convert back
FunctionNode<D> *mwNode = static_cast<FunctionNode<D> *>(ket.findNode(idx));
if (mwNode == nullptr) { MSG_ABORT("Trees must have same grid"); }
if (mwNode == nullptr) MSG_ABORT("Trees must have same grid");
node.getAbsCoefs(valA);
mwNode->getAbsCoefs(valB);
for (int i = 0; i < ncoef; i++) result += valA[i] * valB[i];
Expand Down
1 change: 1 addition & 0 deletions src/treebuilders/project.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ void project(double prec,
std::vector<std::function<double(const Coord<D> &r)>> func,
int maxIter,
bool absPrec) {
if (out.size() != func.size()) MSG_ABORT("Size mismatch");
for (auto j = 0; j < D; j++) mrcpp::project<D>(prec, get_func(out, j), func[j], maxIter, absPrec);
}

Expand Down
10 changes: 8 additions & 2 deletions src/trees/FunctionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ namespace mrcpp {
template <int D>
FunctionTree<D>::FunctionTree(const MultiResolutionAnalysis<D> &mra, SharedMemory *sh_mem)
: MWTree<D>(mra)
, RepresentableFunction<D>(nullptr, nullptr) {
, RepresentableFunction<D>(mra.getWorldBox().getLowerBounds().data(),
mra.getWorldBox().getUpperBounds().data()) {
this->serialTree_p = new SerialFunctionTree<D>(this, sh_mem);
this->serialTree_p->allocRoots(*this);
this->resetEndNodeTable();
Expand Down Expand Up @@ -191,7 +192,7 @@ template <int D> double FunctionTree<D>::integrate() const {
return jacobian * result;
}

/** @returns Function value in a point
/** @returns Function value in a point, out of bounds returns zero
*
* @param[in] r: Cartesian coordinate
*
Expand All @@ -218,6 +219,9 @@ template <int D> double FunctionTree<D>::evalf(const Coord<D> &r) const {
// always 1.0 from the point of view of this function.
if (this->getRootBox().isPeriodic()) { periodic::coord_manipulation<D>(arg, this->getRootBox().getPeriodic()); }

// Function is zero outside the domain
if (this->outOfBounds(arg)) return 0.0;

const MWNode<D> &mw_node = this->getNodeOrEndNode(arg);
auto &f_node = static_cast<const FunctionNode<D> &>(mw_node);
auto result = f_node.evalScaling(arg);
Expand Down Expand Up @@ -329,6 +333,7 @@ template <int D> void FunctionTree<D>::normalize() {
*
*/
template <int D> void FunctionTree<D>::add(double c, FunctionTree<D> &inp) {
if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c), shared(inp)
{
Expand Down Expand Up @@ -358,6 +363,7 @@ template <int D> void FunctionTree<D>::add(double c, FunctionTree<D> &inp) {
*
*/
template <int D> void FunctionTree<D>::multiply(double c, FunctionTree<D> &inp) {
if (this->getMRA() != inp.getMRA()) MSG_ABORT("Incompatible MRA");
if (this->getNGenNodes() != 0) MSG_ABORT("GenNodes not cleared");
#pragma omp parallel firstprivate(c), shared(inp)
{
Expand Down

0 comments on commit 7ef0269

Please sign in to comment.