Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow arena_matrix to use move semantics #2928

Merged
merged 73 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
00a9f43
Adds move semantics for arena matrix types
SteveBronder Aug 1, 2023
056e297
use forwarding in sum
SteveBronder Aug 2, 2023
3a6c11e
add docs related to auto dangers with the math library
SteveBronder Aug 2, 2023
fa8a464
add docs related to auto dangers with the math library
SteveBronder Aug 2, 2023
3f07884
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 2, 2023
e5e226b
newline
SteveBronder Aug 2, 2023
d66673b
Merge remote-tracking branch 'refs/remotes/origin/feature/reverse-mod…
SteveBronder Aug 2, 2023
53dc399
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 2, 2023
fcca84d
fix constructor alias bug
SteveBronder Aug 2, 2023
c13855f
Merge remote-tracking branch 'refs/remotes/origin/feature/reverse-mod…
SteveBronder Aug 2, 2023
2f42a0a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 2, 2023
4189367
fix normal_lpdf templates
SteveBronder Aug 3, 2023
b297df5
Merge remote-tracking branch 'refs/remotes/origin/feature/reverse-mod…
SteveBronder Aug 3, 2023
5b76fc8
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 3, 2023
5fbaf55
fix transpose issues with arena matrix
SteveBronder Aug 4, 2023
1bfd431
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
SteveBronder Aug 4, 2023
119099d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 4, 2023
90b23ab
cleanup after reduce_sum is called in tests
SteveBronder Aug 7, 2023
9ce5c50
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 7, 2023
1871c64
remove tmp sundials files
SteveBronder Aug 7, 2023
42cef50
Merge branch 'feature/reverse-mode-move-semantics' of github.com:stan…
SteveBronder Aug 7, 2023
0ec5253
use agradrev in mix/probs test
SteveBronder Aug 8, 2023
d6b892d
use forwarding in normal_lpdf functions
SteveBronder Aug 17, 2023
984bdf8
Merge commit 'd4eab2773347ca6fbe03d49f70828c08ff248269' into HEAD
yashikno Aug 17, 2023
a85e786
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Aug 17, 2023
1b0504a
fix typo
SteveBronder Aug 17, 2023
4f926cf
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
SteveBronder Jan 3, 2024
0d34e03
update docs
SteveBronder Jan 3, 2024
3193ad4
merge from develop
SteveBronder Feb 29, 2024
b37d163
remove double include for hypergeo2f1
SteveBronder Feb 29, 2024
36e0bd3
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
SteveBronder Mar 22, 2024
eb6276c
update constructors and assignment operators for arena_matrix
SteveBronder Mar 22, 2024
8acdb6d
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Mar 22, 2024
11da0dd
update get_rows and get_cols for arena_matrix
SteveBronder Mar 22, 2024
8a1f9f0
update docs
SteveBronder Mar 22, 2024
c29860e
update to develop
SteveBronder Apr 2, 2024
0707438
only allow the move operator for arena matrix types if the input type…
SteveBronder Apr 2, 2024
8b96c45
fixes aos csr matrix bug, still debugging soa matrix bug
Apr 12, 2024
0a168c3
Merge remote-tracking branch 'origin/develop' into fix/csr-matrix-tim…
SteveBronder Apr 15, 2024
fec3689
update csr matrix multiply to avoid linker error for windows. Adds to…
SteveBronder Apr 15, 2024
6b8ae15
small fixes
SteveBronder Apr 15, 2024
c83cbfc
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 15, 2024
f594b2c
fix header error
SteveBronder Apr 15, 2024
d1feb19
use static_cast for bool conversion in sparse matrix loops
SteveBronder Apr 15, 2024
33f0825
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 15, 2024
5c5dfc6
update csr_matrix_times_vector w adjoint update. Uncomment tests for …
SteveBronder Apr 18, 2024
b2af1cd
Merge commit '11663a2e79e6dc4286ebf1399573a7048667b1c5' into HEAD
yashikno Apr 18, 2024
b0815c4
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 18, 2024
515d621
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
SteveBronder Apr 18, 2024
a604fa4
update header includes for var_test
SteveBronder Apr 18, 2024
6a71cfb
ad require_not_arena_matrix_t
SteveBronder Apr 18, 2024
a2124c1
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 18, 2024
1e906d9
fix definition of require_not_arena_matrix_t
SteveBronder Apr 18, 2024
de7d11e
fix definition of require_not_arena_matrix_t
SteveBronder Apr 18, 2024
d89610e
Merge remote-tracking branch 'origin/feature/reverse-mode-move-semant…
SteveBronder Apr 18, 2024
eee02d8
merge
SteveBronder Apr 18, 2024
c5f983a
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 18, 2024
86a3e83
Merge pull request #3048 from stan-dev/fix/csr-matrix-times-vector
SteveBronder Apr 19, 2024
1f25ef7
Initial updates for winarm64
andrjohns Apr 21, 2024
c9f76db
Fix path error
andrjohns Apr 21, 2024
7a5a009
Update comments
andrjohns Apr 21, 2024
df305d6
Document TBB changes
andrjohns Apr 21, 2024
08d8a22
Merge pull request #3051 from stan-dev/tbb-winarm64
WardBrian Apr 22, 2024
34cf554
use a seperate class for csr_matrix adjoint
SteveBronder Apr 24, 2024
a3a88a5
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 24, 2024
f748825
update docs for new vari for csr_matrix_times_vector
SteveBronder Apr 25, 2024
04124da
[Jenkins] auto-formatting by clang-format version 10.0.0-4ubuntu1
stan-buildbot Apr 25, 2024
9f759e1
Merge pull request #3053 from stan-dev/fix/csr-matrix-seperate-vari
SteveBronder Apr 26, 2024
045073f
Don't set build and clean rules for sundials if external libs used
andrjohns Apr 26, 2024
e73651b
Merge pull request #3054 from stan-dev/sundials-targets
WardBrian Apr 26, 2024
7a9601d
Merge remote-tracking branch 'origin/develop' into feature/reverse-mo…
SteveBronder Apr 26, 2024
d45dff2
update to 5.0
SteveBronder Apr 26, 2024
91ea4c1
fix docs
SteveBronder Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doxygen/contributor_help_pages/common_pitfalls.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,37 @@ The general rules to follow for passing values to a function are:
2. If you are writing a function for reverse mode, pass values by `const&`
3. In prim, if you are confident and working with larger types, use perfect forwarding to pass values that can be moved from. Otherwise simply pass values by `const&`.

