Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CI #115

Merged
merged 27 commits into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1f40106
Test on Julia 1.8
Sep 20, 2023
d5676f3
Check in Manifest
Sep 21, 2023
3d51bfb
Add rrules for `kron`
Sep 21, 2023
9da27ab
Revert "Check in Manifest"
Sep 21, 2023
9373a1e
Remove unnecessary method
Sep 21, 2023
f11e0a7
Use new `DTC` API
Sep 22, 2023
f846d58
Fix `rand_tangent` deprecation warnings
Sep 22, 2023
48d41de
Bump compat for AbstractGPs
Sep 22, 2023
de5a921
Drop Julia 1.8 from tests
simsurace Sep 22, 2023
283b2b7
Remove disclaimer about dependencies
simsurace Sep 22, 2023
7ecc719
Revert "Fix `rand_tangent` deprecation warnings"
simsurace Sep 22, 2023
4af9328
Try replacing chainrules by Zygote rules
Sep 26, 2023
5f1d160
Bump Zygote and remove pirated rules
Sep 27, 2023
76f209d
Implement changes re: `Symmetric`
Sep 27, 2023
8e3a4a3
Update comment re: symmetric wrapper
Sep 27, 2023
23a9638
Replace two-argument `map`s by broadcasts
simsurace Oct 3, 2023
228f7f6
Make space-time examples smaller
simsurace Oct 3, 2023
ce86901
Further optimize performance
simsurace Oct 3, 2023
f4bd059
Correct logical typo
simsurace Oct 3, 2023
4062b4f
Add tests and fix error
simsurace Oct 3, 2023
61c3d9a
Put back symmetric wrapper
simsurace Oct 3, 2023
c974e4a
Add some `@info` to example runs
simsurace Oct 3, 2023
e2ddedd
Name image files to not overwrite each other
simsurace Oct 3, 2023
ad05422
Revert "Update comment re: symmetric wrapper"
simsurace Oct 3, 2023
9286cca
Revert "Make space-time examples smaller"
simsurace Oct 3, 2023
7e0147a
Get rid of warning about `params`
simsurace Oct 4, 2023
07588b1
Decrease size of exact space-time learning
simsurace Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.5.15"
AbstractGPs = "0.5.17"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7"
KernelFunctions = "0.9, 0.10.1"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6"
Zygote = "0.6.65"
julia = "1.6"
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ TemporalGPs.jl is a tool to make Gaussian processes (GPs) defined using [Abstrac

[JuliaCon 2020 Talk](https://www.youtube.com/watch?v=dysmEpX1QoE)

# Dependency Status

In the interest of managing expectations, please note that TemporalGPs does not currently operate with the most current version of AbstractGPs / Zygote / ChainRules. I (Will) am aware of this problem, and will sort it out as soon as I have the time!

# Installation

TemporalGPs.jl is registered, so simply type the following at the REPL:
Expand Down
2 changes: 1 addition & 1 deletion examples/approx_space_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N_pr, T));
layout=(1, 2),
),
"posterior.png",
"approx_space_time_inference.png",
);
end
2 changes: 1 addition & 1 deletion examples/approx_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N_pr, T));
layout=(1, 2),
),
"posterior.png",
"approx_space_time_learning.png",
);
end
4 changes: 2 additions & 2 deletions examples/augmented_inference.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using AbstractGPs
using TemporalGPs
using Distributions
using Distributions: Bernoulli
using StatsFuns: logistic

# In this example we are showing how to work with non-Gaussian likelihoods,
Expand Down Expand Up @@ -73,5 +73,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
plot!(plt, x_pr, f_post_samples; color=:red, alpha=0.3, label="");
plot!(plt, x, f_true; label="", lw=2.0, color=:blue); # Plot the true latent GP on top
scatter!(plt, x, y; label="", markersize=1.0, alpha=1.0); # Plot the data
savefig(plt, "posterior.png");
savefig(plt, "augmented_inference.png");
end
2 changes: 1 addition & 1 deletion examples/exact_space_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N, T_pr));
layout=(1, 2),
),
"posterior.png",
"exact_space_time_inference.png",
);
end
6 changes: 3 additions & 3 deletions examples/exact_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ end
# Exact inference only works for such grids.
# Times must be increasing, points in space can be anywhere.
N = 50;
T = 1_000;
T = 500;
points_in_space = collect(range(-3.0, 3.0; length=N));
points_in_time = RegularSpacing(0.0, 0.01, T);
x = RectilinearGrid(points_in_space, points_in_time);
Expand Down Expand Up @@ -73,7 +73,7 @@ final_params = unpack(training_results.minimizer)
f_post = posterior(build_gp(final_params)(x, final_params.var_noise), y);

# Specify some locations at which to make predictions.
T_pr = 1200;
T_pr = 600;
points_in_time_pr = RegularSpacing(0.0, 0.01, T_pr);
x_pr = RectilinearGrid(points_in_space, points_in_time_pr);

Expand All @@ -93,6 +93,6 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
heatmap(reshape(σ_post_marginals, N, T_pr));
layout=(1, 2),
),
"posterior.png",
"exact_space_time_learning.png",
);
end
2 changes: 1 addition & 1 deletion examples/exact_time_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
plot!(x_pr, f_post_samples; color=:red, label="");
savefig(plt, "posterior.png");
savefig(plt, "exact_time_inference.png");
end
2 changes: 1 addition & 1 deletion examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,5 @@ if get(ENV, "TESTING", "FALSE") == "FALSE"
scatter!(plt, x, y; label="", markersize=0.1, alpha=0.1);
plot!(plt, f_post(x_pr); ribbon_scale=3.0, label="");
plot!(plt, x_pr, f_post_samples; color=:red, label="");
savefig(plt, "posterior.png");
savefig(plt, "exact_time_learning.png");
end
1 change: 1 addition & 0 deletions src/models/linear_gaussian_conditionals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ be equivalent to
function predict(x::Gaussian, f::AbstractLGC)
A, a, Q = get_fields(f)
m, P = get_fields(x)

