-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpmx_csda.hpp
127 lines (117 loc) · 4.43 KB
/
pmx_csda.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#ifndef TORSTEN_COMPLEX_STEP_DERIVATIVE_HPP
#define TORSTEN_COMPLEX_STEP_DERIVATIVE_HPP
#include <stan/math/prim/fun/value_of.hpp>
#include <stan/math/prim/meta/return_type.hpp>
#include <stan/math/rev/core/callback_vari.hpp>
#include <stan/math/rev/core/vari.hpp>
#include <stan/math/rev/core/var.hpp>
#include <vector>
#include <complex>
#include <iostream>
#include <algorithm>
namespace torsten {
/**
* Return a var that has value of given functor F and derivative
* of df/d(theta), using complex step derivative
* approximation. "f" does not have to support "var"
* type, as its signature should be
* (complex, std::vector<double>, std::vector<int>, stream*) : complex
*
* @tparam F type of functor F
* @param[in] f functor for the complex number evaluation,
* must support @c std::complex<double> as arg.
* @param[in] theta parameter where f and df/d(theta) is requested.
* @param[in] x_r continuous data vector for the ODE.
* @param[in] x_i integer data vector for the ODE.
* @param[in] h complex step size
* @param[out] msgs the print stream for warning messages.
* @return a var with value f(theta.val()) and derivative at theta.
*/
template <typename F>
stan::math::var pk_csda(const F& f,
const stan::math::var& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i,
const double h,
std::ostream* msgs) {
using stan::math::var;
using std::complex;
const double theta_d = theta.val();
const complex<double> res = f(complex<double>(theta_d, h), x_r, x_i, msgs);
const double fx = std::real(res);
const double g = std::imag(res) / h;
// return var(new stan::math::precomp_v_vari(fx, theta.vi_, g));
return stan::math::make_callback_var(fx, [theta, g](auto& vi) mutable {
theta.adj() += g;
});
}
/**
* CSDA, calculate directional derivative of a vector function
*
*/
template <typename F>
std::vector<double> pk_csda(const F& f,
const std::vector<double>& y,
const std::vector<double>& dy,
const std::vector<double>& x_r,
const std::vector<int>& x_i,
std::ostream* msgs) {
using stan::math::var;
using std::complex;
using cplx = complex<double>;
const double h = 1.E-20;
std::vector<cplx> cplx_y(y.size());
std::transform(y.begin(), y.end(), dy.begin(), cplx_y.begin(),
[&h](const double &r, const double &i) -> cplx {
return cplx(r, h * i); }
);
const std::vector<cplx> res = f(cplx_y, x_r, x_i, msgs);
std::vector<double> g(y.size());
std::transform(res.begin(), res.end(), g.begin(),
[&h](cplx x) -> double { return std::imag(x)/h; });
return g;
}
/**
* CSDA, calculate directional derivative of a vector
* function, without system input @c x_r, x_i, msgs
*/
template <typename F>
std::vector<double> pk_csda(const F& f,
const std::vector<double>& y,
const std::vector<double>& dy) {
using std::complex;
using cplx = complex<double>;
const double h = 1.E-20;
std::vector<cplx> cplx_y(y.size());
std::transform(y.begin(), y.end(), dy.begin(), cplx_y.begin(),
[&h](const double &r, const double &i) -> cplx {
return cplx(r, h * i); }
);
const std::vector<cplx> res = f(cplx_y);
std::vector<double> g(y.size());
std::transform(res.begin(), res.end(), g.begin(),
[&h](cplx x) -> double { return std::imag(x)/h; });
return g;
}
/**
* CSDA, default h version, with h = 1.E-20
*
* @tparam F type of functor F
* @param[in] f functor for the complex number evaluation,
* must support @c std::complex<double> as arg.
* @param[in] theta parameter where f and df/d(theta) is requested.
* @param[in] x_r continuous data vector for the ODE.
* @param[in] x_i integer data vector for the ODE.
* @param[out] msgs the print stream for warning messages.
* @return a var with value f(theta.val()) and derivative at theta.
*/
template <typename F>
stan::math::var pk_csda(const F& f,
const stan::math::var& theta,
const std::vector<double>& x_r,
const std::vector<int>& x_i,
std::ostream* msgs) {
return pk_csda(f, theta, x_r, x_i, 1.E-20, msgs);
}
} // namespace torsten
#endif