Skip to content

Commit

Permalink
Tidy up and use JET rather than inferred
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 27, 2024
1 parent d95557e commit 0391be3
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 105 deletions.
27 changes: 22 additions & 5 deletions test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
using AbstractGPs,
BlockDiagonals,
FillArrays,
LinearAlgebra,
JET,
KernelFunctions,
LinearAlgebra,
Mooncake,
Random,
StaticArrays,
Expand All @@ -23,7 +24,26 @@ using TemporalGPs:
scan_emit,
transform_model_and_obs,
RectilinearGrid,
RegularInTime
RegularInTime,
posterior_and_lml,
predict,
predict_marginals,
step_marginals,
step_logpdf,
step_filter,
step_rand,
invert_dynamics,
step_posterior,
storage_type,
is_of_storage_type,
ArrayStorage,
SArrayStorage,
SmallOutputLGC,
LargeOutputLGC,
ScalarOutputLGC,
Forward,
Reverse,
ordering

ENV["TESTING"] = "TRUE"

Expand All @@ -34,8 +54,5 @@ ENV["TESTING"] = "TRUE"
# ENV["GROUP"] = "test gp"
const GROUP = get(ENV, "GROUP", "all")

const TEST_TYPE_INFER = false # Test type stability over the tests
const TEST_ALLOC = false # Test allocations over the tests

include("test_util.jl")
include(joinpath("models", "model_test_utils.jl"))
59 changes: 13 additions & 46 deletions test/models/lgssm.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,3 @@
using TemporalGPs:
TemporalGPs,
predict,
step_marginals,
step_logpdf,
step_filter,
step_rand,
invert_dynamics,
step_posterior,
storage_type,
is_of_storage_type,
ArrayStorage,
SArrayStorage,
SmallOutputLGC,
LargeOutputLGC,
ScalarOutputLGC,
Forward,
Reverse,
ordering

println("lgssm:")
@testset "lgssm" begin

Expand Down Expand Up @@ -58,7 +38,8 @@ println("lgssm:")
# Print current iteration to prevent CI timing out.
println(
"(time_varying=$tv, N=$N, Dlat=$Dlat, Dobs=$Dobs, " *
"storage=$(storage.name), emissions=$(emission.val), ordering=$order)",
"storage=$(storage.name), emissions=$(emission.val), ordering=$order, " *
"Q=$Q)",
)

# Build LGSSM.
Expand All @@ -82,30 +63,16 @@ println("lgssm:")
y = first(rand(model))
x = TemporalGPs.x0(model)

interface_only = true
@testset "step_marginals" begin
@inferred step_marginals(x, model[1])
test_rule(rng, step_marginals, x, model[1]; is_primitive=false, interface_only)
end
@testset "step_logpdf" begin
args = (ordering(model[1]), x, (model[1], y))
@inferred step_logpdf(args...)
test_rule(rng, step_logpdf, args...; is_primitive=false, interface_only)
end
@testset "step_filter" begin
args = (ordering(model[1]), x, (model[1], y))
@inferred step_filter(args...)
test_rule(rng, step_filter, args...; is_primitive=false, interface_only)
end
@testset "invert_dynamics" begin
args = (x, x, model[1].transition)
@inferred invert_dynamics(args...)
test_rule(rng, invert_dynamics, args...; is_primitive=false, interface_only)
end
@testset "step_posterior" begin
args = (ordering(model[1]), x, (model[1], y))
@inferred step_posterior(args...)
test_rule(rng, step_posterior, args...; is_primitive=false, interface_only)
perf_flag = storage.val isa SArrayStorage ? :allocs : :none
@testset "$f" for (f, args...) in Any[
(step_marginals, x, model[1]),
(step_logpdf, ordering(model[1]), x, (model[1], y)),
(step_filter, ordering(model[1]), x, (model[1], y)),
(invert_dynamics, x, x, model[1].transition),
(step_posterior, ordering(model[1]), x, (model[1], y)),
]
@test_opt target_modules=[TemporalGPs] f(args...)
test_rule(rng, f, args...; is_primitive=false, interface_only=true, perf_flag)
end

# Run standard battery of LGSSM tests.
Expand All @@ -116,7 +83,7 @@ println("lgssm:")
max_primal_allocs=25,
max_forward_allocs=25,
max_backward_allocs=25,
check_allocs=TEST_ALLOC && storage.val isa SArrayStorage,
check_allocs=storage.val isa SArrayStorage,
)
end
end
50 changes: 15 additions & 35 deletions test/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
using TemporalGPs: posterior_and_lml, predict, predict_marginals
using Test

println("linear_gaussian_conditionals:")
@testset "linear_gaussian_conditionals" begin
Dlats = [1, 3]
Dobss = [1, 2]
# Dlats = [3]
# Dobss = [2]
storages = [
(name="dense storage Float64", val=ArrayStorage(Float64)),
]
Q_types = [
Val(:dense),
Val(:diag),
]
storages = [(name="dense storage Float64", val=ArrayStorage(Float64))]
Q_types = [Val(:dense), Val(:diag)]

