Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Zygote as a dep #129

Merged
merged 20 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ jobs:
matrix:
version:
- '1'
- '1.6'
os:
- ubuntu-latest
arch:
- x64
group:
- 'test util'
- 'test models'
- 'test models-lgssm'
- 'test gp'
- 'test space_time'
steps:
Expand Down
15 changes: 15 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# 0.7

Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in
TemporalGPs.jl _reasonably_ efficiently, and only requires a single rule (for time_exp).
This is in stark contrast with Zygote.jl, which required roughly 2.5k lines to achieve
reasonable performance. This code was not robust, required maintenance from time-to-time,
and generally made making progress on improvements to this library hard to make.
Consequently, in this version of TemporalGPs, we have removed all Zygote-related
functionality, and now recommend that Mooncake.jl (or perhaps Enzyme.jl) is used to
differentiate code in this package. In some places Mooncake.jl achieves worse performance
than Zygote.jl, but it is worth it for the amount of code that has been removed.

If you wish to use Zygote + TemporalGPs, you should restrict yourself to the 0.6 series of
this package.

# 0.5.12

- A collection of examples of inference, and inference + learning, have been added.
Expand Down
27 changes: 21 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]> and contributors"]
version = "0.6.8"
authors = ["Will Tebbutt and contributors"]
version = "0.7.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
TemporalGPsMooncakeExt = "Mooncake"

[compat]
AbstractGPs = "0.5.17"
BenchmarkTools = "1"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7, 1"
JET = "0.9"
KernelFunctions = "0.9, 0.10.1"
Mooncake = "0.4.3"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6.65"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "JET", "Mooncake", "Pkg", "Test"]
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ f = to_sde(f_naive, SArrayStorage(Float64))

# Project onto finite-dimensional distribution as usual.
# x = range(-5.0; step=0.1, length=10_000)
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for Zygote.
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for AD.
fx = f(x, 0.1)

# Sample from the prior as usual.
Expand All @@ -63,7 +63,7 @@ rand(f_post(x))
logpdf(f_post(x), y)
```

## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Zygote.jl](https://github.com/FluxML/Zygote.jl/)
## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Mooncake.jl](https://github.com/compintell/Mooncake.jl/)

TemporalGPs.jl doesn't provide scikit-learn-like functionality to train your model (find good kernel parameter settings).
Instead, we offer the functionality needed to easily implement your own training functionality using standard tools from the Julia ecosystem, as shown below.
Expand All @@ -76,7 +76,7 @@ using TemporalGPs
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Mooncake # Algorithmic Differentiation

using ParameterHandling: flatten

Expand Down Expand Up @@ -115,7 +115,7 @@ objective(params)
# Optim.jl for more info on available optimisers and their properties.
training_results = Optim.optimize(
objective ∘ unpack,
θ -> only(Zygote.gradient(objective ∘ unpack, θ)),
θ -> only(Mooncake.gradient(objective ∘ unpack, θ)),
flat_initial_params + randn(3), # Add some noise to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down Expand Up @@ -152,7 +152,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived

"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.

Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance.
Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance.



Expand Down
Loading
Loading