### Using auto is Dangerous With Eigen Matrix Functions in Reverse Mode

The use of auto with the Stan Math library should be used with care, like in [Eigen](https://eigen.tuxfamily.org/dox/TopicPitfalls.html). Along with the cautions mentioned in the Eigen docs, there are also memory considerations when using reverse mode automatic differentiation. When returning from a function in the Stan math library with an Eigen matrix output with a scalar `var` type, the actual returned type will often be an `arena_matrix<Eigen::Matrix<...>>`. The `arena_matrix` class is an Eigen matrix where the underlying array of memory is located in Stan's memory arena. The `arena_matrix` that is returned by Stan functions is normally the same one resting in the callback used to calculate gradients in the reverse pass. Directly changing the elements of this matrix would also change the memory the reverse pass callback sees which would result in incorrect calculations.

The simple solution to this is that when you use a math library function that returns a matrix and then want to assign to any of the individual elements of the matrix, assign to an actual Eigen matrix type instead of using auto. In the below example, we see the first case which uses auto and will change the memory of the `arena_matrix` returned in the callback for multiply's reverse mode. Directly below it is the safe version, which just directly assigns to an Eigen matrix type and is safe to do element insertion into.

```c++
Eigen::Matrix<var, -1, 1> y;
Eigen::Matrix<var, -1, -1> X;
// Bad!! Will change memory used by reverse pass callback within multiply!
auto mu = multiply(X, y);
mu(4) = 1.0;
// Good! Will not change memory used by reverse pass callback within multiply
Eigen::Matrix<var, -1, 1> mu_good = multiply(X, y);
mu_good(4) = 1.0;
```

The reason we do this is for cases where functions returns are passe to other functions. An `arena_matrix` will always make a shallow copy when being constructed from another `arena_matrix`, which let's the functions avoid unnecessary copies.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The reason we do this is for cases where functions returns are passe to other functions. An `arena_matrix` will always make a shallow copy when being constructed from another `arena_matrix`, which let's the functions avoid unnecessary copies.
The reason we do this is for cases where function returns are passed to other functions. An `arena_matrix` will always make a shallow copy when being constructed from another `arena_matrix`, which lets the functions avoid unnecessary copies.


```c++
Eigen::Matrix<var, -1, 1> y1;
Eigen::Matrix<var, -1, -1> X1;
Eigen::Matrix<var, -1, 1> y2;
Eigen::Matrix<var, -1, -1> X2;
auto mu1 = multiply(X1, y1);
auto mu2 = multiply(X2, y2);
// Inputs not copied in this case!
auto z = add(mu1, mu2);
```


### Passing variables that need destructors called after the reverse pass (`make_chainable_ptr`)

When possible, non-arena variables should be copied to the arena to be used in the reverse pass.
Expand Down
16 changes: 8 additions & 8 deletions stan/math/prim/prob/normal_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ namespace math {
* @tparam T_loc Type of location parameter.
*/
template <bool propto, typename T_y, typename T_loc, typename T_scale>
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
return normal_lpdf<propto, T_y, T_loc, T_scale>(y, mu, sigma);
inline return_type_t<T_y, T_loc, T_scale> normal_log(T_y&& y, T_loc&& mu,
T_scale&& sigma) {
return normal_lpdf<propto>(std::forward<T_y>(y), std::forward<T_loc>(mu),
std::forward<T_scale>(sigma));
}

/** \ingroup prob_dists
* @deprecated use <code>normal_lpdf</code>
*/
template <typename T_y, typename T_loc, typename T_scale>
inline return_type_t<T_y, T_loc, T_scale> normal_log(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
return normal_lpdf<T_y, T_loc, T_scale>(y, mu, sigma);
inline return_type_t<T_y, T_loc, T_scale> normal_log(T_y&& y, T_loc&& mu,
T_scale&& sigma) {
return normal_lpdf(std::forward<T_y>(y), std::forward<T_loc>(mu),
std::forward<T_scale>(sigma));
}

} // namespace math
Expand Down
23 changes: 11 additions & 12 deletions stan/math/prim/prob/normal_lpdf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,18 @@ namespace math {
template <bool propto, typename T_y, typename T_loc, typename T_scale,
require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
T_y, T_loc, T_scale>* = nullptr>
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(T_y&& y, T_loc&& mu,
T_scale&& sigma) {
using T_partials_return = partials_return_t<T_y, T_loc, T_scale>;
using T_y_ref = ref_type_if_not_constant_t<T_y>;
using T_mu_ref = ref_type_if_not_constant_t<T_loc>;
using T_sigma_ref = ref_type_if_not_constant_t<T_scale>;
static constexpr const char* function = "normal_lpdf";
check_consistent_sizes(function, "Random variable", y, "Location parameter",
mu, "Scale parameter", sigma);
T_y_ref y_ref = y;
T_mu_ref mu_ref = mu;
T_sigma_ref sigma_ref = sigma;
T_y_ref y_ref = std::forward<T_y>(y);
T_mu_ref mu_ref = std::forward<T_loc>(mu);
T_sigma_ref sigma_ref = std::forward<T_scale>(sigma);

decltype(auto) y_val = to_ref(as_value_column_array_or_scalar(y_ref));
decltype(auto) mu_val = to_ref(as_value_column_array_or_scalar(mu_ref));
Expand All @@ -63,7 +62,7 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
check_finite(function, "Location parameter", mu_val);
check_positive(function, "Scale parameter", sigma_val);

if (size_zero(y, mu, sigma)) {
if (size_zero(y_ref, mu_ref, sigma_ref)) {
return 0.0;
}
if (!include_summand<propto, T_y, T_loc, T_scale>::value) {
Expand All @@ -78,7 +77,7 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
const auto& y_scaled_sq
= to_ref_if<!is_constant_all<T_scale>::value>(y_scaled * y_scaled);

size_t N = max_size(y, mu, sigma);
size_t N = max_size(y_ref, mu_ref, sigma_ref);
T_partials_return logp = -0.5 * sum(y_scaled_sq);
if (include_summand<propto>::value) {
logp += NEG_LOG_SQRT_TWO_PI * N;
Expand Down Expand Up @@ -106,10 +105,10 @@ inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
}

template <typename T_y, typename T_loc, typename T_scale>
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(const T_y& y,
const T_loc& mu,
const T_scale& sigma) {
return normal_lpdf<false>(y, mu, sigma);
inline return_type_t<T_y, T_loc, T_scale> normal_lpdf(T_y&& y, T_loc&& mu,
T_scale&& sigma) {
return normal_lpdf<false>(std::forward<T_y>(y), std::forward<T_loc>(mu),
std::forward<T_scale>(sigma));
}

} // namespace math
Expand Down
87 changes: 70 additions & 17 deletions stan/math/rev/core/arena_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/rev/core/chainable_alloc.hpp>
#include <stan/math/rev/core/chainablestack.hpp>

#include <stan/math/rev/core/chainable_object.hpp>
namespace stan {
namespace math {

Expand Down Expand Up @@ -54,7 +54,7 @@ class arena_matrix : public Eigen::Map<MatrixType> {
size) {}

/**
* Constructs `arena_matrix` from an expression.
* Constructs `arena_matrix` from an expression
* @param other expression
*/
template <typename T, require_eigen_t<T>* = nullptr>
Expand All @@ -73,6 +73,50 @@ class arena_matrix : public Eigen::Map<MatrixType> {
*this = other;
}

/**
* Constructs `arena_matrix` from an expression, then send it to either the
* object stack or memory arena.
* @tparam T A type that inherits from Eigen::DenseBase that is not an
* `arena_matrix`.
* @param other expression
* @note When T is both an rvalue and a plain type, the expression is moved to
* the object stack. However when T is an lvalue, or an rvalue that is not a
* plain type, the expression is copied to the memory arena.
*/
template <typename T, require_eigen_t<T>* = nullptr,
require_not_arena_matrix_t<T>* = nullptr>
arena_matrix(T&& other) // NOLINT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to split this to two constructors using require_t<std::is_rvalue_reference<T>::value> rather than having the constructor instantiate and call a lambda? Just feels a little unnecessarily complex

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's much cleaner. Though I still kept the immediately evaluated lambda as that keeps the constructor empty which I find kind of nice

: Base::Map([](auto&& x) {
using base_map_t =
typename stan::math::arena_matrix<MatrixType>::Base;
using T_t = std::decay_t<T>;
if (std::is_rvalue_reference<decltype(x)>::value
&& is_plain_type<T_t>::value) {
// Note: plain_type_t here does nothing since T_t is plain type
auto other
= make_chainable_ptr(plain_type_t<MatrixType>(std::move(x)));
// other has it's rows and cols swapped already if it needed that
return base_map_t(&(other->coeffRef(0)), other->rows(),
other->cols());
} else {
base_map_t map(
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(
x.size()),
(RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1)
|| (ColsAtCompileTime == 1
&& T_t::RowsAtCompileTime == 1)
? x.cols()
: x.rows(),
(RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1)
|| (ColsAtCompileTime == 1
&& T_t::RowsAtCompileTime == 1)
? x.rows()
: x.cols());
map = x;
return map;
}
}(std::forward<T>(other))) {}

/**
* Constructs `arena_matrix` from an expression. This makes an assumption that
* any other `Eigen::Map` also contains memory allocated in the arena.
Expand Down Expand Up @@ -110,23 +154,32 @@ class arena_matrix : public Eigen::Map<MatrixType> {
* @param a expression to evaluate into this
* @return `*this`
*/
template <typename T>
arena_matrix& operator=(const T& a) {
// do we need to transpose?
if ((RowsAtCompileTime == 1 && T::ColsAtCompileTime == 1)
|| (ColsAtCompileTime == 1 && T::RowsAtCompileTime == 1)) {
// placement new changes what data map points to - there is no allocation
new (this) Base(
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()),
a.cols(), a.rows());

template <typename T, require_not_arena_matrix_t<T>* = nullptr>
arena_matrix& operator=(T&& a) {
using T_t = std::decay_t<T>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be plain_type_t instead of just decaying? I don't think all expressions will have a ColsAtCompileTime/RowsAtCompileTime that can be called

if (std::is_rvalue_reference<T&&>::value && is_plain_type<T_t>::value) {
// Note: plain_type_t here does nothing since T_t is plain type
auto other = make_chainable_ptr(plain_type_t<MatrixType>(std::move(a)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the plain_type_t is needed for some reason even though it does nothing, can you change the comment to explain why? Otherwise better to just remove the plain_type_t

new (this) Base(&(other->coeffRef(0)), other->rows(), other->cols());
return *this;
} else {
new (this) Base(
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()),
a.rows(), a.cols());
// do we need to transpose?
if ((RowsAtCompileTime == 1 && T_t::ColsAtCompileTime == 1)
|| (ColsAtCompileTime == 1 && T_t::RowsAtCompileTime == 1)) {
// placement new changes what data map points to - there is no
// allocation
new (this) Base(
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()),
a.cols(), a.rows());

} else {
new (this) Base(
ChainableStack::instance_->memalloc_.alloc_array<Scalar>(a.size()),
a.rows(), a.cols());
}
Base::operator=(a);
return *this;
}
Base::operator=(a);
return *this;
}
/**
* Forces hard copying matrices into an arena matrix
Expand Down
6 changes: 2 additions & 4 deletions stan/math/rev/core/chainable_object.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#ifndef STAN_MATH_REV_CORE_CHAINABLE_OBJECT_HPP
#define STAN_MATH_REV_CORE_CHAINABLE_OBJECT_HPP

#include <stan/math/rev/meta.hpp>
#include <stan/math/rev/core/chainable_alloc.hpp>
#include <stan/math/rev/core/typedefs.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/typedefs.hpp>
#include <stan/math/prim/meta.hpp>
#include <stan/math/rev/core/chainable_alloc.hpp>
#include <vector>

namespace stan {
Expand Down
34 changes: 17 additions & 17 deletions stan/math/rev/core/operator_addition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ inline var operator+(Arith a, const var& b) {
*/
template <typename VarMat1, typename VarMat2,
require_all_rev_matrix_t<VarMat1, VarMat2>* = nullptr>
inline auto add(const VarMat1& a, const VarMat2& b) {
inline auto add(VarMat1&& a, VarMat2&& b) {
check_matching_dims("add", "a", a, "b", b);
using op_ret_type = decltype(a.val() + b.val());
using ret_type = return_var_matrix_t<op_ret_type, VarMat1, VarMat2>;
arena_t<VarMat1> arena_a(a);
arena_t<VarMat2> arena_b(b);
arena_t<VarMat1> arena_a(std::forward<VarMat1>(a));
arena_t<VarMat2> arena_b(std::forward<VarMat2>(b));
arena_t<ret_type> ret(arena_a.val() + arena_b.val());
reverse_pass_callback([ret, arena_a, arena_b]() mutable {
for (Eigen::Index j = 0; j < ret.cols(); ++j) {
Expand All @@ -124,7 +124,7 @@ inline auto add(const VarMat1& a, const VarMat2& b) {
}
}
});
return ret_type(ret);
return ret;
}

/**
Expand All @@ -139,18 +139,18 @@ inline auto add(const VarMat1& a, const VarMat2& b) {
template <typename Arith, typename VarMat,
require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto add(const VarMat& a, const Arith& b) {
inline auto add(VarMat&& a, const Arith& b) {
if (is_eigen<Arith>::value) {
check_matching_dims("add", "a", a, "b", b);
}
using op_ret_type
= decltype((a.val().array() + as_array_or_scalar(b)).matrix());
using ret_type = return_var_matrix_t<op_ret_type, VarMat>;
arena_t<VarMat> arena_a(a);
arena_t<VarMat> arena_a(std::forward<VarMat>(a));
arena_t<ret_type> ret(arena_a.val().array() + as_array_or_scalar(b));
reverse_pass_callback(
[ret, arena_a]() mutable { arena_a.adj() += ret.adj_op(); });
return ret_type(ret);
return ret;
}

/**
Expand All @@ -165,8 +165,8 @@ inline auto add(const VarMat& a, const Arith& b) {
template <typename Arith, typename VarMat,
require_st_arithmetic<Arith>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto add(const Arith& a, const VarMat& b) {
return add(b, a);
inline auto add(const Arith& a, VarMat&& b) {
return add(std::forward<VarMat>(b), a);
}

/**
Expand All @@ -185,7 +185,7 @@ inline auto add(const Var& a, const EigMat& b) {
using ret_type = return_var_matrix_t<EigMat>;
arena_t<ret_type> ret(a.val() + b.array());
reverse_pass_callback([ret, a]() mutable { a.adj() += ret.adj().sum(); });
return ret_type(ret);
return ret;
}

/**
Expand Down Expand Up @@ -217,9 +217,9 @@ inline auto add(const EigMat& a, const Var& b) {
template <typename Var, typename VarMat,
require_var_vt<std::is_arithmetic, Var>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto add(const Var& a, const VarMat& b) {
inline auto add(const Var& a, VarMat&& b) {
using ret_type = return_var_matrix_t<VarMat>;
arena_t<VarMat> arena_b(b);
arena_t<VarMat> arena_b(std::forward<VarMat>(b));
arena_t<ret_type> ret(a.val() + arena_b.val().array());
reverse_pass_callback([ret, a, arena_b]() mutable {
for (Eigen::Index j = 0; j < ret.cols(); ++j) {
Expand All @@ -230,7 +230,7 @@ inline auto add(const Var& a, const VarMat& b) {
}
}
});
return ret_type(ret);
return ret;
}

/**
Expand All @@ -246,8 +246,8 @@ inline auto add(const Var& a, const VarMat& b) {
template <typename Var, typename VarMat,
require_var_vt<std::is_arithmetic, Var>* = nullptr,
require_rev_matrix_t<VarMat>* = nullptr>
inline auto add(const VarMat& a, const Var& b) {
return add(b, a);
inline auto add(VarMat&& a, const Var& b) {
return add(b, std::forward<VarMat>(a));
}

template <typename T1, typename T2,
Expand All @@ -274,8 +274,8 @@ inline auto add(const T1& a, const T2& b) {
*/
template <typename VarMat1, typename VarMat2,
require_any_var_matrix_t<VarMat1, VarMat2>* = nullptr>
inline auto operator+(const VarMat1& a, const VarMat2& b) {
return add(a, b);
inline auto operator+(VarMat1&& a, VarMat2&& b) {
return add(std::forward<VarMat1>(a), std::forward<VarMat2>(b));
}

} // namespace math
Expand Down
4 changes: 2 additions & 2 deletions stan/math/rev/fun/fill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace math {
template <typename VarMat, typename S, require_var_matrix_t<VarMat>* = nullptr,
require_var_t<S>* = nullptr>
inline void fill(VarMat& x, const S& y) {
arena_t<plain_type_t<value_type_t<VarMat>>> prev_vals = x.val().eval();
arena_t<plain_type_t<value_type_t<VarMat>>> prev_vals(x.val().eval());
x.vi_->val_.fill(y.val());
reverse_pass_callback([x, y, prev_vals]() mutable {
x.vi_->val_ = prev_vals;
Expand All @@ -46,7 +46,7 @@ inline void fill(VarMat& x, const S& y) {
template <typename VarMat, typename S, require_var_matrix_t<VarMat>* = nullptr,
require_arithmetic_t<S>* = nullptr>
inline void fill(VarMat& x, const S& y) {
arena_t<plain_type_t<value_type_t<VarMat>>> prev_vals = x.val().eval();
arena_t<plain_type_t<value_type_t<VarMat>>> prev_vals(x.val().eval());
x.vi_->val_.fill(y);
reverse_pass_callback([x, prev_vals]() mutable {
x.vi_->val_ = prev_vals;
Expand Down
Loading