Skip to content

Commit

Permalink
Excise remaining Zygote legacy code (#132)
Browse files Browse the repository at this point in the history
* Remove legacy collect statement

* Remove Zygote-related remark

* Remove and Zygote-related remark

* Remove redundant comments

* Remove outdated comment

* Display progress through pseudo point tests

* Excise zygote_friendly_map as it is redundant

* Remove zygote-friendly map include from runtests

* Bump patch

* Update readme timings discussion
  • Loading branch information
willtebbutt authored Sep 28, 2024
1 parent 9ddac9c commit 8777eb7
Show file tree
Hide file tree
Showing 13 changed files with 22 additions and 83 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["Will Tebbutt and contributors"]
version = "0.7.0"
version = "0.7.1"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived



# Benchmarking Results
# Benchmarking Results (Old)

![](/examples/benchmarks.png)

"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.

Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance.
Gradient computations were performed using [Zygote.jl](https://github.com/FluxML/Zygote.jl/), and required many custom adjoints.
You should see similar results to this using [Mooncake.jl](https://github.com/compintell/Mooncake.jl) or [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl).


# Relevant literature
Expand Down
1 change: 0 additions & 1 deletion src/TemporalGPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ module TemporalGPs
# Various bits-and-bobs. Often commiting some type piracy.
include(joinpath("util", "linear_algebra.jl"))
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))

include(joinpath("util", "gaussian.jl"))
include(joinpath("util", "mul.jl"))
Expand Down
2 changes: 1 addition & 1 deletion src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ end
function lgssm_components(
m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType
)
m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays.
m = mean_vector(m, t)
As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type)
hs = add_proj_mean(hs, m)
return As, as, Qs, (Hs, hs), x0
Expand Down
6 changes: 3 additions & 3 deletions src/models/lgssm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ end
function posterior(prior::LGSSM, y::AbstractVector)
_check_inputs(prior, y)
new_trans, xf = _a_bit_of_posterior(prior, y)
A = zygote_friendly_map(x -> x.A, new_trans)
a = zygote_friendly_map(x -> x.a, new_trans)
Q = zygote_friendly_map(x -> x.Q, new_trans)
A = map(x -> x.A, new_trans)
a = map(x -> x.a, new_trans)
Q = map(x -> x.Q, new_trans)
return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), prior.emissions)
end

Expand Down
10 changes: 4 additions & 6 deletions src/models/missings.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Several strategies for missing data handling were attempted.
# 1. Use `missing`s as expected. This turned out to be problematic for type-stability.
# 2. Sentinel values (NaNs). Also problematic for type-stability because Zygote.
# 2. Sentinel values (NaNs).
# 3. (The adopted strategy) - replace missings with arbitrary observations and _large_
# observation noises. While not optimal, type-stability is preserved inside the
# performance-sensitive code.
#
# In an ideal world, strategy 1 would work. Unfortunately Zygote isn't up to it yet.

function AbstractGPs.logpdf(
model::LGSSM, y::AbstractVector{Union{Missing, T}},
Expand All @@ -28,7 +26,7 @@ function transform_model_and_obs(
model::LGSSM, y::AbstractVector{<:Union{Missing, T}},
) where {T<:Union{<:AbstractVector, <:Real}}
Σs_filled_in, y_filled_in = fill_in_missings(
zygote_friendly_map(noise_cov, emissions(model)), y,
map(noise_cov, emissions(model)), y,
)
model_with_missings = replace_observation_noise_cov(model, Σs_filled_in)
return model_with_missings, y_filled_in
Expand All @@ -54,11 +52,11 @@ function _logpdf_volume_compensation(y::AbstractVector{<:Union{Missing, <:Real}}
return count(ismissing, y) * log(2π * _large_var_const()) / 2
end

function fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T}
function fill_in_missings(Σs::AbstractVector, y::AbstractVector{Union{Missing, T}}) where {T}
return _fill_in_missings(Σs, y)
end

function _fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T}
function _fill_in_missings(Σs::AbstractVector, y::AbstractVector{Union{Missing, T}}) where {T}

# Fill in observation covariance matrices with very large values.
Σs_filled_in = map(eachindex(y)) do n
Expand Down
17 changes: 5 additions & 12 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,14 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect

k = fx_dtc.f.f.kernel
Cf_diags = kernel_diagonals(k, fx_dtc.x)
# return Cf_diags

# Transform a vector into a vector-of-vectors.
y_vecs = restructure(y, lgssm.emissions)

tmp = zygote_friendly_map(
((Σ, Cf_diag, marg_diag, yn), ) -> begin
Σ_, _ = fill_in_missings(Σ, yn)
return sum(diag(Σ_ \ (Cf_diag - marg_diag.P))) -
count(ismissing, yn) + size(Σ_, 1)
end,
zip(Σs, Cf_diags, marg_diags, y_vecs),
)
# return -sum(tmp) / 2

tmp = map(Σs, Cf_diags, marg_diags, y_vecs) do Σ, Cf_diag, marg_diag, yn
Σ_, _ = fill_in_missings(Σ, yn)
return sum(diag(Σ_ \ (Cf_diag - marg_diag.P))) -
count(ismissing, yn) + size(Σ_, 1)
end
return logpdf(lgssm, y_vecs) - sum(tmp) / 2
end

Expand Down
2 changes: 1 addition & 1 deletion src/space_time/rectilinear_grid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ end
# See docstring elsewhere for context.
function noise_var_to_time_form(x::RectilinearGrid, S::Diagonal{<:Real})
vs = restructure(diag(S), Fill(length(get_space(x)), length(get_times(x))))
return zygote_friendly_map(v -> Diagonal(collect(v)), vs)
return map(v -> Diagonal(collect(v)), vs)
end

destructure(::RectilinearGrid, y::AbstractVector) = reduce(vcat, y)
39 changes: 0 additions & 39 deletions src/util/zygote_friendly_map.jl

This file was deleted.

5 changes: 3 additions & 2 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using KernelFunctions: kappa
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
using Test

# Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure
# that Zygote can handle construction.
# Everything is tested once the LGSSM is constructed, so the logpdf bit of this test
# function is probably redundant. It is good to do a little bit of integration testing
# though.
function _logpdf_tester(f_naive::GP, y, storage::StorageType, σ², t::AbstractVector)
f = to_sde(f_naive, storage)
return logpdf(f(t, σ²...), y)
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ include("front_matter.jl")
println("util:")
@testset "util" begin
include(joinpath("util", "scan.jl"))
include(joinpath("util", "zygote_friendly_map.jl"))
include(joinpath("util", "gaussian.jl"))
include(joinpath("util", "mul.jl"))
include(joinpath("util", "regular_data.jl"))
Expand Down
1 change: 1 addition & 0 deletions test/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using Test
x_pr_r = randn(rng, 10)

@testset "kernel=$(k.name), x=$(x.name)" for k in kernels, x in xs
@info "kernel=$(k.name), x=$(x.name)"

# Compute pseudo-input locations. These have to share time points with `x`.
t = get_times(x.val)
Expand Down
14 changes: 0 additions & 14 deletions test/util/zygote_friendly_map.jl

This file was deleted.

4 comments on commit 8777eb7

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116220

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" 8777eb73f3d1e569ef75f76030a68a496625bea6
git push origin v0.7.1

Also, note the warning: Version 0.7.1 skips over 0.7.0
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/116220

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" 8777eb73f3d1e569ef75f76030a68a496625bea6
git push origin v0.7.1

Please sign in to comment.