From 6d807e68831fc59c9a9d6cf82e165fd810d67c74 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 22 Aug 2024 08:47:50 +0100 Subject: [PATCH] Improve perf --- src/space_time/pseudo_point.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/space_time/pseudo_point.jl b/src/space_time/pseudo_point.jl index bcb90d5a..c5867df0 100644 --- a/src/space_time/pseudo_point.jl +++ b/src/space_time/pseudo_point.jl @@ -73,6 +73,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect k = fx_dtc.f.f.kernel Cf_diags = kernel_diagonals(k, fx_dtc.x) + # return Cf_diags # Transform a vector into a vector-of-vectors. y_vecs = restructure(y, lgssm.emissions) @@ -85,6 +86,7 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect end, zip(Σs, Cf_diags, marg_diags, y_vecs), ) + # return -sum(tmp) / 2 return logpdf(lgssm, y_vecs) - sum(tmp) / 2 end @@ -101,14 +103,8 @@ end function kernel_diagonals(k::DTCSeparable, x::RegularInTime) space_kernel = k.k.l - time_kernel = k.k.r - time_vars = kernelmatrix_diag(time_kernel, get_times(x)) - return Diagonal.( - kernelmatrix_diag.( - Ref(space_kernel), - x.vs - ) .* time_vars - ) + time_vars = kernelmatrix_diag(k.k.r, get_times(x)) + return map((v, tv) -> Diagonal(kernelmatrix_diag(space_kernel, v) * tv), x.vs, time_vars) end function kernel_diagonals(k::ScaledKernel, x::AbstractVector)