Skip to content

Commit

Permalink
replace opaque SUNCheckCallLastErr macro with explicit version
Browse files Browse the repository at this point in the history
  • Loading branch information
balos1 committed Nov 6, 2023
1 parent 6bcaa1a commit a6f9327
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 71 deletions.
22 changes: 4 additions & 18 deletions doc/sundials_developers/source/style_guide/SourceCode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ Coding Conventions and Rules

#. Spaces not tabs.

#. All comments should use ``/* */``.

#. Comments should use proper spelling and grammar.

#. Following the Google Style Guide [GoogleStyle]_, TODO comments are used to note
Expand Down Expand Up @@ -254,12 +252,6 @@ Coding Conventions and Rules
/* invalid number of vectors */
SUNAssert(nvec >= 1, SUN_ERR_ARG_OUTOFRANGE);
/* should have called N_VScale */
if (nvec == 1) {
SUNCheckCallLastErr(N_VScale_Serial(c[0], X[0], z));
return SUN_SUCCESS;
}
// ...
}
Expand All @@ -275,12 +267,6 @@ Coding Conventions and Rules
/* invalid number of vectors */
SUNAssert(nvec >= 1, SUN_ERR_ARG_OUTOFRANGE);
/* should have called N_VScale */
if (nvec == 1) {
SUNCheckCallLastErr(N_VScale_Serial(c[0], X[0], z));
return SUN_SUCCESS;
}
// ...
}
Expand Down Expand Up @@ -324,18 +310,18 @@ Coding Conventions and Rules
be followed by checking the last error stored in the ``SUNContext``.
The exception to this rule is for internal helper functions.
These should not be checked unless they return a ``SUNErrCode``.
These checks are done with the ``SUNCheckCallLastErr`` macros.
These checks are done with the ``SUNCheckLastErr`` macros.

.. code-block:: c
// Correct
SUNCheckCallLastErr(N_VLinearSum(...));
N_VLinearSum(...); SUNCheckLastErr();
// Incorrect
SUNCheckCallLastErr(SUNRsqrt(N_VDotProd(...)));
SUNRsqrt(N_VDotProd(...)); SUNCheckLastErr();
// Correct
sunrealtype tmp = SUNCheckCallLastErr(N_VDotProd(...));
sunrealtype tmp = N_VDotProd(...); SUNCheckLastErr();
tmp = SUNRsqrt(tmp);
#. Programmer errors should be checked with the ``SUNAssert`` or ``SUNMPIAssert`` macro.
Expand Down
1 change: 1 addition & 0 deletions include/sundials/impl/sundials_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct SUNContext_
sunbooleantype own_logger;
SUNErrCode last_err;
SUNErrHandler err_handler;
void* comm;
};

#ifdef __cplusplus
Expand Down
26 changes: 11 additions & 15 deletions include/sundials/impl/sundials_errors_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,29 +129,25 @@
value, and calls the error handler. */

#if !defined(SUNDIALS_DISABLE_ERROR_CHECKS)
#define SUNCheckCallLastErrNoRet(call) \
call; \
#define SUNCheckLastErrNoRet() \
SUNCheckCallNoRet(SUNGetLastErr(sunctx_))

/* Same as SUNCheckCallLastErrNoRet, but returns with the error code. */
#define SUNCheckCallLastErr(call) \
call; \
/* Same as SUNCheckLastErrNoRet, but returns with the error code. */
#define SUNCheckLastErr() \
SUNCheckCall(SUNGetLastErr(sunctx_))

/* Same as SUNCheckCallLastErrNoRet, but returns void. */
#define SUNCheckCallLastErrVoid(call) \
call; \
/* Same as SUNCheckLastErrNoRet, but returns void. */
#define SUNCheckLastErrVoid() \
SUNCheckCallVoid(SUNGetLastErr(sunctx_))

