Skip to content

Commit

Permalink
Remove Zygote as a dep (#129)
Browse files Browse the repository at this point in the history
* Remove literal_getfield usage

* Improve perf

* Progress

* Lots of changes

* Add Test as test dep

* Fix typo

* Add Pkg to examples

* Add Pkg to test deps

* Require Mooncake 0-4-3

* Import more names

* Remove Mooncake as direct dep

* Formatting

* Formatting

* Tidy up + enable all tests

* Enable all tests

* Add JET as test dep'

* Tidy up and use JET rather than inferred

* Some fixes

* Discuss the changes in this release

* Figure out how to avoid bad gradients
  • Loading branch information
willtebbutt authored Sep 27, 2024
1 parent 784dbad commit b3405d0
Show file tree
Hide file tree
Showing 54 changed files with 571 additions and 2,862 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ jobs:
matrix:
version:
- '1'
- '1.6'
os:
- ubuntu-latest
arch:
- x64
group:
- 'test util'
- 'test models'
- 'test models-lgssm'
- 'test gp'
- 'test space_time'
steps:
Expand Down
15 changes: 15 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# 0.7

Mooncake.jl (and probably Enzyme.jl) is now able to differentiate everything in
TemporalGPs.jl _reasonably_ efficiently, and only requires a single rule (for time_exp).
This is in stark contrast with Zygote.jl, which required roughly 2.5k lines to achieve
reasonable performance. This code was not robust, required maintenance from time-to-time,
and generally made making progress on improvements to this library hard to make.
Consequently, in this version of TemporalGPs, we have removed all Zygote-related
functionality, and now recommend that Mooncake.jl (or perhaps Enzyme.jl) is used to
differentiate code in this package. In some places Mooncake.jl achieves worse performance
than Zygote.jl, but it is worth it for the amount of code that has been removed.

If you wish to use Zygote + TemporalGPs, you should restrict yourself to the 0.6 series of
this package.

# 0.5.12

- A collection of examples of inference, and inference + learning, have been added.
Expand Down
27 changes: 21 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
name = "TemporalGPs"
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
authors = ["willtebbutt <[email protected]> and contributors"]
version = "0.6.8"
authors = ["Will Tebbutt and contributors"]
version = "0.7.0"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Bessels = "0e736298-9ec6-45e8-9647-e4fc86a2fe38"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
TemporalGPsMooncakeExt = "Mooncake"

[compat]
AbstractGPs = "0.5.17"
BenchmarkTools = "1"
Bessels = "0.2.8"
BlockDiagonals = "0.1.7"
ChainRulesCore = "1"
FillArrays = "0.13.0 - 0.13.7, 1"
JET = "0.9"
KernelFunctions = "0.9, 0.10.1"
Mooncake = "0.4.3"
StaticArrays = "1"
StructArrays = "0.5, 0.6"
Zygote = "0.6.65"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["BenchmarkTools", "JET", "Mooncake", "Pkg", "Test"]
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ f = to_sde(f_naive, SArrayStorage(Float64))

# Project onto finite-dimensional distribution as usual.
# x = range(-5.0; step=0.1, length=10_000)
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for Zygote.
x = RegularSpacing(0.0, 0.1, 10_000) # Hack for AD.
fx = f(x, 0.1)

# Sample from the prior as usual.
Expand All @@ -63,7 +63,7 @@ rand(f_post(x))
logpdf(f_post(x), y)
```

## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Zygote.jl](https://github.com/FluxML/Zygote.jl/)
## Learning kernel parameters with [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl), and [Mooncake.jl](https://github.com/compintell/Mooncake.jl/)

TemporalGPs.jl doesn't provide scikit-learn-like functionality to train your model (find good kernel parameter settings).
Instead, we offer the functionality needed to easily implement your own training functionality using standard tools from the Julia ecosystem, as shown below.
Expand All @@ -76,7 +76,7 @@ using TemporalGPs
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Mooncake # Algorithmic Differentiation

using ParameterHandling: flatten

Expand Down Expand Up @@ -115,7 +115,7 @@ objective(params)
# Optim.jl for more info on available optimisers and their properties.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
θ -> only(Mooncake.gradient(objective unpack, θ)),
flat_initial_params + randn(3), # Add some noise to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down Expand Up @@ -152,7 +152,7 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived

"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.

Gradient computations use Zygote. Custom adjoints have been implemented to achieve this level of performance.
Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance.



Expand Down
Loading

2 comments on commit b3405d0

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116263

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" b3405d031e07b878103cab90f34159d2215c8156
git push origin v0.7.0

Please sign in to comment.