-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Sketch of vector-valued GP functionality #218
base: master
Are you sure you want to change the base?
Changes from 3 commits
cebf7da
7d2c68e
a2ae2c2
9fff953
fa1ceaa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Represents a GP whose output is vector-valued. | ||
struct VectorValuedGP{Tf<:AbstractGP} | ||
f::Tf | ||
num_outputs::Int | ||
end | ||
|
||
# I gave up figuring out how to properly subtype MatrixDistribution, but I want this to | ||
# subtype a distribution type which indicates that samples from this distribution produces | ||
# matrix of size num_features x num_outputs, or something like that. | ||
struct FiniteVectorValuedGP{Tv<:VectorValuedGP,Tx<:AbstractVector,TΣy<:Real} | ||
v::Tv | ||
x::Tx | ||
Σy::TΣy | ||
end | ||
|
||
(f::VectorValuedGP)(x...) = FiniteVectorValuedGP(f, x...) | ||
|
||
function Statistics.mean(vx::FiniteVectorValuedGP) | ||
|
||
# Construct equivalent FiniteGP. | ||
x_f = KernelFunctions.MOInputIsotopicByOutputs(vx.x, vx.v.num_outputs) | ||
f = vx.v.f | ||
fx = f(x_f, vx.Σy) | ||
Comment on lines
+20
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The amount of code-duplication in here seems to beg for something like function _vector_equivalent_finitegp(vx::FiniteVectorValuedGP)
x_f = KernelFunctions.MOInputIsotopicByOutputs(vx.x, vx.v.num_outputs)
f = vx.v.f
fx = f(x_f, vx.Σy)
return fx
end (or |
||
|
||
# Compute quantity under equivalent FiniteGP. | ||
m = mean(fx) | ||
|
||
# Construct the matrix-version of the quantity. | ||
M = reshape(m, length(vx.x), vx.v.num_outputs) | ||
return M | ||
Comment on lines
+28
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could also have a function _vectorvaluedgp_reshape(vx::FiniteVectorValuedGP, a)
return reshape(a, length(vx.x), vx.v.num_outputs)
end NB: should the return type be a ColVecs ? |
||
end | ||
|
||
function Statistics.var(vx::FiniteVectorValuedGP) | ||
|
||
# Construct equivalent FiniteGP. | ||
x_f = KernelFunctions.MOInputIsotopicByOutputs(vx.x, vx.v.num_outputs) | ||
f = vx.v.f | ||
fx = f(x_f, vx.Σy) | ||
|
||
# Compute quantity under equivalent FiniteGP. | ||
v = var(fx) | ||
|
||
# Construct the matrix-version of the quantity. | ||
V = reshape(v, length(vx.x), vx.v.num_outputs) | ||
return V | ||
end | ||
|
||
function Random.rand(rng::AbstractRNG, vx::FiniteVectorValuedGP) | ||
|
||
# Construct equivalent FiniteGP. | ||
x_f = KernelFunctions.MOInputIsotopicByOutputs(vx.x, vx.v.num_outputs) | ||
f = vx.v.f | ||
fx = f(x_f, vx.Σy) | ||
|
||
# Compute quantity under equivalent FiniteGP. | ||
y = rand(rng, fx) | ||
|
||
# Construct the matrix-version of the quantity. | ||
Y = reshape(y, length(vx.x), vx.v.num_outputs) | ||
return Y | ||
end | ||
|
||
function Distributions.logpdf(vx::FiniteVectorValuedGP, Y::AbstractMatrix{<:Real}) | ||
|
||
# Construct equivalent FiniteGP. | ||
x_f = KernelFunctions.MOInputIsotopicByOutputs(vx.x, vx.v.num_outputs) | ||
f = vx.v.f | ||
fx = f(x_f, vx.Σy) | ||
|
||
# Construct flattened-version of observations. | ||
y = vec(Y) | ||
|
||
# Compute logpdf using FiniteGP. | ||
return logpdf(fx, y) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
@testset "vector_valued_gp" begin | ||
f = GP(LinearMixingModelKernel([Matern52Kernel(), Matern12Kernel()], randn(2, 2))) | ||
x = range(0.0, 10.0; length=3) | ||
Σy = 0.1 | ||
|
||
v = AbstractGPs.VectorValuedGP(f, 2) | ||
vx = v(x, Σy) | ||
|
||
M = mean(vx) | ||
|
||
rng = MersenneTwister(123456) | ||
Y = rand(rng, vx) | ||
logpdf(vx, Y) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f
must be, in some sense, a multi-output GP. We don't have a way to enforce that though, but my feeling is that's fine?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make
AbstractGP
parametric on the kernel type and e.g.struct GP{Tm<:MeanFunction,Tk<:Kernel} <: AbstractGP{Tk}
with a
struct VectorValuedGP{Tk<:MOKernel,Tf<:AbstractGP{Tk}}
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would stop you to use any composition on it though...
This looks like a "trait" kind of problem. One could define
is_multioutput
and recursively check any multioutput kernel like:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the thing on which to parametrise the type should perhaps be the domain on which the GP is defined - that could be made to work for transformations as well then. But perhaps too much effort for it to be worth it. And I'm assuming that the "multi-output" kernel would generally be the outermost: In the end the question is just whether the final kernel object accepts the (all_the_features, output_index) tuple correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. We've got a number of AbstractGPs now which don't have a kernel field.
Exactly. Throwing an error if we find out that the inputs aren't compatible the GP to which they're supplied isn't super nice, but it's what we do throughout the ecosystem at the minute. The reason being that giving the GP knowledge about its domain is a massive hassle, and it seems to work fine most of the time 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related: JuliaGaussianProcesses/KernelFunctions.jl#16