@testset "SmallOutputLGC (Dlat=$Dlat, Dobs=$Dobs, Q=$(Q_type), $(storage.name))" for
Dlat in Dlats,
Expand All @@ -27,11 +17,9 @@ println("linear_gaussian_conditionals:")
x = random_gaussian(rng, Dlat, storage.val)
model = random_small_output_lgc(rng, Dlat, Dobs, Q_type, storage.val)

check_allocs = storage.val isa SArrayStorage
test_interface(
rng, model, x;
check_adjoints=true,
check_inferred=TEST_TYPE_INFER,
check_allocs=TEST_ALLOC && storage.val isa SArrayStorage,
rng, model, x; check_adjoints=true, check_inferred=true, check_allocs
)

Q_type == Val(:diag) && @testset "missing data" begin
Expand All @@ -57,9 +45,8 @@ println("linear_gaussian_conditionals:")
@test lml lml_new atol=1e-8 rtol=1e-8

# Check that everything infers and AD gives the right answer.
@inferred posterior_and_lml(x, model, y_missing)
# BROKEN: gradients with Zygote look fine but are failing because of ChainRulesTestUtils checks see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/270
# test_zygote_grad(posterior_and_lml, x, model, y_missing)
@test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing)
test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false)
end
end

Expand Down Expand Up @@ -103,27 +90,25 @@ println("linear_gaussian_conditionals:")
@test lml_vanilla lml_large rtol=1e-8 atol=1e-8

# Check that everything infers and AD gives the right answer.
@inferred posterior_and_lml(x, model, y_missing)
@test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing)
test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false)
end
end

check_allocs = storage.val isa SArrayStorage
test_interface(
rng, model, x;
check_adjoints=true,
check_inferred=TEST_TYPE_INFER,
check_allocs=TEST_ALLOC && storage.val isa SArrayStorage,
rng, model, x; check_adjoints=true, check_inferred=true, check_allocs
)
end

@testset "ScalarOutputLGC (Dlat=$Dlat, ($storage.name))" for
@testset "ScalarOutputLGC (Dlat=$Dlat, $(storage.name))" for
Dlat in Dlats,
storage in [
(name="dense storage Float64", val=ArrayStorage(Float64)),
(name="static storage Float64", val=SArrayStorage(Float64)),
]

println("ScalarOutputLGC (Dlat=$Dlat, ($storage.name))")
println("ScalarOutputLGC (Dlat=$Dlat, $(storage.name))")

rng = MersenneTwister(123456)
x = random_gaussian(rng, Dlat, storage.val)
Expand All @@ -140,11 +125,9 @@ println("linear_gaussian_conditionals:")
@test lml_vanilla lml_scalar
end

check_allocs = storage.val isa SArrayStorage
test_interface(
rng, model, x;
check_adjoints=true,
check_inferred=TEST_TYPE_INFER,
check_allocs=TEST_ALLOC && storage.val isa SArrayStorage,
rng, model, x; check_adjoints=true, check_inferred=true, check_allocs
)
end

Expand All @@ -167,10 +150,7 @@ println("linear_gaussian_conditionals:")
@test TemporalGPs.dim_in(model) == Din

test_interface(
rng, model, x;
check_adjoints=true,
check_inferred=TEST_TYPE_INFER,
check_allocs=TEST_ALLOC,
rng, model, x; check_adjoints=true, check_inferred=true, check_allocs=false
)

@testset "consistency with SmallOutputLGC" begin
Expand Down Expand Up @@ -202,7 +182,7 @@ println("linear_gaussian_conditionals:")
@test lml_vanilla lml_large rtol=1e-8 atol=1e-8

# Check that everything infers and AD gives the right answer.
@inferred posterior_and_lml(x, model, y_missing)
@test_opt target_modules=[TemporalGPs] posterior_and_lml(x, model, y_missing)
test_rule(rng, posterior_and_lml, x, model, y_missing; is_primitive=false)
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/models/missings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
end

# Check logpdf and inference run, infer.
@inferred logpdf(model, y_missing)
@inferred posterior(model, y_missing)
@test_opt target_modules=[TemporalGPs] logpdf(model, y_missing)
@test_opt target_modules=[TemporalGPs] posterior(model, y_missing)
end
end;
38 changes: 21 additions & 17 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ using TemporalGPs:

function test_interface(
rng::AbstractRNG, conditional::AbstractLGC, x::Gaussian;
check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, kwargs...,
check_inferred=true, check_adjoints=true, check_allocs=true,
)
x_val = rand(rng, x)
y = conditional_rand(rng, conditional, x_val)
perf_flag = check_allocs ? :allocs : :none

is_primitive = false
@testset "rand" begin
@test length(y) == dim_out(conditional)
args = (TemporalGPs.ε_randn(rng, conditional), conditional, x_val)
check_inferred && @inferred conditional_rand(args...)
check_adjoints && test_rule(rng, conditional_rand, args...; perf_flag, is_primitive=false)
check_inferred && @test_opt target_modules=[TemporalGPs] conditional_rand(args...)
check_adjoints && test_rule(rng, conditional_rand, args...; perf_flag, is_primitive)
end

