Skip to content

Commit

Permalink
test: mark edge case getu type inference as broken
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 31, 2024
1 parent 423a27a commit 47a9224
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions test/downstream/symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,16 @@ for (sym, val, check_inference) in [
((x, x), [(i, i) for i in x_val], true),
((x, x_idx), [(i, i) for i in x_val], true),
((x, x[1] + y), [(i, j) for (i, j) in zip(x_val, obs_val)], true),
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], true),
((x, (x[1] + y, y)), [(i, (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], missing),
([x, [x[1] + y, y]], [[i, [k, j]] for (i, j, k) in zip(x_val, y_val, obs_val)], false),
((x, [x[1] + y, y], (x[1] + y, y_idx)),
[(i, [k, j], (k, j)) for (i, j, k) in zip(x_val, y_val, obs_val)], false),
([x, [x[1] + y, y], (x[1] + y, y_idx)],
[[i, [k, j], (k, j)] for (i, j, k) in zip(x_val, y_val, obs_val)], false)
]
if check_inference
if check_inference === missing
@test_broken @inferred getu(prob, sym)(sol)
elseif check_inference
@inferred getu(prob, sym)(sol)
end
@test getu(prob, sym)(sol) == val
Expand Down Expand Up @@ -319,7 +321,7 @@ for (sym, oldval, newval, check_inference) in [
((p[1], p[2:3]), (pval[1], pval[2:3]), (pval_new[1], pval_new[2:3]), true),
([p[1], p[2:3]], [pval[1], pval[2:3]], [pval_new[1], pval_new[2:3]], false),
((p[1], (p[2],), [p[3]]), (pval[1], (pval[2],), [pval[3]]),
(pval_new[1], (pval_new[2],), [pval_new[3]]), true),
(pval_new[1], (pval_new[2],), [pval_new[3]]), false),
([p[1], (p[2],), [p[3]]], [pval[1], (pval[2],), [pval[3]]],
[pval_new[1], (pval_new[2],), [pval_new[3]]], false)
]
Expand Down Expand Up @@ -460,7 +462,7 @@ for (sym, val, buffer, check_inference) in [
end
end

int = init(sol.prob.p, Tsit5())
int = init(sol.prob, Tsit5())
step!(int, 0.1, true)
ud1obsval = vcat(ud1val[2:end], int.ps[Hold(ud1)])

Expand Down

0 comments on commit 47a9224

Please sign in to comment.