Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Aug 22, 2024
1 parent 6d807e6 commit da5540d
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 18 deletions.
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
73 changes: 65 additions & 8 deletions examples/approx_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using TemporalGPs: Separable, approx_posterior_marginals, RegularInTime
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Tapir # Algorithmic Differentiation

using ParameterHandling: flatten

Expand Down Expand Up @@ -56,16 +56,72 @@ z_r = collect(range(-3.0, 3.0; length=5));

# Specify an objective function for Optim to minimise in terms of x and y.
# We choose the usual negative log marginal likelihood (NLML).
function objective(params)
f = build_gp(params)
return -elbo(f(x, params.var_noise), y, z_r)
function make_objective(unpack, x, y, z_r)
function objective(flat_params)
params = unpack(flat_params)
f = build_gp(params)
return elbo(f(x, params.var_noise), y, z_r)
end
return objective
end
objective = make_objective(unpack, x, y, z_r)

# Optimise using Optim. Takes a little while to compile because Zygote.
using Tapir: CoDual, primal

Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real}
function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64})
B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t)))
B = primal(B_dB)
dB = tangent(B_dB)
time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B))
return B_dB, time_exp_pb
end



using Random
# y = y
# z_r = z_r
# fx = build_gp(unpack(flat_initial_params))(x, params.var_noise)
# fx_dtc = TemporalGPs.dtcify(z_r, fx)
# lgssm = TemporalGPs.build_lgssm(fx_dtc)
# Σs = lgssm.emissions.fan_out.Q
# marg_diags = TemporalGPs.marginals_diag(lgssm)

# k = fx_dtc.f.f.kernel
# Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x)

# # Transform a vector into a vector-of-vectors.
# y_vecs = TemporalGPs.restructure(y, lgssm.emissions)

# tmp = TemporalGPs.zygote_friendly_map(
# ((Σ, Cf_diag, marg_diag, yn), ) -> begin
# Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn)
# return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) -
# count(ismissing, yn) + size(Σ_, 1)
# end,
# zip(Σs, Cf_diags, marg_diags, y_vecs),
# )

# logpdf(lgssm, y_vecs) # this is the failing thing

for _ in 1:10
Tapir.TestUtils.test_rule(
Xoshiro(123456), objective, flat_initial_params;
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interface_only=false,
is_primitive=false,
safety_on=false,
)
end

# Optimise using Optim.
rule = Tapir.build_rrule(objective, flat_initial_params);
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
flat_initial_params,
objective,
θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2],
flat_initial_params + randn(4), # Add some noise to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
linesearch = Optim.LineSearches.BackTracking(),
Expand All @@ -74,6 +130,7 @@ training_results = Optim.optimize(
inplace=false,
);


# Extracting the final values of the parameters.
# Should be close to truth.
final_params = unpack(training_results.minimizer);
Expand Down
24 changes: 19 additions & 5 deletions examples/exact_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using TemporalGPs: Separable, RectilinearGrid
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Tapir # Algorithmic Differentiation

# Declare model parameters using `ParameterHandling.jl` types.
flat_initial_params, unflatten = ParameterHandling.flatten((
Expand Down Expand Up @@ -47,15 +47,29 @@ y = rand(build_gp(params)(x, 1e-4));

# Specify an objective function for Optim to minimise in terms of x and y.
# We choose the usual negative log marginal likelihood (NLML).
function objective(params)
function objective(flat_params)
params = unpack(flat_params)
f = build_gp(params)
return -logpdf(f(x, params.var_noise), y)
end

# Optimise using Optim. Takes a little while to compile because Zygote.
using Tapir: CoDual, primal

Tapir.@is_primitive Tapir.MinimalCtx Tuple{typeof(TemporalGPs.time_exp), AbstractMatrix{<:Real}, Real}
function Tapir.rrule!!(::CoDual{typeof(TemporalGPs.time_exp)}, A::CoDual, t::CoDual{Float64})
B_dB = Tapir.zero_fcodual(TemporalGPs.time_exp(primal(A), primal(t)))
B = primal(B_dB)
dB = tangent(B_dB)
time_exp_pb(::NoRData) = NoRData(), NoRData(), sum(dB .* (primal(A) * B))
return B_dB, time_exp_pb
end

rule = Tapir.build_rrule(objective, flat_initial_params);

# Optimise using Optim.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
objective,
θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2],
flat_initial_params + randn(4), # Add some noise to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down
11 changes: 7 additions & 4 deletions examples/exact_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using TemporalGPs: RegularSpacing
# Load standard packages from the Julia ecosystem
using Optim # Standard optimisation algorithms.
using ParameterHandling # Helper functionality for dealing with model parameters.
using Zygote # Algorithmic Differentiation
using Tapir # Algorithmic Differentiation

# Declare model parameters using `ParameterHandling.jl` types.
# var_kernel is the variance of the kernel, λ the inverse length scale, and var_noise the
Expand Down Expand Up @@ -42,15 +42,18 @@ y = rand(f(x, params.var_noise));

# Specify an objective function for Optim to minimise in terms of x and y.
# We choose the usual negative log marginal likelihood (NLML).
function objective(params)
function objective(flat_params)
params = unpack(flat_params)
f = build_gp(params)
return -logpdf(f(x, params.var_noise), y)
end

rule = Tapir.build_rrule(objective, flat_initial_params);

# Optimise using Optim. Zygote takes a little while to compile.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
objective,
θ -> Tapir.value_and_gradient!!(rule, objective, θ)[2][2],
flat_initial_params .+ randn.(), # Perturb the parameters to make learning non-trivial
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand Down

0 comments on commit da5540d

Please sign in to comment.