diff --git a/test/unit/math/rev/prob/wiener_full_test.cpp b/test/unit/math/rev/prob/wiener_full_test.cpp index e13a51e84f6..262e6ad354a 100644 --- a/test/unit/math/rev/prob/wiener_full_test.cpp +++ b/test/unit/math/rev/prob/wiener_full_test.cpp @@ -207,3 +207,96 @@ TEST(ProbWienerFullPrec, wiener_full_prec_all_scalar) { check_scalar_types(f_st0, st0[i], result[i], dst0[i]); } } + + + + + + + + +// CHECK THAT ALL VALID Vector TYPES ARE ACCEPTED +template +void check_vector_types(F& f, std::vector value, double res) { + // - f: Function where all inputs are vectors + // - value: value to be used for the parameter + // - res: expected result of calling `f` with `value` + // - deriv: expected result of partial of f with respect to + // the parameter in `value` + using stan::math::var; + double err_tol = 2e-4; + + // type double + EXPECT_NEAR(f(value), res, err_tol); + + // type var with derivative + var result_var = f(value); + result_var.grad(); + EXPECT_NEAR(value_of(result_var), res, err_tol); +} + +TEST(ProbWienerFull, wiener_full_all_vector) { + // tests all parameter types individually, with other + // parameters set to std::vector + using stan::math::wiener_full_lpdf; + + std::vector rt{1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector a{1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector v{1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector w{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}; + std::vector t0{0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0}; + std::vector sv{0.1, 0.1, 0.1, 0, 0.1, 0, 0, 0, 0}; + std::vector sw{0.1, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}; + std::vector st0{0.1, 0, 0.1, 0.1, 0, 0, 0.1, 0, 0}; + + double result{-24.307593}; + + // rt + auto f_rt = [&](auto value) { + return wiener_full_lpdf(value, a, t0, w, v, sv, sw, + st0); + }; + check_vector_types(f_rt, rt, result); + // a + auto f_a = [&](auto value) { + return wiener_full_lpdf(rt, value, t0, w, v, sv, sw, + st0); + }; + check_vector_types(f_a, a, result); + // v + auto f_v = [&](auto value) { + return wiener_full_lpdf(rt, a, t0, w, value, sv, sw, + st0); + }; + check_vector_types(f_v, v, result); + // w + auto f_w = [&](auto value) { + return wiener_full_lpdf(rt, a, t0, value, v, sv, sw, + st0); + }; + check_vector_types(f_w, w, result); + // t0 + auto f_t0 = [&](auto value) { + return wiener_full_lpdf(rt, a, value, w, v, sv, sw, + st0); + }; + check_vector_types(f_t0, t0, result); + // sv + auto f_sv = [&](auto value) { + return wiener_full_lpdf(rt, a, t0, w, v, value, sw, + st0); + }; + check_vector_types(f_sv, sv, result); + // sw + auto f_sw = [&](auto value) { + return wiener_full_lpdf(rt, a, t0, w, v, sv, value, + st0); + }; + check_vector_types(f_sw, sw, result); + // st0 + auto f_st0 = [&](auto value) { + return wiener_full_lpdf(rt, a, t0, w, v, sv, sw, + value); + }; + check_vector_types(f_st0, st0, result); + } \ No newline at end of file