Skip to content

Commit

Permalink
Improve perf
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Aug 22, 2024
1 parent a169fcc commit 6d807e6
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/space_time/pseudo_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect

k = fx_dtc.f.f.kernel
Cf_diags = kernel_diagonals(k, fx_dtc.x)
# return Cf_diags

# Transform a vector into a vector-of-vectors.
y_vecs = restructure(y, lgssm.emissions)
Expand All @@ -85,6 +86,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect
end,
zip(Σs, Cf_diags, marg_diags, y_vecs),
)
# return -sum(tmp) / 2

return logpdf(lgssm, y_vecs) - sum(tmp) / 2
end
Expand All @@ -101,14 +103,8 @@ end

function kernel_diagonals(k::DTCSeparable, x::RegularInTime)
space_kernel = k.k.l
time_kernel = k.k.r
time_vars = kernelmatrix_diag(time_kernel, get_times(x))
return Diagonal.(
kernelmatrix_diag.(
Ref(space_kernel),
x.vs
) .* time_vars
)
time_vars = kernelmatrix_diag(k.k.r, get_times(x))
return map((v, tv) -> Diagonal(kernelmatrix_diag(space_kernel, v) * tv), x.vs, time_vars)
end

function kernel_diagonals(k::ScaledKernel, x::AbstractVector)
Expand Down

0 comments on commit 6d807e6

Please sign in to comment.