/* Same as SUNCheckCallLastErrNoRet, but returns NULL. */
#define SUNCheckCallLastErrNull(call) \
call; \
/* Same as SUNCheckLastErrNoRet, but returns NULL. */
#define SUNCheckLastErrNull() \
SUNCheckCallNull(SUNGetLastErr(sunctx_))
#else
#define SUNCheckCallLastErrNoRet(call) call
#define SUNCheckCallLastErr(call) call
#define SUNCheckCallLastErrVoid(call) call
#define SUNCheckCallLastErrNull(call) call
#define SUNCheckLastErrNoRet()
#define SUNCheckLastErr()
#define SUNCheckLastErrVoid()
#define SUNCheckLastErrNull()
#endif

/* SUNAssert checks if an expression is true.
Expand Down
72 changes: 36 additions & 36 deletions src/sundials/sundials_iterative.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ SUNErrCode SUNModifiedGS(N_Vector* v, sunrealtype** h, int k, int p,
int i, k_minus_1, i0;
sunrealtype new_norm_2, new_product, vk_norm, temp;

vk_norm = SUNCheckCallLastErr((N_VDotProd(v[k], v[k])));
vk_norm = N_VDotProd(v[k], v[k]); SUNCheckLastErr();
vk_norm = SUNRsqrt(vk_norm);
k_minus_1 = k - 1;
i0 = SUNMAX(k - p, 0);
Expand All @@ -55,13 +55,13 @@ SUNErrCode SUNModifiedGS(N_Vector* v, sunrealtype** h, int k, int p,

for (i = i0; i < k; i++)
{
h[i][k_minus_1] = SUNCheckCallLastErr(N_VDotProd(v[i], v[k]));
SUNCheckCallLastErr(N_VLinearSum(ONE, v[k], -h[i][k_minus_1], v[i], v[k]));
h[i][k_minus_1] = N_VDotProd(v[i], v[k]); SUNCheckLastErr();
N_VLinearSum(ONE, v[k], -h[i][k_minus_1], v[i], v[k]); SUNCheckLastErr();
}

/* Compute the norm of the new vector at v[k] */

*new_vk_norm = SUNCheckCallLastErr(N_VDotProd(v[k], v[k]));
*new_vk_norm = N_VDotProd(v[k], v[k]); SUNCheckLastErr();
*new_vk_norm = SUNRsqrt(*new_vk_norm);

/* If the norm of the new vector at v[k] is less than
Expand All @@ -77,11 +77,11 @@ SUNErrCode SUNModifiedGS(N_Vector* v, sunrealtype** h, int k, int p,

for (i = i0; i < k; i++)
{
new_product = SUNCheckCallLastErr(N_VDotProd(v[i], v[k]));
new_product = N_VDotProd(v[i], v[k]); SUNCheckLastErr();
temp = FACTOR * h[i][k_minus_1];
if ((temp + new_product) == temp) continue;
h[i][k_minus_1] += new_product;
SUNCheckCallLastErr(N_VLinearSum(ONE, v[k], -new_product, v[i], v[k]));
N_VLinearSum(ONE, v[k], -new_product, v[i], v[k]); SUNCheckLastErr();
new_norm_2 += SUNSQR(new_product);
}

Expand Down Expand Up @@ -131,7 +131,7 @@ SUNErrCode SUNClassicalGS(N_Vector* v, sunrealtype** h, int k, int p,

/* Compute the norm of the new vector at v[k] */

*new_vk_norm = SUNCheckCallLastErr(SUNRsqrt(N_VDotProd(v[k], v[k])));
*new_vk_norm = SUNRsqrt(N_VDotProd(v[k], v[k])); SUNCheckLastErr();

/* Reorthogonalize if necessary */

Expand All @@ -150,7 +150,7 @@ SUNErrCode SUNClassicalGS(N_Vector* v, sunrealtype** h, int k, int p,

SUNCheckCall(N_VLinearCombination(k + 1, stemp, vtemp, v[k]));

*new_vk_norm = SUNCheckCallLastErr(SUNRsqrt(N_VDotProd(v[k], v[k])));
*new_vk_norm = SUNRsqrt(N_VDotProd(v[k], v[k])); SUNCheckLastErr();
}

