Skip to content

Commit

Permalink
Wct/remove zygote dep (#130)
Browse files Browse the repository at this point in the history
* Remove literal_getfield usage

* Improve perf

* Progress

* Lots of changes

* Add Test as test dep

* Fix typo

* Add Pkg to examples

* Add Pkg to test deps

* Require Mooncake 0-4-3

* Import more names

* Remove Mooncake as direct dep

* Formatting

* Formatting

* Tidy up + enable all tests

* Enable all tests

* Add JET as test dep'

* Tidy up and use JET rather than inferred

* Some fixes

* Discuss the changes in this release

* Figure out how to avoid bad gradients

* Tidy up example
  • Loading branch information
willtebbutt authored Sep 27, 2024
1 parent b3405d0 commit d0486e9
Showing 1 changed file with 0 additions and 49 deletions.
49 changes: 0 additions & 49 deletions examples/approx_space_time_learning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,61 +54,12 @@ y = sin.(first.(xs)) .+ cos.(last.(xs)) + sqrt.(params.var_noise) .* randn(lengt
# Spatial pseudo-point inputs.
z_r = collect(range(-3.0, 3.0; length=5));

# # Specify an objective function for Optim to minimise in terms of x and y.
# # We choose the usual negative log marginal likelihood (NLML).
# function make_objective(unpack, x, y, z_r)
# function objective(flat_params)
# params = unpack(flat_params)
# f = build_gp(params)
# return elbo(f(x, params.var_noise), y, z_r)
# end
# return objective
# end
# objective = make_objective(unpack, x, y, z_r)

function objective(flat_params)
params = unpack(flat_params)
f = build_gp(params)
return -elbo(f(x, params.var_noise), y, z_r)
end

# using Random
# # y = y
# # z_r = z_r
# # fx = build_gp(unpack(flat_initial_params))(x, params.var_noise)
# # fx_dtc = TemporalGPs.dtcify(z_r, fx)
# # lgssm = TemporalGPs.build_lgssm(fx_dtc)
# # Σs = lgssm.emissions.fan_out.Q
# # marg_diags = TemporalGPs.marginals_diag(lgssm)

# # k = fx_dtc.f.f.kernel
# # Cf_diags = TemporalGPs.kernel_diagonals(k, fx_dtc.x)

# # # Transform a vector into a vector-of-vectors.
# # y_vecs = TemporalGPs.restructure(y, lgssm.emissions)

# # tmp = TemporalGPs.zygote_friendly_map(
# # ((Σ, Cf_diag, marg_diag, yn), ) -> begin
# # Σ_, _ = TemporalGPs.fill_in_missings(Σ, yn)
# # return sum(TemporalGPs.diag(Σ_ \ (Cf_diag - marg_diag.P))) -
# # count(ismissing, yn) + size(Σ_, 1)
# # end,
# # zip(Σs, Cf_diags, marg_diags, y_vecs),
# # )

# # logpdf(lgssm, y_vecs) # this is the failing thing

# # for _ in 1:10
# # Tapir.TestUtils.test_rule(
# # Xoshiro(123456), objective, flat_initial_params;
# # perf_flag=:none,
# # interp=Tapir.TapirInterpreter(),
# # interface_only=false,
# # is_primitive=false,
# # safety_on=false,
# # )
# # end

# Optimise using Optim.
function objective_grad(rule, flat_params)
return Mooncake.value_and_gradient!!(rule, objective, flat_params)[2][2]
Expand Down

0 comments on commit d0486e9

Please sign in to comment.