# Symmetric wrapper needed for numerical stability. Do not unwrap.
return Gaussian(A * m + a, (A * symmetric(P)) * A' + Q)
end
Expand Down
11 changes: 6 additions & 5 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,11 @@ function kernel_diagonals(k::DTCSeparable, x::RegularInTime)
space_kernel = k.k.l
time_kernel = k.k.r
time_vars = kernelmatrix_diag(time_kernel, get_times(x))
return map(
(s_t, x_r) -> Diagonal(kernelmatrix_diag(space_kernel, x_r) * s_t),
time_vars,
x.vs,
return Diagonal.(
kernelmatrix_diag.(
Ref(space_kernel),
x.vs
) .* time_vars
)
end

Expand Down Expand Up @@ -185,7 +186,7 @@ function lgssm_components(k_dtc::DTCSeparable, x::RegularInTime, storage::Storag
C = \(K_space_z_chol, C__)
Cs = partition(ChainRulesCore.ignore_derivatives(map(length, x.vs)), C)

cs = _map((h, v) -> fill(h, length(v)), hs_t, x.vs) # This should currently be zero.
cs = fill.(hs_t, length.(x.vs)) # This should currently be zero.
Hs = _map(
((I, H_t), ) -> kron(I, H_t),
zip(Fill(ident_M, N), Hs_t),
Expand Down
11 changes: 10 additions & 1 deletion src/space_time/regular_in_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,16 @@ function Base.collect(x::RegularInTime)
return [(x, t) for (x, t) in zip(space_inputs, time_inputs)]
end

Base.getindex(x::RegularInTime, n::Int) = collect(x)[n]
function Base.getindex(x::RegularInTime, n::Int)
n ≤ 0 && throw(BoundsError(x, n))
sum_of_lengths = 0
for (i, v) in enumerate(x.vs)
temp = sum_of_lengths + length(v)
temp ≥ n && return (v[n - sum_of_lengths], x.ts[i])
sum_of_lengths = temp
end
throw(BoundsError(x, n))
end

Base.show(io::IO, x::RegularInTime) = Base.show(io::IO, collect(x))

Expand Down
7 changes: 0 additions & 7 deletions src/util/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ Zygote.accum(a::Tuple, b::Tuple, c::Tuple) = map(Zygote.accum, a, b, c)
# StaticArrays #
# ---------------------------------------------------------------------------- #

function ProjectTo(x::SArray{S,T}) where {S, T}
return ProjectTo{SArray}(; element=_eltype_projectto(T), axes=axes(x), static_size=S)
end

(proj::ProjectTo{SArray})(dx::SArray) = SArray{proj.static_size}(dx.data)
(proj::ProjectTo{SArray})(dx::AbstractArray) = SArray{proj.static_size}(Tuple(dx))

function rrule(::Type{T}, x::Tuple) where {T<:SArray}
SArray_rrule(Δ) = begin
(NoTangent(), Tangent{typeof(x)}(unthunk(Δ).data...))
Expand Down
19 changes: 12 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,16 @@ if GROUP == "examples"
Pkg.resolve()
Pkg.instantiate()

include(joinpath(pkgpath, "examples", "exact_time_inference.jl"))
include(joinpath(pkgpath, "examples", "exact_time_learning.jl"))
include(joinpath(pkgpath, "examples", "exact_space_time_inference.jl"))
include(joinpath(pkgpath, "examples", "exact_space_time_learning.jl"))
include(joinpath(pkgpath, "examples", "approx_space_time_inference.jl"))
include(joinpath(pkgpath, "examples", "approx_space_time_learning.jl"))
include(joinpath(pkgpath, "examples", "augmented_inference.jl"))
function include_with_info(filename)
@info "Running examples/$filename"
include(joinpath(pkgpath, "examples", filename))
end

include_with_info("exact_time_inference.jl")
include_with_info("exact_time_learning.jl")
include_with_info("exact_space_time_inference.jl")
include_with_info("exact_space_time_learning.jl")
include_with_info("approx_space_time_inference.jl")
include_with_info("approx_space_time_learning.jl")
include_with_info("augmented_inference.jl")
end
4 changes: 2 additions & 2 deletions test/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ include("../models/model_test_utils.jl")
validate_dims(lgssm)

# The two approaches to DTC computation should be equivalent up to roundoff error.
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, y)
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, y)
dtc_sde = dtc(fx, y, z_r)
@test dtc_naive ≈ dtc_sde rtol=1e-6

Expand Down Expand Up @@ -150,7 +150,7 @@ include("../models/model_test_utils.jl")
fx_naive = f_naive(naive_inputs_missings, 0.1)

# Compute DTC using both approaches.
dtc_naive = dtc(VFE(f_naive(z_naive)), fx_naive, naive_y_missings)
dtc_naive = approx_log_evidence(DTC(f_naive(z_naive)), fx_naive, naive_y_missings)
dtc_sde = dtc(fx, y_missing, z_r)
@test dtc_naive ≈ dtc_sde rtol=1e-7 atol=1e-7

Expand Down
3 changes: 3 additions & 0 deletions test/space_time/regular_in_time.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ using TemporalGPs: RegularInTime
@test prod(size(x)) == length(collect(x))

@test all([getindex(x, n) for n in 1:length(x)] .== collect(x))
@test_throws BoundsError x[0]
@test_throws BoundsError x[-1]
@test_throws BoundsError x[length(x) + 1]
end
Loading