return (0);
Expand Down Expand Up @@ -331,15 +331,15 @@ SUNErrCode SUNQRAdd_MGS(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mMa
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp));
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr();
for (j = 0; j < m; j++)
{
R[m * mMax + j] = SUNCheckCallLastErr(N_VDotProd(Q[j], qrdata->vtemp));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -R[m * mMax + j], Q[j], qrdata->vtemp));
R[m * mMax + j] = N_VDotProd(Q[j], qrdata->vtemp); SUNCheckLastErr();
N_VLinearSum(ONE, qrdata->vtemp, -R[m * mMax + j], Q[j], qrdata->vtemp); SUNCheckLastErr();
}
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
Expand All @@ -360,7 +360,7 @@ SUNErrCode SUNQRAdd_ICWY(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mM
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp)); /* stores d_fi in temp */
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr(); /* stores d_fi in temp */

if (m > 0)
{
Expand All @@ -387,14 +387,14 @@ SUNErrCode SUNQRAdd_ICWY(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mM

/* Q(:,k-1) = df - Q_k-1 R(1:k-1,k) */
SUNCheckCall(N_VLinearCombination(m, R + m * mMax, Q, qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp));
N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp); SUNCheckLastErr();
}

/* R(k,k) = \| df \| */
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
/* Q(:,k) = df / \| df \| */
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
Expand All @@ -415,7 +415,7 @@ SUNErrCode SUNQRAdd_ICWY_SB(N_Vector* Q, sunrealtype* R, N_Vector df, int m,
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp)); /* stores d_fi in temp */
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr(); /* stores d_fi in temp */

if (m > 0)
{
Expand Down Expand Up @@ -455,14 +455,14 @@ SUNErrCode SUNQRAdd_ICWY_SB(N_Vector* Q, sunrealtype* R, N_Vector df, int m,

/* Q(:,k-1) = df - Q_k-1 R(1:k-1,k) */
SUNCheckCall(N_VLinearCombination(m, R + m * mMax, Q, qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp));
N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp); SUNCheckLastErr();
}

/* R(k,k) = \| df \| */
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
/* Q(:,k) = df / \| df \| */
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
Expand All @@ -483,7 +483,7 @@ SUNErrCode SUNQRAdd_CGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mM
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp)); /* temp = df */
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr(); /* temp = df */

if (m > 0)
{
Expand All @@ -492,14 +492,14 @@ SUNErrCode SUNQRAdd_CGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mM

/* y = df - Q_k-1 s_k */
SUNCheckCall(N_VLinearCombination(m, R + m * mMax, Q, qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp2));
N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp2); SUNCheckLastErr();

/* z_k = Q_k-1^T y */
SUNCheckCall(N_VDotProdMulti(m, qrdata->vtemp2, Q, qrdata->temp_array));

/* df = y - Q_k-1 z_k */
SUNCheckCall(N_VLinearCombination(m, qrdata->temp_array, Q, Q[m]));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp2, -ONE, Q[m], qrdata->vtemp));
N_VLinearSum(ONE, qrdata->vtemp2, -ONE, Q[m], qrdata->vtemp); SUNCheckLastErr();

/* R(1:k-1,k) = s_k + z_k */
for (j = 0; j < m; j++)
Expand All @@ -509,10 +509,10 @@ SUNErrCode SUNQRAdd_CGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m, int mM
}

/* R(k,k) = \| df \| */
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
/* Q(:,k) = df / R(k,k) */
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
Expand All @@ -533,7 +533,7 @@ SUNErrCode SUNQRAdd_DCGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m,
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp)); /* temp = df */
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr(); /* temp = df */

if (m > 0)
{
Expand All @@ -548,7 +548,7 @@ SUNErrCode SUNQRAdd_DCGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m,
/* Q(:,k-1) = Q(:,k-1) - Q_k-2 s */
SUNCheckCall(N_VLinearCombination(m - 1, qrdata->temp_array, Q,
qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, Q[m - 1], -ONE, qrdata->vtemp2, Q[m - 1]));
N_VLinearSum(ONE, Q[m - 1], -ONE, qrdata->vtemp2, Q[m - 1]); SUNCheckLastErr();

/* R(1:k-2,k-1) = R(1:k-2,k-1) + s */
for (j = 0; j < m - 1; j++)
Expand All @@ -559,14 +559,14 @@ SUNErrCode SUNQRAdd_DCGS2(N_Vector* Q, sunrealtype* R, N_Vector df, int m,

/* df = df - Q(:,k-1) R(1:k-1,k) */
SUNCheckCall(N_VLinearCombination(m, R + m * mMax, Q, qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp));
N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp); SUNCheckLastErr();
}

