diff --git a/docs/src/kernels.md b/docs/src/kernels.md index 8f19091c6..ec1a05dd8 100644 --- a/docs/src/kernels.md +++ b/docs/src/kernels.md @@ -124,6 +124,7 @@ TransformedKernel ScaledKernel KernelSum KernelProduct +KernelTensorSum KernelTensorProduct NormalizedKernel ``` diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..7ddd75b58 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -15,9 +15,10 @@ export LinearKernel, PolynomialKernel export RationalKernel, RationalQuadraticKernel, GammaRationalKernel export PiecewisePolynomialKernel export PeriodicKernel, NeuralNetworkKernel -export KernelSum, KernelProduct, KernelTensorProduct +export KernelSum, KernelProduct, KernelTensorSum, KernelTensorProduct export TransformedKernel, ScaledKernel, NormalizedKernel export GibbsKernel +export ⊕ export Transform, SelectTransform, @@ -108,6 +109,7 @@ include("kernels/normalizedkernel.jl") include("matrix/kernelmatrix.jl") include("kernels/kernelsum.jl") include("kernels/kernelproduct.jl") +include("kernels/kerneltensorsum.jl") include("kernels/kerneltensorproduct.jl") include("kernels/overloads.jl") include("kernels/neuralkernelnetwork.jl") diff --git a/src/kernels/kerneltensorsum.jl b/src/kernels/kerneltensorsum.jl new file mode 100644 index 000000000..c4b2a58f9 --- /dev/null +++ b/src/kernels/kerneltensorsum.jl @@ -0,0 +1,110 @@ +""" + KernelTensorSum + +Tensor sum of kernels. + +# Definition + +For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor +sum of kernels ``k_1, \\ldots, k_n`` is defined as +```math +k(x, x'; k_1, \\ldots, k_n) = \\sum_{i=1}^n k_i(x_i, x'_i). +``` + +# Construction + +The simplest way to specify a `KernelTensorSum` is to use the `⊕` operator (can be typed by `\\oplus`). +```jldoctest tensorsum +julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2); + +julia> kernelmatrix(k1 ⊕ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) + kernelmatrix(k2, X[:, 2]) +true +``` + +You can also specify a `KernelTensorSum` by providing kernels as individual arguments +or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or +individual arguments guarantees that `KernelTensorSum` is concretely typed but might +lead to large compilation times if the number of kernels is large. +```jldoctest tensorsum +julia> KernelTensorSum(k1, k2) == k1 ⊕ k2 +true + +julia> KernelTensorSum((k1, k2)) == k1 ⊕ k2 +true + +julia> KernelTensorSum([k1, k2]) == k1 ⊕ k2 +true +``` +""" +struct KernelTensorSum{K} <: Kernel + kernels::K +end + +function KernelTensorSum(kernel::Kernel, kernels::Kernel...) + return KernelTensorSum((kernel, kernels...)) +end + +@functor KernelTensorSum + +Base.length(kernel::KernelTensorSum) = length(kernel.kernels) + +function (kernel::KernelTensorSum)(x, y) + if !((nx = length(x)) == (ny = length(y)) == (nkernels = length(kernel))) + throw( + DimensionMismatch( + "number of kernels ($nkernels) and number of features (x=$nx, y=$ny) are not consistent", + ), + ) + end + return sum(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y)) +end + +function validate_domain(k::KernelTensorSum, x::AbstractVector, y::AbstractVector) + return (dx = dim(x)) == (dy = dim(y)) == (nkernels = length(k)) || error( + "number of kernels ($nkernels) and group of features (x=$dx), y=$dy) are not consistent", + ) +end + +function validate_domain(k::KernelTensorSum, x::AbstractVector) + return validate_domain(k, x, x) +end + +function kernelmatrix(k::KernelTensorSum, x::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix, +, k.kernels, slices(x)) +end + +function kernelmatrix(k::KernelTensorSum, x::AbstractVector, y::AbstractVector) + validate_domain(k, x, y) + return mapreduce(kernelmatrix, +, k.kernels, slices(x), slices(y)) +end + +function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector) + validate_domain(k, x) + return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x)) +end + +function kernelmatrix_diag(k::KernelTensorSum, x::AbstractVector, y::AbstractVector) + validate_domain(k, x, y) + return mapreduce(kernelmatrix_diag, +, k.kernels, slices(x), slices(y)) +end + +function Base.:(==)(x::KernelTensorSum, y::KernelTensorSum) + return ( + length(x.kernels) == length(y.kernels) && + all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels)) + ) +end + +Base.show(io::IO, kernel::KernelTensorSum) = printshifted(io, kernel, 0) + +function printshifted(io::IO, kernel::KernelTensorSum, shift::Int) + print(io, "Tensor sum of ", length(kernel), " kernels:") + for k in kernel.kernels + print(io, "\n") + for _ in 1:(shift + 1) + print(io, "\t") + end + printshifted(io, k, shift + 2) + end +end diff --git a/src/kernels/overloads.jl b/src/kernels/overloads.jl index 3285c3dd4..e609a92bc 100644 --- a/src/kernels/overloads.jl +++ b/src/kernels/overloads.jl @@ -1,7 +1,11 @@ +function tensor_sum end +const ⊕ = tensor_sum + for (M, op, T) in ( (:Base, :+, :KernelSum), (:Base, :*, :KernelProduct), (:TensorCore, :tensor, :KernelTensorProduct), + (:KernelFunctions, :⊕, :KernelTensorSum), ) @eval begin $M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2) diff --git a/test/kernels/kerneltensorsum.jl b/test/kernels/kerneltensorsum.jl new file mode 100644 index 000000000..b8be6c965 --- /dev/null +++ b/test/kernels/kerneltensorsum.jl @@ -0,0 +1,67 @@ +@testset "kerneltensorsum" begin + rng = MersenneTwister(123456) + u1 = rand(rng, 10) + u2 = rand(rng, 10) + v1 = rand(rng, 5) + v2 = rand(rng, 5) + + # kernels + k1 = SqExponentialKernel() + k2 = ExponentialKernel() + kernel1 = KernelTensorSum(k1, k2) + kernel2 = KernelTensorSum([k1, k2]) + + @test kernel1 == kernel2 + @test kernel1.kernels == (k1, k2) === KernelTensorSum((k1, k2)).kernels + for (_k1, _k2) in Iterators.product( + (k1, KernelTensorSum((k1,)), KernelTensorSum([k1])), + (k2, KernelTensorSum((k2,)), KernelTensorSum([k2])), + ) + @test kernel1 == _k1 ⊕ _k2 + end + @test length(kernel1) == length(kernel2) == 2 + @test string(kernel1) == ( + "Independent sum of 2 kernels:\n" * + "\tSquared Exponential Kernel (metric = Euclidean(0.0))\n" * + "\tExponential Kernel (metric = Euclidean(0.0))" + ) + @test_throws DimensionMismatch kernel1(rand(3), rand(3)) + + @testset "val" begin + for (x, y) in (((v1, u1), (v2, u2)), ([v1, u1], [v2, u2])) + val = k1(x[1], y[1]) + k2(x[2], y[2]) + + @test kernel1(x, y) == kernel2(x, y) == val + end + end + + # Standardised tests. + TestUtils.test_interface(kernel1, ColVecs{Float64}) + TestUtils.test_interface(kernel1, RowVecs{Float64}) + TestUtils.test_interface( + KernelTensorSum(WhiteKernel(), ConstantKernel(; c=1.1)), ColVecs{String} + ) + test_ADs( + x -> KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=exp(x[1]))), + rand(1); + dims=[2, 2], + ) + types = [ColVecs{Float64,Matrix{Float64}}, RowVecs{Float64,Matrix{Float64}}] + test_interface_ad_perf(2.1, StableRNG(123456), types) do c + KernelTensorSum(SqExponentialKernel(), LinearKernel(; c=c)) + end + test_params(KernelTensorSum(k1, k2), (k1, k2)) + + @testset "single kernel" begin + kernel = KernelTensorSum(k1) + @test length(kernel) == 1 + + @testset "eval" begin + for (x, y) in (((v1,), (v2,)), ([v1], [v2])) + val = k1(x[1], y[1]) + + @test kernel(x, y) == val + end + end + end +end diff --git a/test/kernels/overloads.jl b/test/kernels/overloads.jl index eb79d41f8..456cd7796 100644 --- a/test/kernels/overloads.jl +++ b/test/kernels/overloads.jl @@ -5,8 +5,9 @@ k2 = SqExponentialKernel() k3 = RationalQuadraticKernel() - for (op, T) in ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct)) - if T === KernelTensorProduct + for (op, T) in + ((+, KernelSum), (*, KernelProduct), (⊗, KernelTensorProduct), (⊕, KernelTensorSum)) + if T === KernelTensorProduct || T === KernelTensorSum v2_1 = rand(rng, 2) v2_2 = rand(rng, 2) v3_1 = rand(rng, 3) diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..c33d67300 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -125,6 +125,7 @@ include("test_utils.jl") include("kernels/kernelproduct.jl") include("kernels/kernelsum.jl") include("kernels/kerneltensorproduct.jl") + include("kernels/kerneltensorsum.jl") include("kernels/overloads.jl") include("kernels/scaledkernel.jl") include("kernels/transformedkernel.jl")