Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/adjoint odes #1905

Merged
merged 236 commits into from
May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
236 commits
Select commit Hold shift + click to select a range
0389837
Revert "Revert back to forward mode sensitivity for cvodes (to make s…
bbbales2 May 22, 2020
7811368
Merge branch 'feature/parameter-pack-odes' into feature/adjoint-odes
bbbales2 May 22, 2020
fb36702
Merge remote-tracking branch 'origin/develop' into feature/adjoint-od…
wds15 Jan 2, 2021
27fbb73
add adjoint integrator first version
wds15 Jan 2, 2021
c480212
adding sho tests
wds15 Jan 2, 2021
942dd51
reroute ode_adams to ode_bdf_adjoint
wds15 Jan 2, 2021
d86cbc9
bump
wds15 Jan 2, 2021
6383fa3
introduce plain_type_t
wds15 Jan 3, 2021
f978440
fix eigen expression stuff for cvodes_integrator_adjoint
wds15 Jan 3, 2021
79cfeca
Changed how tuples initialized in ODE adjoint memory
bbbales2 Jan 3, 2021
a65573c
Merge branch 'feature/adjoint-odes-v2' of https://github.com/stan-dev…
wds15 Jan 3, 2021
cdddc43
more expression things
wds15 Jan 3, 2021
c27331b
Merge branch 'feature/adjoint-odes-v2' into feature/adjoint-odes
wds15 Jan 3, 2021
c4f530a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 3, 2021
37ec72c
add simple benchmark in the form of a test
wds15 Jan 3, 2021
91aaaa2
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 3, 2021
64acf4a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 3, 2021
ccc5026
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 3, 2021
edea9b0
recycle argument vari's
wds15 Jan 3, 2021
c81e6aa
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 3, 2021
418c0c1
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 3, 2021
fd9d531
fix
wds15 Jan 4, 2021
bd35bd9
add plain_type_t
wds15 Jan 4, 2021
b649c25
Merge commit 'b2eaa329ee3427245dbe06482ee937569cb6032f' into HEAD
yashikno Jan 4, 2021
e0a4e40
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 4, 2021
b6b8782
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 4, 2021
a15e58a
bump
wds15 Jan 4, 2021
2adfcc0
route calls correctly
wds15 Jan 4, 2021
b15c7b1
bump
wds15 Jan 4, 2021
47abeb8
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 4, 2021
2a2bce5
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 5, 2021
cfcc616
too much work error message
wds15 Jan 5, 2021
83d7c41
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 5, 2021
78f87e0
bump
wds15 Jan 5, 2021
c940514
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 5, 2021
9da66fb
Merge commit '7d5f626b55d822f2544138cc5fbeabbd4d86a92c' into HEAD
yashikno Jan 5, 2021
6be8f54
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 5, 2021
6fff468
add maximal number of steps option to backward solve
wds15 Jan 6, 2021
e553c23
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 6, 2021
a09696b
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 6, 2021
7e78900
add scaling test
wds15 Jan 7, 2021
724ddec
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 7, 2021
2d1a65d
Merge commit 'aa1c1c48372fb8e60da9ca7af6f0403b83f99d9d' into HEAD
yashikno Jan 7, 2021
4bb9482
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 7, 2021
c4e8280
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 7, 2021
a1fd530
align tolerances
wds15 Jan 7, 2021
72cbc3f
fix scaling benchmark
wds15 Jan 11, 2021
12f9b53
Merge commit '3b0ddba22615efe779da52f7185ec15863b0a7e6' into HEAD
yashikno Jan 11, 2021
4d68049
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 11, 2021
94012ea
Set adjoints to 1 in benchmark test.
charlesm93 Jan 11, 2021
6fa6208
Merge commit '222e1ad9426ba13c364a8a65d2f3958104047b54' into HEAD
yashikno Jan 11, 2021
14dc05d
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 11, 2021
9a87420
add toy model
wds15 Jan 13, 2021
67eb800
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 13, 2021
1ce94e2
Merge commit '7085ef69bdb399db7d918240659f56305885a2fd' into HEAD
yashikno Jan 13, 2021
c5509e7
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 13, 2021
4934995
fix
wds15 Jan 13, 2021
89b60fd
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 13, 2021
9926246
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 13, 2021
999f766
fix cvodes init / reinit logic
wds15 Jan 13, 2021
6480b4d
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 13, 2021
90e6f93
Merge commit '6ac9c2e5ab14bb53a5f4a0e31d3a783015fe1dd7' into HEAD
yashikno Jan 13, 2021
12b8408
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 13, 2021
1d85054
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 13, 2021
18ff5b3
add sum test
wds15 Jan 14, 2021
964e87e
Merge commit 'd2e485cf1597d1a30fbfdf9330540033c432683c' into HEAD
yashikno Jan 14, 2021
71cefa1
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 14, 2021
d1280cf
fix linked mass flow example
wds15 Jan 14, 2021
1615e23
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 14, 2021
61aa6f8
add SBC codes
weberse2 Jan 17, 2021
4cbdb04
bump
weberse2 Jan 17, 2021
ff1571f
adding eval and some notes on how the math works out
wds15 Jan 18, 2021
ac36f82
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 18, 2021
f49a7b3
Merge commit '7a0d613e607fecd3b6f039dca796456559f18856' into HEAD
yashikno Jan 18, 2021
ed3ce6e
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Jan 18, 2021
d36b62f
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Jan 18, 2021
5c47c05
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
wds15 Feb 22, 2021
9735ffd
revert adams back to actual adams
wds15 Feb 22, 2021
920fc7c
make signature work as decribed in design doc
wds15 Feb 22, 2021
59eb1d8
make first test work with new function
wds15 Feb 22, 2021
1c4482c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Feb 22, 2021
3157e49
make vector absolute tolerance work and expose solver choice
wds15 Feb 23, 2021
0acc62e
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Feb 23, 2021
1eda2f9
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Feb 23, 2021
475216e
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Feb 24, 2021
1faf260
add tol call
wds15 Feb 24, 2021
9f2b1c7
Merge commit '94e16d87d9cb464b4e05b8914f6cbbc8ea20ac0b' into HEAD
yashikno Feb 24, 2021
5d74d68
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Feb 24, 2021
6a1efb2
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Feb 24, 2021
3413734
add various benchmarks
wds15 Feb 26, 2021
ff87fd5
tune sparse matrix support
wds15 Mar 2, 2021
db633c2
Merge commit '5533797fb87828108b4412e80537084fc168b092' into HEAD
yashikno Mar 2, 2021
69e2e91
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 2, 2021
4a8afdf
disable sparse matrix stuff on CVODES and switch to new function name…
wds15 Mar 4, 2021
d8ce935
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Mar 4, 2021
73041c1
enable more tests
wds15 Mar 4, 2021
9e30580
Merge commit 'b7eff4c4d5d0f6b55523be91e4b757c8b766d676' into HEAD
yashikno Mar 4, 2021
6e238d7
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 4, 2021
ee94ed3
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Mar 4, 2021
efba64e
use vector for abs_tol_b
rok-cesnovar Mar 5, 2021
6db3dad
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 5, 2021
024e94a
fix test model
rok-cesnovar Mar 5, 2021
4450cca
Switch to reference types
bbbales2 Mar 6, 2021
4c478fd
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Mar 7, 2021
149c649
align defaults of simplified ode_adjoint_tol interface with doc
weberse2 Mar 11, 2021
b31a728
revert to hermite polynomials which appear more stable
weberse2 Mar 11, 2021
fd74b60
bump examples
weberse2 Mar 11, 2021
a382fee
fix
weberse2 Mar 11, 2021
d72dd8a
align notation with user guide
weberse2 Mar 12, 2021
217f077
export all tolerances for example
weberse2 Mar 15, 2021
5f0f2b7
make use of Eigen vector expressions and use coeff indexing
weberse2 Mar 15, 2021
0e90995
put apply call in separate function to clean code
weberse2 Mar 15, 2021
86f586f
Merge commit '8a3ebcb864c92343fda00360278013622e171bc0' into HEAD
yashikno Mar 15, 2021
f4bde7e
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Mar 15, 2021
b315e77
Merge branch 'feature/adjoint-odes' of https://github.com/stan-dev/ma…
wds15 Mar 22, 2021
681ab4f
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
wds15 Mar 31, 2021
4738307
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
wds15 Apr 15, 2021
338de15
backport adjoint code and integrate into testing framework in a first…
wds15 Apr 15, 2021
4abf989
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 15, 2021
c2e9e39
cpplint
wds15 Apr 16, 2021
f637c3a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 16, 2021
9390361
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 16, 2021
74f6e0c
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 16, 2021
527fa2e
cpplint
wds15 Apr 16, 2021
4377411
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 16, 2021
3e2e653
lint
wds15 Apr 19, 2021
5c0179a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 19, 2021
8f4aa86
refactor adjoint solver to not use extra memory object
wds15 Apr 22, 2021
b393615
some more simplifcations
wds15 Apr 22, 2021
1de856d
move build_varis function inside cvodes_integrator_adjoint
wds15 Apr 22, 2021
db528e0
undo last change
wds15 Apr 22, 2021
a710e79
Merge commit 'eae2d6d1f526e31f97286829737329d789897544' into HEAD
yashikno Apr 22, 2021
bbb3bd2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 22, 2021
d0e76a3
make benchmarks work with newer interface
wds15 Apr 22, 2021
bcd09f6
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 22, 2021
5ee531a
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 22, 2021
2b887ca
remove obsolete test
wds15 Apr 22, 2021
abbdebf
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 22, 2021
7e02e5d
get rid of helper build_vari and replace it with overload withing var…
wds15 Apr 22, 2021
dbd07ea
mark methods as const where possible
wds15 Apr 22, 2021
0d46d07
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 22, 2021
643c602
reorder freeing
wds15 Apr 22, 2021
7afcd30
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 22, 2021
fa3dbb2
port over changes from review3 branch
wds15 Apr 26, 2021
9d55a24
mark functions const as appropiate
wds15 Apr 26, 2021
7b6dd5e
more merging from review3
wds15 Apr 26, 2021
ace3931
optimize ts only var case
wds15 Apr 26, 2021
a314b0a
Merge commit 'f1905f68276b9cc578bb1c7ab3f8e42266e0074a' into HEAD
yashikno Apr 26, 2021
6e10a7d
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 26, 2021
a836ff6
headers compliance
wds15 Apr 27, 2021
efe107f
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 27, 2021
1654918
update zero_adjoints test
SteveBronder Apr 27, 2021
c5b8b71
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
SteveBronder Apr 27, 2021
6eff162
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 27, 2021
1641cdb
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 27, 2021
d99f69e
remove obsolete benchmarks
wds15 Apr 27, 2021
e8138b8
delete benchmark R code and move ode functor into chainable_alloc cla…
wds15 Apr 27, 2021
d14f9f0
cleanups & added some more inline & constexpr as appropiate
wds15 Apr 27, 2021
875082b
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 27, 2021
fd6fd13
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 28, 2021
c619587
fix memory issues
wds15 Apr 28, 2021
ee27184
Merge commit '2e6fdbaef13d208d47c69c8349a258bd65dd2eb7' into HEAD
yashikno Apr 28, 2021
0150ed6
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 28, 2021
bd6073a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 28, 2021
273a82b
some more const declarations
wds15 Apr 28, 2021
c6510f8
cleanup
wds15 Apr 28, 2021
18ecc0d
fix recover memory issues of ode test framework
wds15 Apr 28, 2021
a2af67d
fix leak in coupled ode system and move the nested to be local in the…
SteveBronder Apr 28, 2021
62281d2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 28, 2021
e721414
adding extra bad tests for ode_adjoint_tol_ctl interface
wds15 Apr 28, 2021
976f237
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 Apr 28, 2021
c513e94
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot Apr 28, 2021
c645587
fix cpplint error
SteveBronder Apr 28, 2021
8db9c2e
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 1, 2021
ab4acbe
Update stan/math/rev/functor/cvodes_integrator_adjoint.hpp
wds15 May 1, 2021
7f87d1c
Merge commit 'eec0ad1dffa5906e7717db9afac675171260ce49' into HEAD
yashikno May 1, 2021
ea7e808
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 1, 2021
9ec431a
Update stan/math/rev/functor/cvodes_integrator_adjoint.hpp
wds15 May 1, 2021
2adc594
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 1, 2021
7d914d1
review comments
wds15 May 1, 2021
f0104fc
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 1, 2021
3ab73f1
update adjoint ode to not have to make seperate vari
SteveBronder May 3, 2021
ba82ee4
remove num_vars
SteveBronder May 3, 2021
784be28
remove num_vars
SteveBronder May 3, 2021
6015c6b
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 4, 2021
894e55b
Merge remote-tracking branch 'origin/develop' into review/ode-adjoint4
SteveBronder May 4, 2021
528441e
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
SteveBronder May 4, 2021
fb84cfe
Merge branch 'review/ode-adjoint4' into feature/adjoint-odes
SteveBronder May 4, 2021
3f91c24
more testing
wds15 May 4, 2021
a5b7b9f
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 4, 2021
19774a2
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 4, 2021
8a4525a
doc and review comments
wds15 May 4, 2021
13b7224
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 4, 2021
eb62f8c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 4, 2021
12b4a92
move functor into deallocated solve object
wds15 May 4, 2021
d779c7a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 4, 2021
c0fa201
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 4, 2021
1e1602f
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 4, 2021
2bf6d94
fix get
SteveBronder May 4, 2021
49ad007
Merge branch 'feature/adjoint-odes' of github.com:stan-dev/math into …
SteveBronder May 4, 2021
79f0970
return value for const ref input for get()
SteveBronder May 4, 2021
46412b1
return value for scalar const ref in get()
SteveBronder May 4, 2021
e1b7960
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 5, 2021
ff16c04
fix tests (hopefully)
wds15 May 5, 2021
0c42867
Merge remote-tracking branch 'origin/develop' into feature/adjoint-odes
SteveBronder May 7, 2021
3320504
1. Cleaning up the actual functions these are called from in order to…
SteveBronder May 9, 2021
0b73247
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 9, 2021
8612ac6
remove extra whitespace for cpplint
SteveBronder May 9, 2021
cf07e1a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 10, 2021
68c3f88
address review comments (split tests, test for no vari on stack when …
wds15 May 10, 2021
828f409
save function name in a string
wds15 May 10, 2021
cc488c0
register solver adjoint only in case of AD
wds15 May 10, 2021
e5ad9c1
optimize case when only ts is a var
wds15 May 10, 2021
459d64b
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 10, 2021
3bedb86
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 10, 2021
0749581
nest adjoint ODE integrator vari call for double only case
wds15 May 11, 2021
a83fbd6
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 11, 2021
2f618c9
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 11, 2021
8b01a90
make use of adjoint_of where it works
wds15 May 11, 2021
437791c
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 11, 2021
ae9521a
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 11, 2021
c26a82f
fix
wds15 May 11, 2021
d88e021
fix tests
wds15 May 11, 2021
ba01fad
improve error messages
wds15 May 15, 2021
826e0c9
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 15, 2021
86d32bf
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 15, 2021
784226c
revert harmonic oscillator
wds15 May 16, 2021
69158dd
make Eigen inputs expression compatible
wds15 May 16, 2021
8d71067
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 16, 2021
1b00008
move ode argument tuple into the cvodes_solvers object
wds15 May 16, 2021
3ef833f
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 16, 2021
67c7c8f
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 16, 2021
03d50c6
add in decay_t again as before
wds15 May 16, 2021
de69763
Merge branch 'feature/adjoint-odes' of ssh://github.com/stan-dev/math…
wds15 May 16, 2021
630e1ed
review comments
wds15 May 16, 2021
01c9fa1
[Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.0…
stan-buildbot May 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 47 additions & 40 deletions stan/math/rev/functor/cvodes_integrator_adjoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,14 @@ class cvodes_integrator_adjoint_vari : public vari_base {

static constexpr bool is_var_ts_{is_var<T_ts>::value};
static constexpr bool is_var_t0_{is_var<T_t0>::value};
static constexpr bool is_var_y0_{is_var<T_y0>::value};
static constexpr bool is_var_y0_t0_{is_var<T_y0_t0>::value};
static constexpr bool is_any_var_args_{
disjunction<is_var<scalar_type_t<T_Args>>...>::value};
static constexpr bool is_var_return_{is_var<T_Return>::value};
static constexpr bool is_var_only_ts_{
is_var_ts_ && !(is_var_t0_ || is_var_y0_t0_ || is_any_var_args_)};

std::tuple<arena_t<T_Args>...> local_args_tuple_;
std::tuple<arena_t<
promote_scalar_t<partials_type_t<scalar_type_t<T_Args>>, T_Args>>...>
value_of_args_tuple_;

arena_t<std::vector<Eigen::VectorXd>> y_;
arena_t<std::vector<T_ts>> ts_;
arena_t<Eigen::Matrix<T_y0_t0, Eigen::Dynamic, 1>> y0_;
Expand Down Expand Up @@ -82,7 +78,9 @@ class cvodes_integrator_adjoint_vari : public vari_base {
* vari class).
*/
struct cvodes_solver : public chainable_alloc {
std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y_return_;
const std::string function_name_str_;
SteveBronder marked this conversation as resolved.
Show resolved Hide resolved
const std::decay_t<F> f_;
const size_t N_;
N_Vector nv_state_forward_;
N_Vector nv_state_backward_;
N_Vector nv_quad_;
Expand All @@ -92,19 +90,23 @@ class cvodes_integrator_adjoint_vari : public vari_base {
SUNLinearSolver LS_forward_;
SUNMatrix A_backward_;
SUNLinearSolver LS_backward_;
const size_t N_;
const std::string function_name_str_;
void* cvodes_mem_;
const std::decay_t<F> f_;
std::vector<Eigen::Matrix<T_Return, Eigen::Dynamic, 1>> y_return_;
std::tuple<T_Args...> local_args_tuple_;
const std::tuple<
promote_scalar_t<partials_type_t<scalar_type_t<T_Args>>, T_Args>...>
value_of_args_tuple_;

template <typename FF, typename StateFwd, typename StateBwd, typename Quad,
typename AbsTolFwd, typename AbsTolBwd>
cvodes_solver(const char* function_name, FF&& f, size_t N,
size_t num_args_vars, size_t ts_size, int solver_forward,
StateFwd& state_forward, StateBwd& state_backward, Quad& quad,
AbsTolFwd& absolute_tolerance_forward,
AbsTolBwd& absolute_tolerance_backward)
AbsTolBwd& absolute_tolerance_backward, const T_Args&... args)
: chainable_alloc(),
wds15 marked this conversation as resolved.
Show resolved Hide resolved
f_(std::forward<FF>(f)),
function_name_str_(function_name),
y_return_(ts_size),
nv_state_forward_(N_VMake_Serial(N, state_forward.data())),
nv_state_backward_(N_VMake_Serial(N, state_backward.data())),
Expand All @@ -122,9 +124,9 @@ class cvodes_integrator_adjoint_vari : public vari_base {
N == 0 ? nullptr
: SUNDenseLinearSolver(nv_state_backward_, A_backward_)),
N_(N),
function_name_str_(function_name),
cvodes_mem_(CVodeCreate(solver_forward)),
f_(std::forward<FF>(f)) {
local_args_tuple_(deep_copy_vars(args)...),
value_of_args_tuple_(value_of(args)...) {
if (cvodes_mem_ == nullptr) {
throw std::runtime_error("CVodeCreate failed to allocate memory");
}
Expand Down Expand Up @@ -202,8 +204,8 @@ class cvodes_integrator_adjoint_vari : public vari_base {
int interpolation_polynomial, int solver_forward, int solver_backward,
std::ostream* msgs, const T_Args&... args)
: vari_base(),
local_args_tuple_(to_arena(deep_copy_vars(args))...),
value_of_args_tuple_(to_arena(value_of(args))...),
// local_args_tuple_(deep_copy_vars(args)...),
// value_of_args_tuple_(value_of(args)...),
wds15 marked this conversation as resolved.
Show resolved Hide resolved
y_(ts.size()),
ts_(ts.begin(), ts.end()),
y0_(y0),
Expand Down Expand Up @@ -238,12 +240,6 @@ class cvodes_integrator_adjoint_vari : public vari_base {
check_finite(function_name, "initial time", t0);
check_finite(function_name, "times", ts);

stan::math::for_each(
[func_name = function_name](auto&& arg) {
check_finite(func_name, "ode parameters and data", arg);
},
local_args_tuple_);

check_nonzero_size(function_name, "times", ts);
check_nonzero_size(function_name, "initial state", y0);
check_sorted(function_name, "times", ts);
Expand All @@ -268,16 +264,29 @@ class cvodes_integrator_adjoint_vari : public vari_base {
check_positive(function_name, "num_steps_between_checkpoints",
num_steps_between_checkpoints_);
// for polynomial: 1=CV_HERMITE / 2=CV_POLYNOMIAL
check_bounded(function_name, "interpolation_polynomial",
interpolation_polynomial_, 1, 2);
if (interpolation_polynomial_ != 1 && interpolation_polynomial_ != 2)
invalid_argument(function_name, "interpolation_polynomial",
interpolation_polynomial_, "",
", must be 1 for Hermite or 2 for polynomial "
"interpolation of ODE solution");
// 1=Adams=CV_ADAMS, 2=BDF=CV_BDF
check_bounded(function_name, "solver_forward", solver_forward_, 1, 2);
check_bounded(function_name, "solver_backward", solver_backward_, 1, 2);
if (solver_forward_ != 1 && solver_forward_ != 2)
invalid_argument(function_name, "solver_forward", solver_forward_, "",
", must be 1 for Adams or 2 for BDF forward solver");
if (solver_backward_ != 1 && solver_backward_ != 2)
invalid_argument(function_name, "solver_backward", solver_backward_, "",
", must be 1 for Adams or 2 for BDF backward solver");

solver_ = new cvodes_solver(
function_name, f, N_, num_args_vars_, ts_.size(), solver_forward_,
state_forward_, state_backward_, quad_, absolute_tolerance_forward_,
absolute_tolerance_backward_);
absolute_tolerance_backward_, args...);

stan::math::for_each(
[func_name = function_name](auto&& arg) {
check_finite(func_name, "ode parameters and data", arg);
},
solver_->local_args_tuple_);

check_flag_sundials(
CVodeInit(solver_->cvodes_mem_, &cvodes_integrator_adjoint_vari::cv_rhs,
Expand Down Expand Up @@ -338,13 +347,12 @@ class cvodes_integrator_adjoint_vari : public vari_base {
} else {
check_flag_sundials(error_code, "CVodeF");
}

} else {
int error_code
= CVode(solver_->cvodes_mem_, t_final, solver_->nv_state_forward_,
&t_init, CV_NORMAL);

if (error_code == CV_TOO_MUCH_WORK) {
if (unlikely(error_code == CV_TOO_MUCH_WORK)) {
throw_domain_error(solver_->function_name_str_.c_str(), "", t_final,
"Failed to integrate to next output time (",
") in less than max_num_steps steps");
Expand All @@ -357,9 +365,7 @@ class cvodes_integrator_adjoint_vari : public vari_base {

t_init = t_final;
}
if (is_var_return_) {
ChainableStack::instance_->var_stack_.push_back(this);
}
ChainableStack::instance_->var_stack_.push_back(this);
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
}

private:
Expand Down Expand Up @@ -414,8 +420,8 @@ class cvodes_integrator_adjoint_vari : public vari_base {
+= forward_as<var>(solver_->y_return_[i].coeff(j)).adj();
}

forward_as<var>(ts_[i])->adj_ += step_sens.dot(
rhs(value_of(ts_[i]), y_[i], value_of_args_tuple_));
adjoint_of(ts_[i]) += step_sens.dot(
rhs(value_of(ts_[i]), y_[i], solver_->value_of_args_tuple_));
step_sens.setZero();
}

Expand Down Expand Up @@ -543,8 +549,8 @@ class cvodes_integrator_adjoint_vari : public vari_base {
}

if (is_var_t0_) {
forward_as<var>(t0_)->adj_ += -state_backward_.dot(
rhs(t_init, value_of(y0_), value_of_args_tuple_));
adjoint_of(t0_) += -state_backward_.dot(
rhs(t_init, value_of(y0_), solver_->value_of_args_tuple_));
}

// After integrating all the way back to t0, we finally have the
Expand Down Expand Up @@ -589,7 +595,8 @@ class cvodes_integrator_adjoint_vari : public vari_base {
*/
inline int rhs(double t, const double* y, double*& dy_dt) const {
const Eigen::VectorXd y_vec = Eigen::Map<const Eigen::VectorXd>(y, N_);
const Eigen::VectorXd dy_dt_vec = rhs(t, y_vec, value_of_args_tuple_);
const Eigen::VectorXd dy_dt_vec
= rhs(t, y_vec, solver_->value_of_args_tuple_);
check_size_match(solver_->function_name_str_.c_str(), "dy_dt",
dy_dt_vec.size(), "states", N_);
Eigen::Map<Eigen::VectorXd>(dy_dt, N_) = dy_dt_vec;
Expand Down Expand Up @@ -624,7 +631,7 @@ class cvodes_integrator_adjoint_vari : public vari_base {
Eigen::Matrix<var, Eigen::Dynamic, 1> y_vars(
Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y), N_));
Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
= rhs(t, y_vars, value_of_args_tuple_);
= rhs(t, y_vars, solver_->value_of_args_tuple_);
check_size_match(solver_->function_name_str_.c_str(), "dy_dt",
f_y_t_vars.size(), "states", N_);
f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB), N_);
Expand Down Expand Up @@ -662,9 +669,9 @@ class cvodes_integrator_adjoint_vari : public vari_base {
// The vars here do not live on the nested stack so must be zero'd
// separately
stan::math::for_each([](auto&& arg) { zero_adjoints(arg); },
local_args_tuple_);
solver_->local_args_tuple_);
Eigen::Matrix<var, Eigen::Dynamic, 1> f_y_t_vars
= rhs(t, y_vec, local_args_tuple_);
= rhs(t, y_vec, solver_->local_args_tuple_);
check_size_match(solver_->function_name_str_.c_str(), "dy_dt",
f_y_t_vars.size(), "states", N_);
f_y_t_vars.adj() = -Eigen::Map<Eigen::VectorXd>(NV_DATA_S(yB), N_);
Expand All @@ -673,7 +680,7 @@ class cvodes_integrator_adjoint_vari : public vari_base {
[&qBdot](auto&&... args) {
accumulate_adjoints(NV_DATA_S(qBdot), args...);
},
local_args_tuple_);
solver_->local_args_tuple_);
return 0;
}

Expand All @@ -698,7 +705,7 @@ class cvodes_integrator_adjoint_vari : public vari_base {
Eigen::Matrix<var, Eigen::Dynamic, 1> y_var(
Eigen::Map<const Eigen::VectorXd>(NV_DATA_S(y), N_));
Eigen::Matrix<var, Eigen::Dynamic, 1> fy_var
= rhs(t, y_var, value_of_args_tuple_);
= rhs(t, y_var, solver_->value_of_args_tuple_);

check_size_match(solver_->function_name_str_.c_str(), "dy_dt",
fy_var.size(), "states", N_);
Expand Down
95 changes: 94 additions & 1 deletion stan/math/rev/functor/ode_adjoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ namespace math {
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename T_abs_tol_fwd, typename T_abs_tol_bwd, typename... T_Args,
require_all_eigen_col_vector_t<T_y0, T_abs_tol_fwd,
T_abs_tol_bwd>* = nullptr>
T_abs_tol_bwd>* = nullptr,
require_any_not_st_arithmetic<T_y0, T_t0, T_ts, T_Args...>* = nullptr>
auto ode_adjoint_impl(
const char* function_name, F&& f, const T_y0& y0, const T_t0& t0,
const std::vector<T_ts>& ts, double relative_tolerance_forward,
Expand All @@ -91,6 +92,98 @@ auto ode_adjoint_impl(
return integrator->solution();
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the stiff backward differentiation formula
* BDF solver or the non-stiff Adams solver from CVODES. The ODE system is
* integrated using the adjoint sensitivity approach of CVODES. This
* implementation handles the case of a double return type which ensures that no
* resources are left on the AD stack.
*
* \p f must define an operator() with the signature as:
* template<typename T_t, typename T_y, typename... T_Args>
* Eigen::Matrix<stan::return_type_t<T_t, T_y, T_Args...>, Eigen::Dynamic, 1>
* operator()(const T_t& t, const Eigen::Matrix<T_y, Eigen::Dynamic, 1>& y,
* std::ostream* msgs, const T_Args&... args);
*
* t is the time, y is the state, msgs is a stream for error messages, and args
* are optional arguments passed to the ODE solve function (which are passed
* through to \p f without modification).
*
* @tparam F Type of ODE right hand side
* @tparam T_y0 Type of initial state
* @tparam T_t0 Type of initial time
* @tparam T_ts Type of output times
* @tparam T_Args Types of pass-through parameters
*
* @param function_name Calling function name (for printing debugging messages)
* @param f Right hand side of the ODE
* @param y0 Initial state
* @param t0 Initial time
* @param ts Times at which to solve the ODE at. All values must be sorted and
* not less than t0.
* @param relative_tolerance_forward Relative tolerance for forward problem
* passed to CVODES
* @param absolute_tolerance_forward Absolute tolerance per ODE state for
* forward problem passed to CVODES
* @param relative_tolerance_backward Relative tolerance for backward problem
* passed to CVODES
* @param absolute_tolerance_backward Absolute tolerance per ODE state for
* backward problem passed to CVODES
* @param relative_tolerance_quadrature Relative tolerance for quadrature
* problem passed to CVODES
* @param absolute_tolerance_quadrature Absolute tolerance for quadrature
* problem passed to CVODES
* @param max_num_steps Upper limit on the number of integration steps to
* take between each output (error if exceeded)
* @param num_steps_between_checkpoints Number of integrator steps after which a
* checkpoint is stored for the backward pass
* @param interpolation_polynomial type of polynomial used for interpolation
* @param solver_forward solver used for forward pass
* @param solver_backward solver used for backward pass
* @param[in, out] msgs the print stream for warning messages
* @param args Extra arguments passed unmodified through to ODE right hand side
* @return An `std::vector` of Eigen column vectors with scalars equal to
* the least upper bound of `T_y0`, `T_t0`, `T_ts`, and the lambda's arguments.
* This represents the solution to ODE at times \p ts
*/
template <typename F, typename T_y0, typename T_t0, typename T_ts,
typename T_abs_tol_fwd, typename T_abs_tol_bwd, typename... T_Args,
require_all_eigen_col_vector_t<T_y0, T_abs_tol_fwd,
T_abs_tol_bwd>* = nullptr,
require_all_st_arithmetic<T_y0, T_t0, T_ts, T_Args...>* = nullptr>
std::vector<Eigen::VectorXd> ode_adjoint_impl(
const char* function_name, F&& f, const T_y0& y0, const T_t0& t0,
const std::vector<T_ts>& ts, double relative_tolerance_forward,
const T_abs_tol_fwd& absolute_tolerance_forward,
double relative_tolerance_backward,
const T_abs_tol_bwd& absolute_tolerance_backward,
double relative_tolerance_quadrature, double absolute_tolerance_quadrature,
long int max_num_steps, // NOLINT(runtime/int)
long int num_steps_between_checkpoints, // NOLINT(runtime/int)
int interpolation_polynomial, int solver_forward, int solver_backward,
std::ostream* msgs, const T_Args&... args) {
std::vector<Eigen::VectorXd> ode_solution;
{
nested_rev_autodiff nested;

using integrator_vari
= cvodes_integrator_adjoint_vari<F, plain_type_t<T_y0>, T_t0, T_ts,
plain_type_t<T_Args>...>;

auto integrator = new integrator_vari(
function_name, std::forward<F>(f), y0, t0, ts,
relative_tolerance_forward, absolute_tolerance_forward,
relative_tolerance_backward, absolute_tolerance_backward,
relative_tolerance_quadrature, absolute_tolerance_quadrature,
max_num_steps, num_steps_between_checkpoints, interpolation_polynomial,
solver_forward, solver_backward, msgs, args...);

ode_solution = integrator->solution();
}
return ode_solution;
}

/**
* Solve the ODE initial value problem y' = f(t, y), y(t0) = y0 at a set of
* times, { t1, t2, t3, ... } using the stiff backward differentiation formula
Expand Down
Loading