/* R(k,k) = \| df \| */
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
/* Q(:,k) = df / R(k,k) */
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
Expand All @@ -587,7 +587,7 @@ SUNErrCode SUNQRAdd_DCGS2_SB(N_Vector* Q, sunrealtype* R, N_Vector df, int m,
SUNQRData qrdata = (SUNQRData)QRdata;
SUNAssignSUNCTX(Q[0]->sunctx);

SUNCheckCallLastErr(N_VScale(ONE, df, qrdata->vtemp)); /* temp = df */
N_VScale(ONE, df, qrdata->vtemp); SUNCheckLastErr(); /* temp = df */

if (m > 0)
{
Expand Down Expand Up @@ -616,7 +616,7 @@ SUNErrCode SUNQRAdd_DCGS2_SB(N_Vector* Q, sunrealtype* R, N_Vector df, int m,
/* Q(:,k-1) = Q(:,k-1) - Q_k-2 s */
SUNCheckCall(N_VLinearCombination(m - 1, qrdata->temp_array + m, Q,
qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, Q[m - 1], -ONE, qrdata->vtemp2, Q[m - 1]));
N_VLinearSum(ONE, Q[m - 1], -ONE, qrdata->vtemp2, Q[m - 1]); SUNCheckLastErr();

/* R(1:k-2,k-1) = R(1:k-2,k-1) + s */
for (j = 0; j < m - 1; j++)
Expand All @@ -627,14 +627,14 @@ SUNErrCode SUNQRAdd_DCGS2_SB(N_Vector* Q, sunrealtype* R, N_Vector df, int m,

/* df = df - Q(:,k-1) R(1:k-1,k) */
SUNCheckCall(N_VLinearCombination(m, R + m * mMax, Q, qrdata->vtemp2));
SUNCheckCallLastErr(N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp));
N_VLinearSum(ONE, qrdata->vtemp, -ONE, qrdata->vtemp2, qrdata->vtemp); SUNCheckLastErr();
}

/* R(k,k) = \| df \| */
R[m * mMax + m] = SUNCheckCallLastErr(N_VDotProd(qrdata->vtemp, qrdata->vtemp));
R[m * mMax + m] = N_VDotProd(qrdata->vtemp, qrdata->vtemp); SUNCheckLastErr();
R[m * mMax + m] = SUNRsqrt(R[m * mMax + m]);
/* Q(:,k) = df / R(k,k) */
SUNCheckCallLastErr(N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]));
N_VScale((1 / R[m * mMax + m]), qrdata->vtemp, Q[m]); SUNCheckLastErr();

return SUN_SUCCESS;
}
4 changes: 2 additions & 2 deletions src/sundials/sundials_nvector.c
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ N_Vector* N_VCloneEmptyVectorArray(int count, N_Vector w)
SUNAssert(vs, SUN_ERR_MALLOC_FAIL);

for (j = 0; j < count; j++) {
vs[j] = SUNCheckCallLastErrNoRet(N_VCloneEmpty(w));
vs[j] = N_VCloneEmpty(w); SUNCheckLastErrNoRet();
if (SUNGetLastErr(w->sunctx) < 0) {
N_VDestroyVectorArray(vs, j-1);
return(NULL);
Expand All @@ -1024,7 +1024,7 @@ N_Vector* N_VCloneVectorArray(int count, N_Vector w)
SUNAssert(vs, SUN_ERR_MALLOC_FAIL);

for (j = 0; j < count; j++) {
vs[j] = SUNCheckCallLastErrNoRet(N_VClone(w));
vs[j] = N_VClone(w); SUNCheckLastErrNoRet();
if (SUNGetLastErr(w->sunctx) < 0) {
N_VDestroyVectorArray(vs, j-1);
return(NULL);
Expand Down

0 comments on commit a6f9327

Please sign in to comment.