@testset "predict" begin
@test predict(x, conditional) isa Gaussian
check_inferred && @inferred predict(x, conditional)
check_adjoints && test_rule(rng, predict, x, conditional; perf_flag, is_primitive=false)
check_inferred && @test_opt target_modules=[TemporalGPs] predict(x, conditional)
check_adjoints && test_rule(rng, predict, x, conditional; perf_flag, is_primitive)
end

conditional isa ScalarOutputLGC || @testset "predict_marginals" begin
Expand All @@ -52,15 +53,15 @@ function test_interface(
@testset "posterior_and_lml" begin
args = (x, conditional, y)
@test posterior_and_lml(args...) isa Tuple{Gaussian, Real}
check_inferred && @inferred posterior_and_lml(args...)
check_adjoints && test_rule(rng, posterior_and_lml, args...; perf_flag, is_primitive=false)
check_inferred && @test_opt target_modules=[TemporalGPs] posterior_and_lml(args...)
check_adjoints && test_rule(rng, posterior_and_lml, args...; perf_flag, is_primitive)
end
end

"""
test_interface(
rng::AbstractRNG, ssm::AbstractLGSSM;
check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, kwargs...
check_inferred=true, check_adjoints=true, check_allocs=true,
)
Basic consistency tests that any LGSSM should be able to satisfy. The purpose of these tests
Expand All @@ -69,7 +70,7 @@ consistent and implements the required interface.
"""
function test_interface(
rng::AbstractRNG, ssm::AbstractLGSSM;
check_inferred=TEST_TYPE_INFER, check_adjoints=true, check_allocs=TEST_ALLOC, rtol, atol, kwargs...
check_inferred=true, check_adjoints=true, check_allocs=true,
)
perf_flag = check_allocs ? :allocs : :none
y_no_missing = rand(rng, ssm)
Expand All @@ -78,13 +79,17 @@ function test_interface(
@test is_of_storage_type(y_no_missing[1], storage_type(ssm))
@test y_no_missing isa AbstractVector
@test length(y_no_missing) == length(ssm)
check_inferred && @inferred rand(rng, ssm)
check_inferred && @test_opt target_modules=[TemporalGPs] rand(rng, ssm)
rng = MersenneTwister(123456)
check_adjoints && test_rule(rng, rand, rng, ssm; perf_flag, interface_only=true, is_primitive=false)
if check_adjoints
test_rule(
rng, rand, rng, ssm; perf_flag, interface_only=true, is_primitive=false
)
end
end

@testset "basics" begin
@inferred storage_type(ssm)
@test_opt target_modules=[TemporalGPs] storage_type(ssm)
@test length(ssm) == length(y_no_missing)
end

Expand All @@ -93,7 +98,7 @@ function test_interface(
@test is_of_storage_type(xs, storage_type(ssm))
@test xs isa AbstractVector{<:Gaussian}
@test length(xs) == length(ssm)
check_inferred && @inferred marginals(ssm)
check_inferred && @test_opt target_modules=[TemporalGPs] marginals(ssm)
if check_adjoints
test_rule(
rng, scan_emit, step_marginals, ssm, x0(ssm), eachindex(ssm);
Expand All @@ -104,7 +109,6 @@ function test_interface(

@testset "$(data.name)" for data in [
(name="no-missings", y=y_no_missing),
# (name="with-missings", y=y_missing),
]
_check_inferred = data.name == "with-missings" ? false : check_inferred

Expand All @@ -113,7 +117,7 @@ function test_interface(
lml = logpdf(ssm, y)
@test lml isa Real
@test is_of_storage_type(lml, storage_type(ssm))
_check_inferred && @inferred logpdf(ssm, y)
_check_inferred && @test_opt target_modules=[TemporalGPs] logpdf(ssm, y)
if check_adjoints
test_rule(
rng, scan_emit, step_logpdf, zip(ssm, y), x0(ssm), eachindex(ssm);
Expand All @@ -126,7 +130,7 @@ function test_interface(
@test is_of_storage_type(xs, storage_type(ssm))
@test xs isa AbstractVector{<:Gaussian}
@test length(xs) == length(ssm)
_check_inferred && @inferred _filter(ssm, y)
_check_inferred && @test_opt target_modules=[TemporalGPs] _filter(ssm, y)
if check_adjoints
test_rule(
rng, scan_emit, step_filter, zip(ssm, y), x0(ssm), eachindex(ssm);
Expand All @@ -138,7 +142,7 @@ function test_interface(
posterior_ssm = posterior(ssm, y)
@test length(posterior_ssm) == length(ssm)
@test ordering(posterior_ssm) != ordering(ssm)
_check_inferred && @inferred posterior(ssm, y)
_check_inferred && @test_opt target_modules=[TemporalGPs] posterior(ssm, y)
if check_adjoints
test_rule(
rng, posterior, ssm, y;
Expand Down

0 comments on commit 0391be3

Please sign in to comment.