diff --git a/examples/Project.toml b/examples/Project.toml index 10ecbb6..1d673c0 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -6,5 +6,5 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/examples/approx_space_time_learning.jl b/examples/approx_space_time_learning.jl index ca00798..2748425 100644 --- a/examples/approx_space_time_learning.jl +++ b/examples/approx_space_time_learning.jl @@ -13,7 +13,7 @@ using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation using ParameterHandling: flatten @@ -56,16 +56,72 @@ z_r = collect(range(-3.0, 3.0; length=5)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) - f = build_gp(params) - return -elbo(f(x, params.var_noise), y, z_r) +function make_objective(unpack, x, y, z_r) + function objective(flat_params) + params = unpack(flat_params) + f = build_gp(params) + return elbo(f(x, params.var_noise), y, z_r) + end + return objective end +objective = make_objective(unpack, x, y, z_r) -# Optimise using Optim. Takes a little while to compile because Zygote. +using Tapir: CoDual, primal + +Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) + B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + B = primal(B_dB) + dB = tangent(B_dB) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + return B_dB, time_exp_pb +end + + + +using Random +# y = y +# z_r = z_r +# fx = build_gp(unpack(flat_initial_params))(x, params.var_noise) +# fx_dtc = TemporalGPs.dtcify(z_r, fx) +# lgssm = TemporalGPs.build_lgssm(fx_dtc) +# Σs = lgssm.emissions.fan_out.Q +# marg_diags = TemporalGPs.marginals_diag(lgssm) + +# k = fx_dtc.f.f.kernel +# Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x) + +# # Transform a vector into a vector-of-vectors. +# y_vecs = TemporalGPs.restructure(y, lgssm.emissions) + +# tmp = TemporalGPs.zygote_friendly_map( +# ((Σ, Cf_diag, marg_diag, yn), ) -> begin +# Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn) +# return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) - +# count(ismissing, yn) + size(Σ_, 1) +# end, +# zip(Σs, Cf_diags, marg_diags, y_vecs), +# ) + +# logpdf(lgssm, y_vecs) # this is the failing thing + +for _ in 1:10 + Tapir.TestUtils.test_rule( + Xoshiro(123456), objective, flat_initial_params; + perf_flag=:none, + interp=Tapir.TapirInterpreter(), + interface_only=false, + is_primitive=false, + safety_on=false, + ) +end + +# Optimise using Optim. +rule = Tapir.build_rrule(objective, flat_initial_params); training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), - flat_initial_params, + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], + flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), linesearch = Optim.LineSearches.BackTracking(), @@ -74,6 +130,7 @@ training_results = Optim.optimize( inplace=false, ); + # Extracting the final values of the parameters. # Should be close to truth. final_params = unpack(training_results.minimizer); diff --git a/examples/exact_space_time_learning.jl b/examples/exact_space_time_learning.jl index f9664ac..c2ac99c 100644 --- a/examples/exact_space_time_learning.jl +++ b/examples/exact_space_time_learning.jl @@ -13,7 +13,7 @@ using TemporalGPs: Separable, RectilinearGrid # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. flat_initial_params, unflatten = ParameterHandling.flatten(( @@ -47,15 +47,29 @@ y = rand(build_gp(params)(x, 1e-4)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) +function objective(flat_params) + params = unpack(flat_params) f = build_gp(params) return -logpdf(f(x, params.var_noise), y) end -# Optimise using Optim. Takes a little while to compile because Zygote. +using Tapir: CoDual, primal + +Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real} +function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64}) + B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t))) + B = primal(B_dB) + dB = tangent(B_dB) + time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B)) + return B_dB, time_exp_pb +end + +rule = Tapir.build_rrule(objective, flat_initial_params); + +# Optimise using Optim. training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], flat_initial_params + randn(4), # Add some noise to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), diff --git a/examples/exact_time_learning.jl b/examples/exact_time_learning.jl index b4b05cb..46a2420 100644 --- a/examples/exact_time_learning.jl +++ b/examples/exact_time_learning.jl @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing # Load standard packages from the Julia ecosystem using Optim # Standard optimisation algorithms. using ParameterHandling # Helper functionality for dealing with model parameters. -using Zygote # Algorithmic Differentiation +using Tapir # Algorithmic Differentiation # Declare model parameters using `ParameterHandling.jl` types. # var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the @@ -42,15 +42,18 @@ y = rand(f(x, params.var_noise)); # Specify an objective function for Optim to minimise in terms of x and y. # We choose the usual negative log marginal likelihood (NLML). -function objective(params) +function objective(flat_params) + params = unpack(flat_params) f = build_gp(params) return -logpdf(f(x, params.var_noise), y) end +rule = Tapir.build_rrule(objective, flat_initial_params); + # Optimise using Optim. Zygote takes a little while to compile. training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + objective, + θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2], flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true),